Skip to content

Commit

Permalink
[#176] Rewrite _get_confusion_matrix() to avoid using 4 filters + counts
Browse files Browse the repository at this point in the history
Using a single select() should let us take better advantage of Spark's
parallel/distributed computing. My initial results profiling this are
pretty promising.
  • Loading branch information
riley-harper committed Dec 10, 2024
1 parent 9755f73 commit c43b57d
Showing 1 changed file with 24 additions and 21 deletions.
45 changes: 24 additions & 21 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pyspark.ml import Model, Transformer
import pyspark.sql
from pyspark.sql import DataFrame
from pyspark.sql.functions import count, mean
from pyspark.sql.functions import col, count, count_if, mean
from functools import reduce
import hlink.linking.core.threshold as threshold_core
import hlink.linking.core.classifier as classifier_core
Expand Down Expand Up @@ -752,27 +752,30 @@ def _get_confusion_matrix(
predictions: pyspark.sql.DataFrame,
dep_var: str,
) -> tuple[int, int, int, int]:
TP = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 1))
TP_count = TP.count()

FP = predictions.filter((predictions[dep_var] == 0) & (predictions.prediction == 1))
FP_count = FP.count()

# print(
# f"Confusion matrix -- true positives and false positivesTP {TP_count} FP {FP_count}"
# )

FN = predictions.filter((predictions[dep_var] == 1) & (predictions.prediction == 0))
FN_count = FN.count()

TN = predictions.filter((predictions[dep_var] == 0) & (predictions.prediction == 0))
TN_count = TN.count()

# print(
# f"Confusion matrix -- true negatives and false negatives: FN {FN_count} TN {TN_count}"
# )
"""
Compute the confusion matrix for the given DataFrame of predictions. The
confusion matrix is the count of true positives, false positives, false
negatives, and true negatives for the predictions.
return TP_count, FP_count, FN_count, TN_count
Return a tuple (true_positives, false_positives, false_negatives,
true_negatives).
"""
prediction_col = col("prediction")
label_col = col(dep_var)

confusion_matrix = predictions.select(
count_if((label_col == 1) & (prediction_col == 1)).alias("true_positives"),
count_if((label_col == 0) & (prediction_col == 1)).alias("false_positives"),
count_if((label_col == 1) & (prediction_col == 0)).alias("false_negatives"),
count_if((label_col == 0) & (prediction_col == 0)).alias("true_negatives"),
)
[confusion_row] = confusion_matrix.collect()
return (
confusion_row.true_positives,
confusion_row.false_positives,
confusion_row.false_negatives,
confusion_row.true_negatives,
)


def _get_aggregate_metrics(
Expand Down

0 comments on commit c43b57d

Please sign in to comment.