Skip to content

Commit

Permalink
Precision issue in get_confusion_matrix (#7187)
Browse files Browse the repository at this point in the history
Fixes #7186

### Description
remove unnecessary float()

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: KumoLiu <[email protected]>
Co-authored-by: Wenqi Li <[email protected]>
  • Loading branch information
KumoLiu and wyli authored Nov 2, 2023
1 parent d42440b commit 2658b00
Show file tree
Hide file tree
Showing 7 changed files with 41 additions and 50 deletions.
13 changes: 5 additions & 8 deletions monai/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,9 +153,6 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou
if not include_background:
y_pred, y = ignore_background(y_pred=y_pred, y=y)

y = y.float()
y_pred = y_pred.float()

if y.shape != y_pred.shape:
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")

Expand All @@ -165,12 +162,12 @@ def get_confusion_matrix(y_pred: torch.Tensor, y: torch.Tensor, include_backgrou
# As for classification tasks, S equals to 1.
y_pred = y_pred.reshape(batch_size, n_class, -1)
y = y.reshape(batch_size, n_class, -1)
tp = ((y_pred + y) == 2).float()
tn = ((y_pred + y) == 0).float()
tp = (y_pred + y) == 2
tn = (y_pred + y) == 0

tp = tp.sum(dim=[2])
tn = tn.sum(dim=[2])
p = y.sum(dim=[2])
tp = tp.sum(dim=[2]).float()
tn = tn.sum(dim=[2]).float()
p = y.sum(dim=[2]).float()
n = y.shape[-1] - p

fn = p - tp
Expand Down
13 changes: 5 additions & 8 deletions monai/metrics/f_beta_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,6 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background:
if not include_background:
y_pred, y = ignore_background(y_pred=y_pred, y=y)

y = y.float()
y_pred = y_pred.float()

if y.shape != y_pred.shape:
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")

Expand All @@ -75,12 +72,12 @@ def get_f_beta_score(y_pred: torch.Tensor, y: torch.Tensor, include_background:
# As for classification tasks, S equals to 1.
y_pred = y_pred.view(batch_size, n_class, -1)
y = y.view(batch_size, n_class, -1)
tp = ((y_pred + y) == 2).float()
tn = ((y_pred + y) == 0).float()
tp = (y_pred + y) == 2
tn = (y_pred + y) == 0

tp = tp.sum(dim=[2])
tn = tn.sum(dim=[2])
p = y.sum(dim=[2])
tp = tp.sum(dim=[2]).float()
tn = tn.sum(dim=[2]).float()
p = y.sum(dim=[2]).float()
n = y.shape[-1] - p

fn = p - tp
Expand Down
3 changes: 0 additions & 3 deletions monai/metrics/meaniou.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,9 +130,6 @@ def compute_iou(
if not include_background:
y_pred, y = ignore_background(y_pred=y_pred, y=y)

y = y.float()
y_pred = y_pred.float()

if y.shape != y_pred.shape:
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")

Expand Down
12 changes: 0 additions & 12 deletions monai/metrics/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_
self.sq_func = partial(torch.pow, exponent=2.0)

def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
y_pred = y_pred.float()
y = y.float()

return compute_mean_error_metrics(y_pred, y, func=self.sq_func)


Expand Down Expand Up @@ -143,9 +140,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_
self.abs_func = torch.abs

def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
y_pred = y_pred.float()
y = y.float()

return compute_mean_error_metrics(y_pred, y, func=self.abs_func)


Expand Down Expand Up @@ -176,9 +170,6 @@ def __init__(self, reduction: MetricReduction | str = MetricReduction.MEAN, get_
self.sq_func = partial(torch.pow, exponent=2.0)

def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
y_pred = y_pred.float()
y = y.float()

mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func)
return torch.sqrt(mse_out)

Expand Down Expand Up @@ -218,9 +209,6 @@ def __init__(
self.sq_func = partial(torch.pow, exponent=2.0)

def _compute_metric(self, y_pred: torch.Tensor, y: torch.Tensor) -> Any:
y_pred = y_pred.float()
y = y.float()

mse_out = compute_mean_error_metrics(y_pred, y, func=self.sq_func)
return 20 * math.log10(self.max_val) - 10 * torch.log10(mse_out)

Expand Down
3 changes: 0 additions & 3 deletions monai/metrics/surface_dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,9 +228,6 @@ def compute_surface_dice(
f"y_pred and y should have same shape, but instead, shapes are {y_pred.shape} (y_pred) and {y.shape} (y)."
)

y = y.float()
y_pred = y_pred.float()

batch_size, n_class = y_pred.shape[:2]

if n_class != len(class_thresholds):
Expand Down
32 changes: 16 additions & 16 deletions monai/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,37 +95,37 @@ def do_metric_reduction(
# some elements might be Nan (if ground truth y was missing (zeros))
# we need to account for it
nans = torch.isnan(f)
not_nans = (~nans).float()
not_nans = ~nans

t_zero = torch.zeros(1, device=f.device, dtype=f.dtype)
t_zero = torch.zeros(1, device=f.device, dtype=torch.float)
reduction = look_up_option(reduction, MetricReduction)
if reduction == MetricReduction.NONE:
return f, not_nans
return f, not_nans.float()

f[nans] = 0
if reduction == MetricReduction.MEAN:
# 2 steps, first, mean by channel (accounting for nans), then by batch
not_nans = not_nans.sum(dim=1)
f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average
not_nans = not_nans.sum(dim=1).float()
f = torch.where(not_nans > 0, f.sum(dim=1).float() / not_nans, t_zero) # channel average

not_nans = (not_nans > 0).float().sum(dim=0)
f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average
not_nans = (not_nans > 0).sum(dim=0).float()
f = torch.where(not_nans > 0, f.sum(dim=0).float() / not_nans, t_zero) # batch average

elif reduction == MetricReduction.SUM:
not_nans = not_nans.sum(dim=[0, 1])
not_nans = not_nans.sum(dim=[0, 1]).float()
f = torch.sum(f, dim=[0, 1]) # sum over the batch and channel dims
elif reduction == MetricReduction.MEAN_BATCH:
not_nans = not_nans.sum(dim=0)
f = torch.where(not_nans > 0, f.sum(dim=0) / not_nans, t_zero) # batch average
not_nans = not_nans.sum(dim=0).float()
f = torch.where(not_nans > 0, f.sum(dim=0).float() / not_nans, t_zero) # batch average
elif reduction == MetricReduction.SUM_BATCH:
not_nans = not_nans.sum(dim=0)
f = f.sum(dim=0) # the batch sum
not_nans = not_nans.sum(dim=0).float()
f = f.sum(dim=0).float() # the batch sum
elif reduction == MetricReduction.MEAN_CHANNEL:
not_nans = not_nans.sum(dim=1)
f = torch.where(not_nans > 0, f.sum(dim=1) / not_nans, t_zero) # channel average
not_nans = not_nans.sum(dim=1).float()
f = torch.where(not_nans > 0, f.sum(dim=1).float() / not_nans, t_zero) # channel average
elif reduction == MetricReduction.SUM_CHANNEL:
not_nans = not_nans.sum(dim=1)
f = f.sum(dim=1) # the channel sum
not_nans = not_nans.sum(dim=1).float()
f = f.sum(dim=1).float() # the channel sum
elif reduction != MetricReduction.NONE:
raise ValueError(
f"Unsupported reduction: {reduction}, available options are "
Expand Down
15 changes: 15 additions & 0 deletions tests/test_compute_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@

TEST_CASES_CLF = [data_clf.copy(), result_clf]

TEST_CASE_PRECISION = [
{
"y_pred": torch.zeros([1, 1, 1024, 1024, 44], device=_device),
"y": torch.zeros([1, 1, 1024, 1024, 44], device=_device),
},
torch.tensor([[[0.0, 0.0, 46137344.0, 0.0]]]),
]


class TestConfusionMatrix(unittest.TestCase):
@parameterized.expand([TEST_CASE_CONFUSION_MATRIX])
Expand Down Expand Up @@ -274,6 +282,13 @@ def test_clf_with_nan(self, input_data, expected_value):
expected_value = compute_confusion_matrix_metric("tpr", expected_value)
assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)

@parameterized.expand([TEST_CASE_PRECISION])
def test_precision(self, input_data, expected_value):
# include or ignore background
result = get_confusion_matrix(**input_data)
assert_allclose(result, expected_value, atol=1e-4, rtol=1e-4)
np.testing.assert_equal(result.device, input_data["y_pred"].device)


if __name__ == "__main__":
unittest.main()

0 comments on commit 2658b00

Please sign in to comment.