diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 00ffe0d..e12e1a2 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -8,7 +8,7 @@ Changelog ========= -0.9.0 (2024-08-29) +0.9.0 (2024-08-30) ------------------ **New features** diff --git a/ndonnx/_core/_numericimpl.py b/ndonnx/_core/_numericimpl.py index 885c920..a3c92e7 100644 --- a/ndonnx/_core/_numericimpl.py +++ b/ndonnx/_core/_numericimpl.py @@ -568,18 +568,17 @@ def cumulative_sum( else: raise ValueError("axis must be specified for multi-dimensional arrays") - if dtype is None: - if isinstance(x.dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)): - if ndx.iinfo(x.dtype).bits < 64: - out = x.astype(dtypes.int64) - else: - raise ndx.UnsupportedOperationError( - f"Cannot perform `cumulative_sum` using {x.dtype}" - ) + if isinstance(x.dtype, (dtypes.Unsigned, dtypes.NullableUnsigned)): + if ndx.iinfo(x.dtype).bits < 64: + out = x.astype(dtypes.int64) else: - out = x.astype(_determine_reduce_op_dtype(x, dtype, dtypes.int64)) + return NotImplemented + elif dtype == dtypes.uint64 or dtype == dtypes.nuint64: + raise ndx.UnsupportedOperationError( + f"Unsupported dtype parameter for cumulative_sum {dtype} due to missing kernel support" + ) else: - out = out.astype(dtype) + out = x.astype(_determine_reduce_op_dtype(x, None, dtypes.uint64)) out = from_corearray( opx.cumsum( @@ -589,10 +588,13 @@ def cumulative_sum( ) ) - if isinstance(x.dtype, dtypes.Unsigned): - out = out.astype(ndx.uint64) - elif isinstance(x.dtype, dtypes.NullableUnsigned): - out = out.astype(ndx.nuint64) + if dtype is None: + if isinstance(x.dtype, dtypes.Unsigned): + out = out.astype(ndx.uint64) + elif isinstance(x.dtype, dtypes.NullableUnsigned): + out = out.astype(ndx.nuint64) + else: + out = out.astype(dtype) # Exclude axis and create zeros of that shape if include_initial: diff --git a/ndonnx/_data_types/coretype.py b/ndonnx/_data_types/coretype.py index c82ee72..34fcdf5 100644 --- a/ndonnx/_data_types/coretype.py +++ b/ndonnx/_data_types/coretype.py @@ -69,7 +69,11 @@ def _parse_input(self, data: np.ndarray) -> dict[str, np.ndarray]: def _cast_from(self, array: Array) -> Array: if isinstance(array.dtype, CoreType): - return ndx.Array._from_fields(self, data=opx.cast(array._core(), to=self)) + return ( + ndx.Array._from_fields(self, data=opx.cast(array._core(), to=self)) + if array.dtype != self + else array.copy() + ) else: raise CastError(f"Cannot cast from {array.dtype} to {self}") diff --git a/tests/test_core.py b/tests/test_core.py index 589df5b..b0eaae3 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -933,13 +933,14 @@ def test_dynamic_reshape_has_no_static_shape(x, shape): ) @pytest.mark.parametrize("include_initial", [True, False]) @pytest.mark.parametrize( - "dtype", + "array_dtype", [ndx.int32, ndx.int64, ndx.float32, ndx.float64, ndx.uint8, ndx.uint16, ndx.uint32], ) @pytest.mark.parametrize( "array, axis", [ ([1, 2, 3], None), + ([100, 100], None), ([1, 2, 3], 0), ([[1, 2], [3, 4]], 0), ([[1, 2], [3, 4]], 1), @@ -948,21 +949,36 @@ def test_dynamic_reshape_has_no_static_shape(x, shape): ([[[[1]]], [[[3]]]], 1), ], ) -def test_cumulative_sum(array, axis, include_initial, dtype): - a = ndx.asarray(array, dtype=dtype) +@pytest.mark.parametrize( + "cumsum_dtype", + [None, ndx.int32, ndx.float32, ndx.float64, ndx.uint8, ndx.int8], +) +def test_cumulative_sum(array, axis, include_initial, array_dtype, cumsum_dtype): + a = ndx.asarray(array, dtype=array_dtype) assert_array_equal( - ndx.cumulative_sum(a, include_initial=include_initial, axis=axis).to_numpy(), + ndx.cumulative_sum( + a, include_initial=include_initial, axis=axis, dtype=cumsum_dtype + ).to_numpy(), np.cumulative_sum( np.asarray(array, a.dtype.to_numpy_dtype()), include_initial=include_initial, axis=axis, + dtype=cumsum_dtype.to_numpy_dtype() if cumsum_dtype is not None else None, ), ) def test_no_unsafe_cumulative_sum_cast(): with pytest.raises( - ndx.UnsupportedOperationError, match="Cannot perform `cumulative_sum`" + ndx.UnsupportedOperationError, + match="Unsupported operand type for cumulative_sum", ): a = ndx.asarray([1, 2, 3], ndx.uint64) ndx.cumulative_sum(a) + + with pytest.raises( + ndx.UnsupportedOperationError, + match="Unsupported dtype parameter for cumulative_sum", + ): + a = ndx.asarray([1, 2, 3], ndx.int32) + ndx.cumulative_sum(a, dtype=ndx.uint64)