Skip to content

Commit

Permalink
svc now uses gridsearch not randomsearch - takes too long otherwise
Browse files Browse the repository at this point in the history
  • Loading branch information
yiwen-h committed Sep 12, 2023
1 parent 4feb571 commit 8ffb954
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 39 deletions.
9 changes: 6 additions & 3 deletions pxtextmining/factories/factory_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sklearn.compose import make_column_transformer
from sklearn.ensemble import RandomForestClassifier
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.model_selection import RandomizedSearchCV
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV
from sklearn.multioutput import MultiOutputClassifier
from sklearn.naive_bayes import MultinomialNB
from sklearn.neighbors import KNeighborsClassifier
Expand Down Expand Up @@ -358,6 +358,10 @@ def create_sklearn_pipeline(model_type, tokenizer=None, additional_features=True
"rbf",
"sigmoid",
]
if "columntransformer__tfidfvectorizer__min_df" in params:
params["columntransformer__tfidfvectorizer__min_df"] = [0, 1, 2, 3, 4, 5]
else:
params["tfidfvectorizer__min_df"] = [0, 1, 2, 3, 4, 5]
if model_type == "rfc":
pipe = make_pipeline(preproc, RandomForestClassifier(n_jobs=-1))
params["randomforestclassifier__max_depth"] = stats.randint(5, 50)
Expand Down Expand Up @@ -418,11 +422,10 @@ def search_sklearn_pipelines(
)
start_time = time.time()
if model_type == "svm":
search = RandomizedSearchCV(
search = GridSearchCV(
pipe,
params,
scoring="average_precision",
n_iter=100,
cv=4,
refit=True,
verbose=1,
Expand Down
72 changes: 36 additions & 36 deletions pxtextmining/pipelines/multilabel_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,42 +333,42 @@ def run_bert_pipeline(


if __name__ == "__main__":
run_svc_pipeline(
additional_features=False,
target=minor_cats,
path="test_multilabel/0906threshold/svc_noq",
include_analysis=True,
custom_threshold=True,
)
run_svc_pipeline(
additional_features=True,
target=minor_cats,
path="test_multilabel/0906threshold/svc",
include_analysis=True,
custom_threshold=True,
)
run_sklearn_pipeline(
additional_features=True,
target=minor_cats,
models_to_try=["xgb", "knn"],
path="test_multilabel/0906threshold/xgb",
include_analysis=True,
custom_threshold=True,
)
run_bert_pipeline(
additional_features=True,
path="test_multilabel/0906threshold/bert",
target=minor_cats,
include_analysis=True,
custom_threshold=True,
)
run_bert_pipeline(
additional_features=False,
path="test_multilabel/0906threshold/bert_noq",
target=minor_cats,
include_analysis=True,
custom_threshold=True,
)
# run_svc_pipeline(
# additional_features=False,
# target=minor_cats,
# path="test_multilabel/0906threshold/svc_noq",
# include_analysis=True,
# custom_threshold=True,
# )
# run_svc_pipeline(
# additional_features=True,
# target=minor_cats,
# path="test_multilabel/0906threshold/svc",
# include_analysis=True,
# custom_threshold=True,
# )
# run_sklearn_pipeline(
# additional_features=True,
# target=minor_cats,
# models_to_try=["xgb", "knn"],
# path="test_multilabel/0906threshold/xgb",
# include_analysis=True,
# custom_threshold=True,
# )
# run_bert_pipeline(
# additional_features=True,
# path="test_multilabel/0906threshold/bert",
# target=minor_cats,
# include_analysis=True,
# custom_threshold=True,
# )
# run_bert_pipeline(
# additional_features=False,
# path="test_multilabel/0906threshold/bert_noq",
# target=minor_cats,
# include_analysis=True,
# custom_threshold=True,
# )
run_sklearn_pipeline(
additional_features=True,
target=minor_cats,
Expand Down

0 comments on commit 8ffb954

Please sign in to comment.