Skip to content

Commit

Permalink
[#172] Don't handle threshold and threshold_ratio in choose_classifier()
Browse files Browse the repository at this point in the history
The caller is responsible for passing a dictionary of hyper-parameters
to choose_classifier(), and this dictionary should not include hlink's
threshold or threshold_ratio. Both of the places where we call
choose_classifier() (training and model exploration) already handle
this.
  • Loading branch information
riley-harper committed Dec 5, 2024
1 parent e57dad6 commit a736dd0
Showing 1 changed file with 4 additions and 22 deletions.
26 changes: 4 additions & 22 deletions hlink/linking/core/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,7 @@ def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str):
features_vector = "features_vector"
if model_type == "random_forest":
classifier = RandomForestClassifier(
**{
key: val
for key, val in params.items()
if key not in ["threshold", "threshold_ratio"]
},
**params,
labelCol=dep_var,
featuresCol=features_vector,
seed=2133,
Expand Down Expand Up @@ -110,11 +106,7 @@ def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str):

elif model_type == "gradient_boosted_trees":
classifier = GBTClassifier(
**{
key: val
for key, val in params.items()
if key not in ["threshold", "threshold_ratio"]
},
**params,
featuresCol=features_vector,
labelCol=dep_var,
seed=2133,
Expand All @@ -130,13 +122,8 @@ def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str):
"its dependencies. Try installing hlink with the lightgbm extra: "
"\n\n pip install hlink[lightgbm]"
)
params_without_threshold = {
key: val
for key, val in params.items()
if key not in {"threshold", "threshold_ratio"}
}
classifier = synapse.ml.lightgbm.LightGBMClassifier(
**params_without_threshold,
**params,
featuresCol=features_vector,
labelCol=dep_var,
probabilityCol="probability_array",
Expand All @@ -151,13 +138,8 @@ def choose_classifier(model_type: str, params: dict[str, Any], dep_var: str):
"the xgboost library and its dependencies. Try installing hlink with "
"the xgboost extra:\n\n pip install hlink[xgboost]"
)
params_without_threshold = {
key: val
for key, val in params.items()
if key not in {"threshold", "threshold_ratio"}
}
classifier = xgboost.spark.SparkXGBClassifier(
**params_without_threshold,
**params,
features_col=features_vector,
label_col=dep_var,
probability_col="probability_array",
Expand Down

0 comments on commit a736dd0

Please sign in to comment.