From 69970f3e868b6f150aa5e37ab3eded017e3e0a7b Mon Sep 17 00:00:00 2001 From: edumotya Date: Tue, 10 Oct 2023 21:48:10 +0200 Subject: [PATCH] fix macro when ignore_index is set --- src/torchmetrics/functional/classification/stat_scores.py | 2 ++ tests/unittests/classification/test_accuracy.py | 7 ++----- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/torchmetrics/functional/classification/stat_scores.py b/src/torchmetrics/functional/classification/stat_scores.py index 5153554253b..bd645ee3c37 100644 --- a/src/torchmetrics/functional/classification/stat_scores.py +++ b/src/torchmetrics/functional/classification/stat_scores.py @@ -416,6 +416,8 @@ def _multiclass_stat_scores_update( fp = confmat.sum(0) - tp fn = confmat.sum(1) - tp tn = confmat.sum() - (fp + fn + tp) + if ignore_index is not None: + fp[ignore_index] = 0 return tp, fp, tn, fn diff --git a/tests/unittests/classification/test_accuracy.py b/tests/unittests/classification/test_accuracy.py index 7501ee2f4ae..7906553f8f7 100644 --- a/tests/unittests/classification/test_accuracy.py +++ b/tests/unittests/classification/test_accuracy.py @@ -190,12 +190,9 @@ def _sklearn_accuracy_multiclass(preds, target, ignore_index, multidim_average, return _sklearn_accuracy(target, preds) confmat = sk_confusion_matrix(target, preds, labels=list(range(NUM_CLASSES))) acc_per_class = confmat.diagonal() / confmat.sum(axis=1) - acc_per_class[np.isnan(acc_per_class)] = 0.0 if average == "macro": - acc_per_class = acc_per_class[ - (np.bincount(preds, minlength=NUM_CLASSES) + np.bincount(target, minlength=NUM_CLASSES)) != 0.0 - ] - return acc_per_class.mean() + return np.nanmean(acc_per_class) + acc_per_class[np.isnan(acc_per_class)] = 0.0 if average == "weighted": weights = confmat.sum(1) return ((weights * acc_per_class) / weights.sum()).sum()