diff --git a/hlink/linking/core/threshold.py b/hlink/linking/core/threshold.py index b0f57a0..b0523d3 100644 --- a/hlink/linking/core/threshold.py +++ b/hlink/linking/core/threshold.py @@ -7,7 +7,7 @@ from pyspark.sql import DataFrame from pyspark.sql.window import Window -from pyspark.sql.functions import rank, lead +from pyspark.sql.functions import col, lead, rank, when def get_threshold_ratio( @@ -79,10 +79,8 @@ def predict_using_thresholds( def _apply_alpha_threshold(pred_df: DataFrame, alpha_threshold: float) -> DataFrame: - return pred_df.selectExpr( - "*", - f"CASE WHEN probability >= {alpha_threshold} THEN 1 ELSE 0 END AS prediction", - ) + prediction = when(col("probability") >= alpha_threshold, 1).otherwise(0) + return pred_df.withColumn("prediction", prediction) def _apply_threshold_ratio(