Skip to content

Commit

Permalink
[#179] Filter with math.isnan() instead of is not np.nan
Browse files Browse the repository at this point in the history
This lets us handle math.nan when aggregating threshold metrics results. It
keeps np.nan more contained to the code that actually cares about Pandas and
Numpy.
  • Loading branch information
riley-harper committed Dec 12, 2024
1 parent fd40c35 commit 1ecef81
Showing 1 changed file with 7 additions and 12 deletions.
19 changes: 7 additions & 12 deletions hlink/linking/model_exploration/link_step_train_test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,14 +658,9 @@ def _capture_prediction_results(
fn_count,
tn_count,
) = _get_confusion_matrix(predictions, dep_var)
precision_raw = metrics_core.precision(tp_count, fp_count)
recall_raw = metrics_core.recall(tp_count, fn_count)
mcc_raw = metrics_core.mcc(tp_count, tn_count, fp_count, fn_count)

# Convert Python's math.nan to np.nan for numpy/pandas processing
precision = precision_raw if not math.isnan(precision_raw) else np.nan
recall = recall_raw if not math.isnan(recall_raw) else np.nan
mcc = mcc_raw if not math.isnan(mcc_raw) else np.nan
precision = metrics_core.precision(tp_count, fp_count)
recall = metrics_core.recall(tp_count, fn_count)
mcc = metrics_core.mcc(tp_count, tn_count, fp_count, fn_count)

result = ThresholdTestResult(
precision=precision,
Expand Down Expand Up @@ -813,11 +808,11 @@ def _aggregate_per_threshold_results(

# Pull out columns to be aggregated
precision_test = [
r.precision for r in prediction_results if r.precision is not np.nan
r.precision for r in prediction_results if not math.isnan(r.precision)
]
recall_test = [r.recall for r in prediction_results if r.recall is not np.nan]
pr_auc_test = [r.pr_auc for r in prediction_results if r.pr_auc is not np.nan]
mcc_test = [r.mcc for r in prediction_results if r.mcc is not np.nan]
recall_test = [r.recall for r in prediction_results if not math.isnan(r.recall)]
pr_auc_test = [r.pr_auc for r in prediction_results if not math.isnan(r.pr_auc)]
mcc_test = [r.mcc for r in prediction_results if not math.isnan(r.mcc)]

# # variance requires at least two values
precision_test_sd = (
Expand Down

0 comments on commit 1ecef81

Please sign in to comment.