From 8ffb95416bd94cfb97a6dbeae50d27fa4213d881 Mon Sep 17 00:00:00 2001 From: YiWen Hon Date: Tue, 12 Sep 2023 10:29:54 +0100 Subject: [PATCH] svc now uses gridsearch not randomsearch - takes too long otherwise --- pxtextmining/factories/factory_pipeline.py | 9 ++- pxtextmining/pipelines/multilabel_pipeline.py | 72 +++++++++---------- 2 files changed, 42 insertions(+), 39 deletions(-) diff --git a/pxtextmining/factories/factory_pipeline.py b/pxtextmining/factories/factory_pipeline.py index 8e38574..e91a5e5 100644 --- a/pxtextmining/factories/factory_pipeline.py +++ b/pxtextmining/factories/factory_pipeline.py @@ -7,7 +7,7 @@ from sklearn.compose import make_column_transformer from sklearn.ensemble import RandomForestClassifier from sklearn.feature_extraction.text import TfidfVectorizer -from sklearn.model_selection import RandomizedSearchCV +from sklearn.model_selection import GridSearchCV, RandomizedSearchCV from sklearn.multioutput import MultiOutputClassifier from sklearn.naive_bayes import MultinomialNB from sklearn.neighbors import KNeighborsClassifier @@ -358,6 +358,10 @@ def create_sklearn_pipeline(model_type, tokenizer=None, additional_features=True "rbf", "sigmoid", ] + if "columntransformer__tfidfvectorizer__min_df" in params: + params["columntransformer__tfidfvectorizer__min_df"] = [0, 1, 2, 3, 4, 5] + else: + params["tfidfvectorizer__min_df"] = [0, 1, 2, 3, 4, 5] if model_type == "rfc": pipe = make_pipeline(preproc, RandomForestClassifier(n_jobs=-1)) params["randomforestclassifier__max_depth"] = stats.randint(5, 50) @@ -418,11 +422,10 @@ def search_sklearn_pipelines( ) start_time = time.time() if model_type == "svm": - search = RandomizedSearchCV( + search = GridSearchCV( pipe, params, scoring="average_precision", - n_iter=100, cv=4, refit=True, verbose=1, diff --git a/pxtextmining/pipelines/multilabel_pipeline.py b/pxtextmining/pipelines/multilabel_pipeline.py index 4f42fed..537ec4f 100644 --- a/pxtextmining/pipelines/multilabel_pipeline.py +++ b/pxtextmining/pipelines/multilabel_pipeline.py @@ -333,42 +333,42 @@ def run_bert_pipeline( if __name__ == "__main__": - run_svc_pipeline( - additional_features=False, - target=minor_cats, - path="test_multilabel/0906threshold/svc_noq", - include_analysis=True, - custom_threshold=True, - ) - run_svc_pipeline( - additional_features=True, - target=minor_cats, - path="test_multilabel/0906threshold/svc", - include_analysis=True, - custom_threshold=True, - ) - run_sklearn_pipeline( - additional_features=True, - target=minor_cats, - models_to_try=["xgb", "knn"], - path="test_multilabel/0906threshold/xgb", - include_analysis=True, - custom_threshold=True, - ) - run_bert_pipeline( - additional_features=True, - path="test_multilabel/0906threshold/bert", - target=minor_cats, - include_analysis=True, - custom_threshold=True, - ) - run_bert_pipeline( - additional_features=False, - path="test_multilabel/0906threshold/bert_noq", - target=minor_cats, - include_analysis=True, - custom_threshold=True, - ) + # run_svc_pipeline( + # additional_features=False, + # target=minor_cats, + # path="test_multilabel/0906threshold/svc_noq", + # include_analysis=True, + # custom_threshold=True, + # ) + # run_svc_pipeline( + # additional_features=True, + # target=minor_cats, + # path="test_multilabel/0906threshold/svc", + # include_analysis=True, + # custom_threshold=True, + # ) + # run_sklearn_pipeline( + # additional_features=True, + # target=minor_cats, + # models_to_try=["xgb", "knn"], + # path="test_multilabel/0906threshold/xgb", + # include_analysis=True, + # custom_threshold=True, + # ) + # run_bert_pipeline( + # additional_features=True, + # path="test_multilabel/0906threshold/bert", + # target=minor_cats, + # include_analysis=True, + # custom_threshold=True, + # ) + # run_bert_pipeline( + # additional_features=False, + # path="test_multilabel/0906threshold/bert_noq", + # target=minor_cats, + # include_analysis=True, + # custom_threshold=True, + # ) run_sklearn_pipeline( additional_features=True, target=minor_cats,