Skip to content

Commit

Permalink
[#174] Replace a SQL query with the equivalent spark expression
Browse files Browse the repository at this point in the history
This prevents a possible SQL injection error by setting alpha_threshold to
something weird.  It's also a bit easier to read and work with in my
experience. It's more composable since you can build up the expression instead
of having to write all of the SQL at once.
  • Loading branch information
riley-harper committed Dec 5, 2024
1 parent 5424513 commit dd16360
Showing 1 changed file with 3 additions and 5 deletions.
8 changes: 3 additions & 5 deletions hlink/linking/core/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from pyspark.sql import DataFrame
from pyspark.sql.window import Window
from pyspark.sql.functions import rank, lead
from pyspark.sql.functions import col, lead, rank, when


def get_threshold_ratio(
Expand Down Expand Up @@ -79,10 +79,8 @@ def predict_using_thresholds(


def _apply_alpha_threshold(pred_df: DataFrame, alpha_threshold: float) -> DataFrame:
return pred_df.selectExpr(
"*",
f"CASE WHEN probability >= {alpha_threshold} THEN 1 ELSE 0 END AS prediction",
)
prediction = when(col("probability") >= alpha_threshold, 1).otherwise(0)
return pred_df.withColumn("prediction", prediction)


def _apply_threshold_ratio(
Expand Down

0 comments on commit dd16360

Please sign in to comment.