Skip to content

Commit

Permalink
[#174] Pass just decision into predict_with_thresholds() instead of t…
Browse files Browse the repository at this point in the history
…he whole training config

This makes it clear which part of the config predict_with_thresholds() is using
and makes it easier to call. It also means that predict_with_thresholds() does
not need to know about the structure of the config.
  • Loading branch information
riley-harper committed Dec 5, 2024
1 parent 28bcd03 commit ad6ce10
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 11 deletions.
8 changes: 4 additions & 4 deletions hlink/linking/core/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def predict_using_thresholds(
pred_df: DataFrame,
alpha_threshold: float,
threshold_ratio: float,
training_conf: dict[str, Any],
id_col: str,
decision: str | None,
) -> DataFrame:
"""Adds a prediction column to the given pred_df by applying thresholds.
Expand All @@ -57,17 +57,17 @@ def predict_using_thresholds(
to the "a" record's next best probability value.
Only used with the "drop_duplicate_with_threshold_ratio"
configuration value.
training_conf: dictionary
the training config section
id_col: string
the id column
decision: str | None
how to apply the thresholds
Returns
-------
A Spark DataFrame containing the "prediction" column as well as other intermediate columns generated to create the prediction.
"""
use_threshold_ratio = (
training_conf.get("decision", "") == "drop_duplicate_with_threshold_ratio"
decision is not None and decision == "drop_duplicate_with_threshold_ratio"
)

if use_threshold_ratio:
Expand Down
3 changes: 2 additions & 1 deletion hlink/linking/matching/link_step_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,13 @@ def _run(self):
threshold_ratio = threshold_core.get_threshold_ratio(
config[training_conf], chosen_model_params, default=1.3
)
decision = config[training_conf].get("decision")
predictions = threshold_core.predict_using_thresholds(
score_tmp,
alpha_threshold,
threshold_ratio,
config[training_conf],
config["id_column"],
decision,
)
predictions.write.mode("overwrite").saveAsTable(f"{table_prefix}predictions")
pmp = self.task.spark.table(f"{table_prefix}potential_matches_pipeline")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -411,20 +411,21 @@ def _evaluate_threshold_combinations(
f"{this_alpha_threshold=} and {this_threshold_ratio=}"
)
logger.debug(diag)
decision = training_settings.get("decision")
start_predict_time = perf_counter()
predictions = threshold_core.predict_using_thresholds(
thresholding_predictions,
this_alpha_threshold,
this_threshold_ratio,
training_settings,
id_column,
decision,
)
predict_train = threshold_core.predict_using_thresholds(
thresholding_predict_train,
this_alpha_threshold,
this_threshold_ratio,
training_settings,
id_column,
decision,
)

end_predict_time = perf_counter()
Expand Down
5 changes: 2 additions & 3 deletions hlink/tests/core/threshold_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_predict_using_thresholds_default_decision(spark: SparkSession) -> None:

# We are using the default decision, so threshold_ratio will be ignored
predictions = predict_using_thresholds(
df, alpha_threshold=0.6, threshold_ratio=0.0, training_conf={}, id_col="id"
df, alpha_threshold=0.6, threshold_ratio=0.0, id_col="id", decision=None
)

output_rows = (
Expand Down Expand Up @@ -64,13 +64,12 @@ def test_predict_using_thresholds_drop_duplicates_decision(spark: SparkSession)
(3, "E", 0.8),
]
df = spark.createDataFrame(input_rows, schema=["id_a", "id_b", "probability"])
training_conf = {"decision": "drop_duplicate_with_threshold_ratio"}
predictions = predict_using_thresholds(
df,
alpha_threshold=0.5,
threshold_ratio=2.0,
training_conf=training_conf,
id_col="id",
decision="drop_duplicate_with_threshold_ratio",
)

output_rows = (
Expand Down
2 changes: 1 addition & 1 deletion hlink/tests/matching_scoring_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ def test_step_2_alpha_beta_thresholds(
score_tmp,
alpha_threshold,
threshold_ratio,
matching_conf["training"],
matching_conf["id_column"],
matching_conf["training"].get("decision"),
)
predictions.write.mode("overwrite").saveAsTable("predictions")

Expand Down

0 comments on commit ad6ce10

Please sign in to comment.