Skip to content

Commit

Permalink
amended some tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwen-h committed Sep 12, 2023
1 parent 8ffb954 commit 95ec0f3
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 6 deletions.
2 changes: 1 addition & 1 deletion pxtextmining/factories/factory_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def create_sklearn_pipeline_sentiment(
cache_size=1000,
),
)
params["svc__C"] = stats.uniform(0.1, 20)
params["svc__C"] = [1, 5, 10, 15, 20]
params["svc__kernel"] = [
"linear",
"rbf",
Expand Down
14 changes: 10 additions & 4 deletions tests/test_factory_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,13 @@ def test_create_sklearn_pipeline(model_type, tokenizer, additional_features):

@pytest.mark.parametrize("target", ["sentiment", None])
@patch("pxtextmining.factories.factory_pipeline.RandomizedSearchCV")
def test_search_sklearn_pipelines(mock_search, target, grab_test_X_additional_feats):
@patch("pxtextmining.factories.factory_pipeline.GridSearchCV")
def test_search_sklearn_pipelines(
mock_gridsearch, mock_randomsearch, target, grab_test_X_additional_feats
):
mock_instance = MagicMock()
mock_search.return_value = mock_instance
mock_gridsearch.return_value = mock_instance
mock_randomsearch.return_value = mock_instance
models_to_try = ["svm"]
X_train = grab_test_X_additional_feats
Y_train = np.array(
Expand Down Expand Up @@ -97,11 +101,13 @@ def test_search_sklearn_pipelines(mock_search, target, grab_test_X_additional_fe

@pytest.mark.parametrize("target", ["sentiment", None])
@patch("pxtextmining.factories.factory_pipeline.RandomizedSearchCV")
@patch("pxtextmining.factories.factory_pipeline.GridSearchCV")
def test_search_sklearn_pipelines_no_feats(
mock_search, target, grab_test_X_additional_feats
mock_gridsearch, mock_randomsearch, target, grab_test_X_additional_feats
):
mock_instance = MagicMock()
mock_search.return_value = mock_instance
mock_gridsearch.return_value = mock_instance
mock_randomsearch.return_value = mock_instance
models_to_try = ["svm"]
X_train = grab_test_X_additional_feats["FFT answer"]
Y_train = np.array(
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def test_accuracy_per_class():


def test_parse_metrics_file():
metrics_file = "current_best_multilabel/bert_sentiment.txt"
metrics_file = "current_best_model/sentiment/bert_sentiment.txt"
labels = ["very positive", "positive", "neutral", "negative", "very negative"]
metrics_df = factory_model_performance.parse_metrics_file(metrics_file, labels)
assert metrics_df.shape == (5, 5)
Expand Down

0 comments on commit 95ec0f3

Please sign in to comment.