Skip to content

Commit

Permalink
[#174] Rewrite some thresholding code to use PySpark exprs instead of…
Browse files Browse the repository at this point in the history
… SQL
  • Loading branch information
riley-harper committed Dec 5, 2024
1 parent dd16360 commit 647a751
Showing 1 changed file with 29 additions and 23 deletions.
52 changes: 29 additions & 23 deletions hlink/linking/core/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,38 +94,44 @@ def _apply_threshold_ratio(
'In order to calculate the threshold ratio based on probabilities, you need to have a "probability" column in your data.'
)

windowSpec = Window.partitionBy(df[id_a]).orderBy(
df["probability"].desc(), df[id_b]
)
windowSpec = Window.partitionBy(id_a).orderBy(col("probability").desc(), id_b)
prob_rank = rank().over(windowSpec)
prob_lead = lead(df["probability"], 1).over(windowSpec)
prob_lead = lead("probability", 1).over(windowSpec)

should_compute_probability_ratio = (
col("second_best_prob").isNotNull()
& (col("second_best_prob") >= alpha_threshold)
& (col("prob_rank") == 1)
)
# To be a match, the row must...
# 1. Have prob_rank 1, so that it's the most likely match,
# 2. Have a probability of at least alpha_threshold,
# and
# 3. Either have no ratio (since there's no second best probability of at
# least alpha_threshold), or have a ratio of more than threshold_ratio.
is_match = (
(col("probability") >= alpha_threshold)
& (col("prob_rank") == 1)
& ((col("ratio") > threshold_ratio) | col("ratio").isNull())
)
return (
df.select(
df["*"],
"*",
prob_rank.alias("prob_rank"),
prob_lead.alias("second_best_prob"),
)
.selectExpr(
.select(
"*",
f"""
IF(
second_best_prob IS NOT NULL
AND second_best_prob >= {alpha_threshold}
AND prob_rank == 1,
probability / second_best_prob,
NULL)
AS ratio
""",
when(
should_compute_probability_ratio,
col("probability") / col("second_best_prob"),
)
.otherwise(None)
.alias("ratio"),
)
.selectExpr(
.select(
"*",
f"""
CAST(
probability >= {alpha_threshold}
AND prob_rank == 1
AND (ratio > {threshold_ratio} OR ratio IS NULL)
AS INT) AS prediction
""",
is_match.cast("integer").alias("prediction"),
)
.drop("prob_rank")
)

0 comments on commit 647a751

Please sign in to comment.