diff --git a/keras/backend/torch/numpy.py b/keras/backend/torch/numpy.py index 3f058f39b7e..e3ab6d0b170 100644 --- a/keras/backend/torch/numpy.py +++ b/keras/backend/torch/numpy.py @@ -892,9 +892,7 @@ def min(x, axis=None, keepdims=False, initial=None): if axis is None: result = torch.min(x) else: - if isinstance(axis, list): - axis = axis[-1] - result = torch.min(x, dim=axis, keepdim=keepdims) + result = amin(x, axis=axis, keepdims=keepdims) if isinstance(getattr(result, "values", None), torch.Tensor): result = result.values diff --git a/keras/ops/numpy_test.py b/keras/ops/numpy_test.py index a52e138f7d5..40043f5dd3f 100644 --- a/keras/ops/numpy_test.py +++ b/keras/ops/numpy_test.py @@ -3429,6 +3429,12 @@ def test_min(self): self.assertAllClose(knp.min(x), np.min(x)) self.assertAllClose(knp.Min()(x), np.min(x)) + self.assertAllClose(knp.min(x, axis=(0, 1)), np.min(x, (0, 1))) + self.assertAllClose(knp.Min((0, 1))(x), np.min(x, (0, 1))) + + self.assertAllClose(knp.min(x, axis=()), np.min(x, axis=())) + self.assertAllClose(knp.Min(())(x), np.min(x, axis=())) + self.assertAllClose(knp.min(x, 0), np.min(x, 0)) self.assertAllClose(knp.Min(0)(x), np.min(x, 0))