diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 8b477ae..86e6393 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -7,6 +7,12 @@ Changelog ========= +0.9.3 (2024-10-25) +------------------ + +- Reduced the number of unnecessary casts in :func:`ndonnx.argmax` and :func:`ndonnx.argmin`. + + 0.9.2 (2024-10-03) ------------------ diff --git a/docs/_static/classify_iris.png b/docs/_static/classify_iris.png index 6b33f25..ee04215 100644 Binary files a/docs/_static/classify_iris.png and b/docs/_static/classify_iris.png differ diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index a3c92e7..b8efde7 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -355,43 +355,39 @@ def matrix_transpose(self, x) -> ndx.Array: @validate_core def argmax(self, x, axis=None, keepdims=False): - if axis is None: - reshaped_x = ndx.reshape(x, [-1])._core() - if keepdims: - return from_corearray( - opx.reshape( - opx.arg_max(reshaped_x, axis=0, keepdims=False), - opx.const([1 for x in range(x.ndim)], dtype=dtypes.int64), - ) - ) - else: - return from_corearray( - opx.reshape( - opx.arg_max(reshaped_x, axis=0, keepdims=False), - opx.const([], dtype=dtypes.int64), - ) - ) - return _via_i64_f64(lambda x: opx.arg_max(x, axis=axis, keepdims=keepdims), [x]) + out = via_upcast( + lambda x: opx.arg_max( + x, + axis=axis or 0, + keepdims=int(keepdims), + ), + [ndx.reshape(x, [-1]) if axis is None else x], + cast_return=False, + int_dtype=ndx.int32, + float_dtype=ndx.float64, + ) + + while keepdims and out.ndim < x.ndim: + out = ndx.expand_dims(out, axis=0) + return out @validate_core def argmin(self, x, axis=None, keepdims=False): - if axis is None: - reshaped_x = ndx.reshape(x, [-1])._core() - if keepdims: - return from_corearray( - opx.reshape( - opx.arg_min(reshaped_x, axis=0, keepdims=False), - opx.const([1 for x in range(x.ndim)], dtype=dtypes.int64), - ) - ) - else: - return from_corearray( - opx.reshape( - opx.arg_min(reshaped_x, axis=0, keepdims=False), - opx.const([], dtype=dtypes.int64), - ) - ) - return _via_i64_f64(lambda x: opx.arg_min(x, axis=axis, keepdims=keepdims), [x]) + out = via_upcast( + lambda x: opx.arg_min( + x, + axis=axis or 0, + keepdims=int(keepdims), + ), + [ndx.reshape(x, [-1]) if axis is None else x], + cast_return=False, + int_dtype=ndx.int32, + float_dtype=ndx.float64, + ) + + while keepdims and out.ndim < x.ndim: + out = ndx.expand_dims(out, axis=0) + return out @validate_core def nonzero(self, x) -> tuple[Array, ...]: diff --git a/ndonnx/_funcs.py b/ndonnx/_funcs.py index d15dd16..3021b2d 100644 --- a/ndonnx/_funcs.py +++ b/ndonnx/_funcs.py @@ -670,7 +670,7 @@ def argmax(x, axis=None, keepdims=False): def argmin(x, axis=None, keepdims=False): if ( - out := x.dtype._ops.argmax(x, axis=axis, keepdims=keepdims) + out := x.dtype._ops.argmin(x, axis=axis, keepdims=keepdims) ) is not NotImplemented: return out raise UnsupportedOperationError(f"Unsupported operand type for argmin: '{x.dtype}'") diff --git a/tests/test_core.py b/tests/test_core.py index 5063954..a9bfe0f 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -4,6 +4,7 @@ from __future__ import annotations import re +import warnings import numpy as np import pytest @@ -992,3 +993,48 @@ def test_no_unsafe_cumulative_sum_cast(): ): a = ndx.asarray([1, 2, 3], ndx.int32) ndx.cumulative_sum(a, dtype=ndx.uint64) + + +@pytest.mark.parametrize("keepdims", [True, False]) +@pytest.mark.parametrize( + "func, x", + [ + (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int32)), + (np.argmax, np.array([[1, 2, 3], [4, 5, 6]], dtype=np.int32)), + (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int8)), + (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float32)), + (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.float64)), + (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float32)), + (np.argmin, np.array([[-11, 2, 3], [4, 5, -6]], dtype=np.int32)), + (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.float64)), + (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int16)), + ], +) +def test_argmaxmin(func, x, keepdims): + np_result = func(x, keepdims=keepdims) + ndx_result = getattr(ndx, func.__name__)( + ndx.asarray(x), keepdims=keepdims + ).to_numpy() + assert_array_equal(np_result, ndx_result) + + +# Pending ORT 1.19 conda-forge release before this becomes supported: +# https://github.com/conda-forge/onnxruntime-feedstock/pull/128 +@pytest.mark.parametrize( + "func, x", + [ + (np.argmax, np.array([1, 2, 3, 4, 5], dtype=np.int64)), + (np.argmin, np.array([1, 2, 3, 4, 5], dtype=np.int64)), + ], +) +def test_argmaxmin_unsupported_kernels(func, x): + import onnxruntime as ort + + if ort.__version__.startswith("19"): + warnings.warn( + "Please remove this test and update `argmax` and `argmin` to reflect expanded kernel support.", + Warning, + ) + + with pytest.raises(TypeError): + getattr(ndx, func.__name__)(ndx.asarray(x)) diff --git a/xfails.txt b/xfails.txt index 69431ee..9665406 100644 --- a/xfails.txt +++ b/xfails.txt @@ -90,8 +90,6 @@ array_api_tests/test_operators_and_elementwise_functions.py::test_signbit array_api_tests/test_operators_and_elementwise_functions.py::test_sinh array_api_tests/test_operators_and_elementwise_functions.py::test_sqrt array_api_tests/test_operators_and_elementwise_functions.py::test_tan -array_api_tests/test_searching_functions.py::test_argmax -array_api_tests/test_searching_functions.py::test_argmin array_api_tests/test_searching_functions.py::test_nonzero_zerodim_error array_api_tests/test_searching_functions.py::test_searchsorted array_api_tests/test_searching_functions.py::test_where