diff --git a/keras/src/backend/jax/numpy.py b/keras/src/backend/jax/numpy.py index 19d6b81e78d..eae67ba6588 100644 --- a/keras/src/backend/jax/numpy.py +++ b/keras/src/backend/jax/numpy.py @@ -223,10 +223,8 @@ def absolute(x): return jnp.absolute(x) -@sparse.elementwise_unary(linear=False) def abs(x): - x = convert_to_tensor(x) - return jnp.absolute(x) + return absolute(x) def all(x, axis=None, keepdims=False): diff --git a/keras/src/backend/tensorflow/numpy.py b/keras/src/backend/tensorflow/numpy.py index 154249b7b08..3712fb3bf84 100644 --- a/keras/src/backend/tensorflow/numpy.py +++ b/keras/src/backend/tensorflow/numpy.py @@ -615,6 +615,7 @@ def zeros(shape, dtype=None): @sparse.elementwise_unary def absolute(x): + x = convert_to_tensor(x) # uintx and bool are always non-negative dtype = standardize_dtype(x.dtype) if "uint" in dtype or dtype == "bool": @@ -622,7 +623,6 @@ def absolute(x): return tf.abs(x) -@sparse.elementwise_unary def abs(x): return absolute(x) @@ -2405,4 +2405,4 @@ def correlate(x1, x2, mode="valid"): def select(condlist, choicelist, default=0): - return tfnp.select(condlist, choicelist, default=default) + return tf.experimental.numpy.select(condlist, choicelist, default=default) diff --git a/keras/src/backend/torch/numpy.py b/keras/src/backend/torch/numpy.py index 7edbdfeb454..56cf9614078 100644 --- a/keras/src/backend/torch/numpy.py +++ b/keras/src/backend/torch/numpy.py @@ -201,10 +201,6 @@ def zeros_like(x, dtype=None): def absolute(x): - return abs(x) - - -def abs(x): x = convert_to_tensor(x) # bool are always non-negative if standardize_dtype(x.dtype) == "bool": @@ -212,6 +208,10 @@ def abs(x): return torch.abs(x) +def abs(x): + return absolute(x) + + def all(x, axis=None, keepdims=False): x = convert_to_tensor(x) if axis is None: