From f889c1f20f03ab386436d09376c272ef9540579e Mon Sep 17 00:00:00 2001 From: HongYu <20734616+james77777778@users.noreply.github.com> Date: Thu, 16 Nov 2023 06:55:43 +0800 Subject: [PATCH] Apply `backend.result_type` to `append`, `average`, `broadcast_to`, `concatenate`, `copy`, `count_nonzero`, `cross`, `diag*`, `diff` and `digitize` (#18779) * Apply `backend.result_type` to `append`, `average`, `broadcast_to`, `concatenate` and `copy` * Apply `backend.result_type` to `count_nonzero` * Apply `backend.result_type` to `cross`, `diag` and `diagonal` * dtype test cleanup for `convert_to_tensor` * Apply `backend.result_type` to `diff` and `digitize` * Revert tf's `digitize` * Revert `test_convert_to_tensor` --- keras/backend/jax/numpy.py | 17 ++- keras/backend/numpy/numpy.py | 36 +++++- keras/backend/tensorflow/numpy.py | 50 ++++++-- keras/backend/torch/numpy.py | 36 ++++-- keras/ops/core_test.py | 73 ++++++----- keras/ops/numpy.py | 28 +++-- keras/ops/numpy_test.py | 195 ++++++++++++++++++++++++++++++ 7 files changed, 357 insertions(+), 78 deletions(-) diff --git a/keras/backend/jax/numpy.py b/keras/backend/jax/numpy.py index 4bf7e701fda..68934afe58a 100644 --- a/keras/backend/jax/numpy.py +++ b/keras/backend/jax/numpy.py @@ -110,11 +110,7 @@ def amin(x, axis=None, keepdims=False): return jnp.amin(x, axis=axis, keepdims=keepdims) -def append( - x1, - x2, - axis=None, -): +def append(x1, x2, axis=None): x1 = convert_to_tensor(x1) x2 = convert_to_tensor(x2) return jnp.append(x1, x2, axis=axis) @@ -213,6 +209,15 @@ def array(x, dtype=None): def average(x, axis=None, weights=None): + x = convert_to_tensor(x) + dtypes_to_resolve = [x.dtype, float] + if weights is not None: + weights = convert_to_tensor(weights) + dtypes_to_resolve.append(weights.dtype) + dtype = dtypes.result_type(*dtypes_to_resolve) + x = cast(x, dtype) + if weights is not None: + weights = cast(weights, dtype) return jnp.average(x, weights=weights, axis=axis) @@ -268,7 +273,7 @@ def cosh(x): def count_nonzero(x, axis=None): - return jnp.count_nonzero(x, axis=axis) + return cast(jnp.count_nonzero(x, axis=axis), "int32") def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): diff --git a/keras/backend/numpy/numpy.py b/keras/backend/numpy/numpy.py index 879dd64c205..0394edd2c17 100644 --- a/keras/backend/numpy/numpy.py +++ b/keras/backend/numpy/numpy.py @@ -1,4 +1,5 @@ import numpy as np +import tree from keras.backend import config from keras.backend import standardize_dtype @@ -94,12 +95,13 @@ def amin(x, axis=None, keepdims=False): return np.amin(x, axis=axis, keepdims=keepdims) -def append( - x1, - x2, - axis=None, -): +def append(x1, x2, axis=None): axis = tuple(axis) if isinstance(axis, list) else axis + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) return np.append(x1, x2, axis=axis) @@ -205,6 +207,15 @@ def array(x, dtype=None): def average(x, axis=None, weights=None): axis = tuple(axis) if isinstance(axis, list) else axis + x = convert_to_tensor(x) + dtypes_to_resolve = [x.dtype, float] + if weights is not None: + weights = convert_to_tensor(weights) + dtypes_to_resolve.append(weights.dtype) + dtype = dtypes.result_type(*dtypes_to_resolve) + x = x.astype(dtype) + if weights is not None: + weights = weights.astype(dtype) return np.average(x, weights=weights, axis=axis) @@ -259,6 +270,12 @@ def clip(x, x_min, x_max): def concatenate(xs, axis=0): axis = tuple(axis) if isinstance(axis, list) else axis + dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + xs = tree.map_structure( + lambda x: convert_to_tensor(x).astype(dtype), xs + ) return np.concatenate(xs, axis=axis) @@ -296,11 +313,18 @@ def cosh(x): def count_nonzero(x, axis=None): axis = tuple(axis) if isinstance(axis, list) else axis - return np.count_nonzero(x, axis=axis) + # np.count_nonzero will return python int when axis=None, so we need + # to convert_to_tensor + return convert_to_tensor(np.count_nonzero(x, axis=axis)).astype("int32") def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): axis = tuple(axis) if isinstance(axis, list) else axis + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = x1.astype(dtype) + x2 = x2.astype(dtype) return np.cross( x1, x2, diff --git a/keras/backend/tensorflow/numpy.py b/keras/backend/tensorflow/numpy.py index 142a0ec0d5b..ffe884961bc 100644 --- a/keras/backend/tensorflow/numpy.py +++ b/keras/backend/tensorflow/numpy.py @@ -322,11 +322,12 @@ def amin(x, axis=None, keepdims=False): return tfnp.amin(x, axis=axis, keepdims=keepdims) -def append( - x1, - x2, - axis=None, -): +def append(x1, x2, axis=None): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) return tfnp.append(x1, x2, axis=axis) @@ -446,12 +447,22 @@ def array(x, dtype=None): def average(x, axis=None, weights=None): + x = convert_to_tensor(x) if not isinstance(axis, (list, tuple)): axis = (axis,) + dtypes_to_resolve = [x.dtype, float] + if weights is not None: + weights = convert_to_tensor(weights) + dtypes_to_resolve.append(weights.dtype) + dtype = dtypes.result_type(*dtypes_to_resolve) + x = tf.cast(x, dtype) + if weights is not None: + weights = tf.cast(weights, dtype) for a in axis: # `tfnp.average` does not handle multiple axes. x = tfnp.average(x, weights=weights, axis=a) - return x + # TODO: tfnp.average incorrectly promote bfloat16 to float64 + return tf.cast(x, dtype) def broadcast_to(x, shape): @@ -485,6 +496,10 @@ def concatenate(xs, axis=0): tf.sparse.to_dense(x) if isinstance(x, tf.SparseTensor) else x for x in xs ] + dtype_set = set([getattr(x, "dtype", type(x)) for x in xs]) + if len(dtype_set) > 1: + dtype = dtypes.result_type(*dtype_set) + xs = tf.nest.map_structure(lambda x: tf.cast(x, dtype), xs) return tfnp.concatenate(xs, axis=axis) @@ -528,10 +543,15 @@ def cosh(x): def count_nonzero(x, axis=None): - return tfnp.count_nonzero(x, axis=axis) + return tf.cast(tfnp.count_nonzero(x, axis=axis), "int32") def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None): + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + dtype = dtypes.result_type(x1.dtype, x2.dtype) + x1 = tf.cast(x1, dtype) + x2 = tf.cast(x2, dtype) return tfnp.cross( x1, x2, @@ -568,7 +588,22 @@ def diff(a, n=1, axis=-1): def digitize(x, bins): + x = convert_to_tensor(x) bins = list(bins) + + # bins must be float type + bins = tf.nest.map_structure(lambda x: float(x), bins) + + # TODO: tf.raw_ops.Bucketize doesn't support bool, bfloat16, float16, int8 + # int16, uint8, uint16, uint32 + ori_dtype = standardize_dtype(x.dtype) + if ori_dtype in ("bool", "int8", "int16", "uint8", "uint16"): + x = tf.cast(x, "int32") + elif ori_dtype == "uint32": + x = tf.cast(x, "int64") + elif ori_dtype in ("bfloat16", "float16"): + x = tf.cast(x, "float32") + if isinstance(x, tf.RaggedTensor): return tf.ragged.map_flat_values( lambda y: tf.raw_ops.Bucketize(input=y, boundaries=bins), x @@ -579,7 +614,6 @@ def digitize(x, bins): values=tf.raw_ops.Bucketize(input=x.values, boundaries=bins), dense_shape=tf.identity(x.dense_shape), ) - x = convert_to_tensor(x) return tf.raw_ops.Bucketize(input=x, boundaries=bins) diff --git a/keras/backend/torch/numpy.py b/keras/backend/torch/numpy.py index 1bccdc83c86..793ca3ca902 100644 --- a/keras/backend/torch/numpy.py +++ b/keras/backend/torch/numpy.py @@ -206,11 +206,7 @@ def amin(x, axis=None, keepdims=False): return torch.amin(x, dim=axis, keepdim=keepdims) -def append( - x1, - x2, - axis=None, -): +def append(x1, x2, axis=None): x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) if axis is None: return torch.cat((x1.flatten(), x2.flatten())) @@ -316,13 +312,18 @@ def array(x, dtype=None): def average(x, axis=None, weights=None): x = convert_to_tensor(x) - # Conversion to float necessary for `torch.mean` - x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x + dtypes_to_resolve = [x.dtype, float] + if weights is not None: + weights = convert_to_tensor(weights) + dtypes_to_resolve.append(weights.dtype) + dtype = dtypes.result_type(*dtypes_to_resolve) + x = cast(x, dtype) + if weights is not None: + weights = cast(weights, dtype) if axis == () or axis == []: # Torch handles the empty axis case differently from numpy. return x if weights is not None: - weights = convert_to_tensor(weights) return torch.sum(torch.mul(x, weights), dim=axis) / torch.sum( weights, dim=-1 ) @@ -432,7 +433,7 @@ def count_nonzero(x, axis=None): if axis == () or axis == []: # Torch handles the empty axis case differently from numpy. return cast(torch.ne(x, 0), "int32") - return torch.count_nonzero(x, dim=axis).T + return cast(torch.count_nonzero(x, dim=axis).T, "int32") def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=-1): @@ -442,8 +443,19 @@ def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=-1): f"Received: axisa={axisa}, axisb={axisb}, axisc={axisc}. Please " "use `axis` arg in torch backend." ) - x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2) - return torch.cross(x1, x2, dim=axis) + x1 = convert_to_tensor(x1) + x2 = convert_to_tensor(x2) + compute_dtype = dtypes.result_type(x1.dtype, x2.dtype) + result_dtype = compute_dtype + # TODO: torch.cross doesn't support bfloat16 with gpu + if get_device() == "cuda" and compute_dtype == "bfloat16": + compute_dtype = "float32" + # TODO: torch.cross doesn't support float16 with cpu + elif get_device() == "cpu" and compute_dtype == "float16": + compute_dtype = "float32" + x1 = cast(x1, compute_dtype) + x2 = cast(x2, compute_dtype) + return cast(torch.cross(x1, x2, dim=axis), result_dtype) def cumprod(x, axis=None): @@ -485,6 +497,8 @@ def diff(a, n=1, axis=-1): def digitize(x, bins): x = convert_to_tensor(x) bins = convert_to_tensor(bins) + if standardize_dtype(x.dtype) == "bool": + x = cast(x, "uint8") return cast(torch.bucketize(x, bins, right=True), "int32") diff --git a/keras/ops/core_test.py b/keras/ops/core_test.py index 043f9475e8e..d636b2f6f3b 100644 --- a/keras/ops/core_test.py +++ b/keras/ops/core_test.py @@ -11,7 +11,6 @@ from keras import ops from keras import optimizers from keras import testing -from keras.backend.common import standardize_dtype from keras.backend.common.keras_tensor import KerasTensor from keras.backend.common.variables import ALLOWED_DTYPES from keras.ops import core @@ -300,28 +299,17 @@ def test_convert_to_tensor(self): self.assertAllEqual(x, (1, 1)) self.assertIsInstance(x, np.ndarray) - # Empty lists should give an empty array with the default float type. + # Empty lists should give an empty array. x = ops.convert_to_tensor([]) - x = ops.convert_to_numpy(x) + np_x = ops.convert_to_numpy(x) + self.assertTrue(ops.is_tensor(x)) self.assertAllEqual(x, []) - self.assertIsInstance(x, np.ndarray) - self.assertEqual(x.dtype.name, "float32") + self.assertIsInstance(np_x, np.ndarray) # Partially converted. x = ops.convert_to_tensor((1, ops.array(2), 3)) self.assertAllEqual(x, (1, 2, 3)) - # Check dtype convertion. - x = [[1, 0, 1], [1, 1, 0]] - output = ops.convert_to_tensor(x, dtype="int32") - self.assertEqual(standardize_dtype(output.dtype), "int32") - x = [[1, 0, 1], [1, 1, 0]] - output = ops.convert_to_tensor(x, dtype="float32") - self.assertEqual(standardize_dtype(output.dtype), "float32") - x = [[1, 0, 1], [1, 1, 0]] - output = ops.convert_to_tensor(x, dtype="bool") - self.assertEqual(standardize_dtype(output.dtype), "bool") - with self.assertRaises(ValueError): ops.convert_to_numpy(KerasTensor((2,))) @@ -435,8 +423,18 @@ def fn(elems): backend.convert_to_numpy(output), 2 * np.ones((2, 3)) ) + def test_is_tensor(self): + np_x = np.array([[1, 2, 3], [3, 2, 1]]) + x = backend.convert_to_tensor(np_x) + if backend.backend() != "numpy": + self.assertFalse(ops.is_tensor(np_x)) + self.assertTrue(ops.is_tensor(x)) + self.assertFalse(ops.is_tensor([1, 2, 3])) + class CoreOpsDtypeTest(testing.TestCase, parameterized.TestCase): + import jax # enable bfloat16 for numpy + # TODO: Using uint64 will lead to weak type promotion (`float`), # resulting in different behavior between JAX and Keras. Currently, we # are skipping the test for uint64 @@ -451,23 +449,30 @@ class CoreOpsDtypeTest(testing.TestCase, parameterized.TestCase): ] @parameterized.parameters( - (bool(0), "bool"), - (int(0), "int32"), - (float(0), backend.floatx()), - ([False, True, False], "bool"), - ([1, 2, 3], "int32"), - ([1.0, 2.0, 3.0], backend.floatx()), - ([1, 2.0, 3], backend.floatx()), - ([[False], [True], [False]], "bool"), - ([[1], [2], [3]], "int32"), - ([[1], [2.0], [3]], backend.floatx()), + ((), None, backend.floatx()), + ([], None, backend.floatx()), + (bool(0), None, "bool"), + (int(0), None, "int32"), + (float(0), None, backend.floatx()), + ([False, True, False], None, "bool"), + ([1, 2, 3], None, "int32"), + ([1.0, 2.0, 3.0], None, backend.floatx()), + ([1, 2.0, 3], None, backend.floatx()), + ([[False], [True], [False]], None, "bool"), + ([[1], [2], [3]], None, "int32"), + ([[1], [2.0], [3]], None, backend.floatx()), *[ - (np.array(0, dtype=dtype), dtype) + (np.array(0, dtype=dtype), None, dtype) + for dtype in ALL_DTYPES + if dtype is not None + ], + *[ + ([[1, 0, 1], [1, 1, 0]], dtype, dtype) for dtype in ALL_DTYPES if dtype is not None ], ) - def test_convert_to_tensor(self, x, expected_dtype): + def test_convert_to_tensor(self, x, dtype, expected_dtype): # We have to disable x64 for jax backend since jnp.array doesn't respect # JAX_DEFAULT_DTYPE_BITS=32 in `./conftest.py`. We also need to downcast # the expected dtype from 64 bit to 32 bit. @@ -481,14 +486,8 @@ def test_convert_to_tensor(self, x, expected_dtype): with jax_disable_x64: self.assertEqual( - backend.standardize_dtype(ops.convert_to_tensor(x).dtype), + backend.standardize_dtype( + ops.convert_to_tensor(x, dtype=dtype).dtype + ), expected_dtype, ) - - def test_is_tensor(self): - np_x = np.array([[1, 2, 3], [3, 2, 1]]) - x = backend.convert_to_tensor(np_x) - if backend.backend() != "numpy": - self.assertFalse(ops.is_tensor(np_x)) - self.assertTrue(ops.is_tensor(x)) - self.assertFalse(ops.is_tensor([1, 2, 3])) diff --git a/keras/ops/numpy.py b/keras/ops/numpy.py index f529dfdc151..47c9a88fb87 100644 --- a/keras/ops/numpy.py +++ b/keras/ops/numpy.py @@ -604,12 +604,16 @@ def call(self, x1, x2): def compute_output_spec(self, x1, x2): x1_shape = x1.shape x2_shape = x2.shape + dtype = dtypes.result_type( + getattr(x1, "dtype", type(x1)), + getattr(x2, "dtype", type(x2)), + ) if self.axis is None: if None in x1_shape or None in x2_shape: output_shape = [None] else: output_shape = [int(np.prod(x1_shape) + np.prod(x2_shape))] - return KerasTensor(output_shape, dtype=x1.dtype) + return KerasTensor(output_shape, dtype=dtype) if not shape_equal(x1_shape, x2_shape, [self.axis]): raise ValueError( @@ -620,7 +624,7 @@ def compute_output_spec(self, x1, x2): output_shape = list(x1_shape) output_shape[self.axis] = x1_shape[self.axis] + x2_shape[self.axis] - return KerasTensor(output_shape, dtype=x1.dtype) + return KerasTensor(output_shape, dtype=dtype) @keras_export(["keras.ops.append", "keras.ops.numpy.append"]) @@ -1183,18 +1187,18 @@ def call(self, x, weights=None): return backend.numpy.average(x, weights=weights, axis=self.axis) def compute_output_spec(self, x, weights=None): + dtypes_to_resolve = [getattr(x, "dtype", type(x)), float] if weights is not None: shape_match = shape_equal(x.shape, weights.shape, allow_none=True) if self.axis is not None: shape_match_on_axis = shape_equal( [x.shape[self.axis]], weights.shape, allow_none=True ) + dtypes_to_resolve.append(getattr(weights, "dtype", type(weights))) + dtype = dtypes.result_type(*dtypes_to_resolve) if self.axis is None: if weights is None or shape_match: - return KerasTensor( - [], - dtype=x.dtype, - ) + return KerasTensor([], dtype=dtype) else: raise ValueError( "`weights` must have the same shape as `x` when " @@ -1204,8 +1208,7 @@ def compute_output_spec(self, x, weights=None): if weights is None or shape_match_on_axis or shape_match: return KerasTensor( - reduce_shape(x.shape, axis=[self.axis]), - dtype=x.dtype, + reduce_shape(x.shape, axis=[self.axis]), dtype=dtype ) else: # `weights` can either be a 1D array of length `x.shape[axis]` or @@ -1463,6 +1466,7 @@ def compute_output_spec(self, xs): first_shape = xs[0].shape total_size_on_axis = 0 all_sparse = True + dtypes_to_resolve = [] for x in xs: if not shape_equal( x.shape, first_shape, axis=[self.axis], allow_none=True @@ -1478,9 +1482,11 @@ def compute_output_spec(self, xs): else: total_size_on_axis += x.shape[self.axis] all_sparse = all_sparse and getattr(x, "sparse", False) + dtypes_to_resolve.append(getattr(x, "dtype", type(x))) output_shape = list(first_shape) output_shape[self.axis] = total_size_on_axis - return KerasTensor(output_shape, dtype=x.dtype, sparse=all_sparse) + dtype = dtypes.result_type(*dtypes_to_resolve) + return KerasTensor(output_shape, dtype=dtype, sparse=all_sparse) @keras_export( @@ -1718,7 +1724,9 @@ def compute_output_spec(self, x1, x2): output_shape = ( output_shape[: self.axisc] + value_size + output_shape[self.axisc :] ) - return KerasTensor(output_shape, dtype=x1.dtype) + + dtype = dtypes.result_type(x1.dtype, x2.dtype) + return KerasTensor(output_shape, dtype=dtype) @keras_export(["keras.ops.cross", "keras.ops.numpy.cross"]) diff --git a/keras/ops/numpy_test.py b/keras/ops/numpy_test.py index 68664f09707..fe05e81edd8 100644 --- a/keras/ops/numpy_test.py +++ b/keras/ops/numpy_test.py @@ -4610,6 +4610,26 @@ def test_any(self, dtype): # TODO: test_einsum + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_append(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.append(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.append(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + knp.Append().symbolic_call(x1, x2).dtype, expected_dtype + ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) def test_argmax(self, dtype): import jax.numpy as jnp @@ -4855,6 +4875,50 @@ def test_array(self, x, expected_dtype): ) # TODO: support the assertion of knp.Array + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_average(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.average(x1_jax, weights=x2_jax).dtype + ) + if dtype1 is not None and "float" not in dtype1: + if dtype2 is not None and "float" not in dtype2: + if "int64" in (dtype1, dtype2) or "uint32" in (dtype1, dtype2): + expected_dtype = backend.floatx() + + self.assertEqual( + standardize_dtype(knp.average(x1, weights=x2).dtype), expected_dtype + ) + self.assertEqual( + knp.Average().symbolic_call(x1, weights=x2).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_broadcast_to(self, dtype): + import jax.numpy as jnp + + x = knp.ones((3,), dtype=dtype) + x_jax = jnp.ones((3,), dtype=dtype) + expected_dtype = standardize_dtype( + jnp.broadcast_to(x_jax, (3, 3)).dtype + ) + + self.assertEqual( + standardize_dtype(knp.broadcast_to(x, (3, 3)).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.BroadcastTo((3, 3)).symbolic_call(x).dtype), + expected_dtype, + ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) def test_ceil(self, dtype): import jax.numpy as jnp @@ -4895,6 +4959,28 @@ def test_clip(self, dtype): expected_dtype, ) + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_concatenate(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1,), dtype=dtype1) + x2 = knp.ones((1,), dtype=dtype2) + x1_jax = jnp.ones((1,), dtype=dtype1) + x2_jax = jnp.ones((1,), dtype=dtype2) + expected_dtype = standardize_dtype( + jnp.concatenate([x1_jax, x2_jax]).dtype + ) + + self.assertEqual( + standardize_dtype(knp.concatenate([x1, x2]).dtype), expected_dtype + ) + self.assertEqual( + knp.Concatenate().symbolic_call([x1, x2]).dtype, expected_dtype + ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) def test_cos(self, dtype): import jax.numpy as jnp @@ -4927,6 +5013,115 @@ def test_cosh(self, dtype): expected_dtype, ) + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_copy(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.copy(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.copy(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Copy().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_count_nonzero(self, dtype): + x = knp.ones((1,), dtype=dtype) + expected_dtype = "int32" + + self.assertEqual( + standardize_dtype(knp.count_nonzero(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.CountNonzero().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters( + named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) + ) + def test_cross(self, dtypes): + import jax.numpy as jnp + + dtype1, dtype2 = dtypes + x1 = knp.ones((1, 1, 3), dtype=dtype1) + x2 = knp.ones((1, 1, 3), dtype=dtype2) + x1_jax = jnp.ones((1, 1, 3), dtype=dtype1) + x2_jax = jnp.ones((1, 1, 3), dtype=dtype2) + expected_dtype = standardize_dtype(jnp.cross(x1_jax, x2_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.cross(x1, x2).dtype), expected_dtype + ) + self.assertEqual( + knp.Cross().symbolic_call(x1, x2).dtype, expected_dtype + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_diag(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.diag(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.diag(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Diag().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_diagonal(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1, 1, 1), dtype=dtype) + x_jax = jnp.ones((1, 1, 1), dtype=dtype) + expected_dtype = standardize_dtype(jnp.diagonal(x_jax).dtype) + + self.assertEqual( + standardize_dtype(knp.diagonal(x).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Diagonal().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_diff(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.diff(x_jax).dtype) + + self.assertEqual(standardize_dtype(knp.diff(x).dtype), expected_dtype) + self.assertEqual( + standardize_dtype(knp.Diff().symbolic_call(x).dtype), + expected_dtype, + ) + + @parameterized.named_parameters(named_product(dtype=ALL_DTYPES)) + def test_digitize(self, dtype): + import jax.numpy as jnp + + x = knp.ones((1,), dtype=dtype) + bins = knp.ones((1,), dtype=dtype) + x_jax = jnp.ones((1,), dtype=dtype) + x_bins = jnp.ones((1,), dtype=dtype) + expected_dtype = standardize_dtype(jnp.digitize(x_jax, x_bins).dtype) + + self.assertEqual( + standardize_dtype(knp.digitize(x, bins).dtype), expected_dtype + ) + self.assertEqual( + standardize_dtype(knp.Digitize().symbolic_call(x, bins).dtype), + expected_dtype, + ) + @parameterized.named_parameters( named_product(dtypes=itertools.combinations(ALL_DTYPES, 2)) )