From c43b57d787c74df2fbf74377330ca370151938eb Mon Sep 17 00:00:00 2001
From: rileyh <rileyh@umn.edu>
Date: Tue, 10 Dec 2024 14:19:58 -0600
Subject: [PATCH] [#176] Rewrite _get_confusion_matrix() to avoid using 4
 filters + counts

Using a single select() should let us take better advantage of Spark's
parallel/distributed computing. My initial results profiling this are
pretty promising.
---
 .../link_step_train_test_models.py            | 45 ++++++++++---------
 1 file changed, 24 insertions(+), 21 deletions(-)

diff --git a/hlink/linking/model_exploration/link_step_train_test_models.py b/hlink/linking/model_exploration/link_step_train_test_models.py
index 6025998..d779121 100644
--- a/hlink/linking/model_exploration/link_step_train_test_models.py
+++ b/hlink/linking/model_exploration/link_step_train_test_models.py
@@ -21,7 +21,7 @@
 from pyspark.ml import Model, Transformer
 import pyspark.sql
 from pyspark.sql import DataFrame
-from pyspark.sql.functions import count, mean
+from pyspark.sql.functions import col, count, count_if, mean
 from functools import reduce
 import hlink.linking.core.threshold as threshold_core
 import hlink.linking.core.classifier as classifier_core
@@ -752,27 +752,30 @@ def _get_confusion_matrix(
     predictions: pyspark.sql.DataFrame,
     dep_var: str,
 ) -> tuple[int, int, int, int]:
-    TP = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 1))
-    TP_count = TP.count()
-
-    FP = predictions.filter((predictions[dep_var] == 0) & (predictions.prediction == 1))
-    FP_count = FP.count()
-
-    # print(
-    #   f"Confusion matrix -- true positives and false positivesTP {TP_count} FP {FP_count}"
-    # )
-
-    FN = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 0))
-    FN_count = FN.count()
-
-    TN = predictions.filter((predictions[dep_var] == 0) & (predictions.prediction == 0))
-    TN_count = TN.count()
-
-    # print(
-    #   f"Confusion matrix -- true negatives and false negatives: FN {FN_count}  TN {TN_count}"
-    # )
+    """
+    Compute the confusion matrix for the given DataFrame of predictions. The
+    confusion matrix is the count of true positives, false positives, false
+    negatives, and true negatives for the predictions.
 
-    return TP_count, FP_count, FN_count, TN_count
+    Return a tuple (true_positives, false_positives, false_negatives,
+    true_negatives).
+    """
+    prediction_col = col("prediction")
+    label_col = col(dep_var)
+
+    confusion_matrix = predictions.select(
+        count_if((label_col == 1) & (prediction_col == 1)).alias("true_positives"),
+        count_if((label_col == 0) & (prediction_col == 1)).alias("false_positives"),
+        count_if((label_col == 1) & (prediction_col == 0)).alias("false_negatives"),
+        count_if((label_col == 0) & (prediction_col == 0)).alias("true_negatives"),
+    )
+    [confusion_row] = confusion_matrix.collect()
+    return (
+        confusion_row.true_positives,
+        confusion_row.false_positives,
+        confusion_row.false_negatives,
+        confusion_row.true_negatives,
+    )
 
 
 def _get_aggregate_metrics(