Skip to content

Commit

Permalink
Make ops.abs and ops.absolute consistent between backends. (#19563)
Browse files Browse the repository at this point in the history
- The TensorFlow implementation was missing `convert_to_tensor`
- The sparse annotation was unnecessarily applied twice
- Now `abs` calls `absolute` in all backends

Also fixed TensorFlow `ops.select`.
  • Loading branch information
hertschuh authored Apr 19, 2024
1 parent a431507 commit 29d10d1
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 9 deletions.
4 changes: 1 addition & 3 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,8 @@ def absolute(x):
return jnp.absolute(x)


@sparse.elementwise_unary(linear=False)
def abs(x):
x = convert_to_tensor(x)
return jnp.absolute(x)
return absolute(x)


def all(x, axis=None, keepdims=False):
Expand Down
4 changes: 2 additions & 2 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -615,14 +615,14 @@ def zeros(shape, dtype=None):

@sparse.elementwise_unary
def absolute(x):
x = convert_to_tensor(x)
# uintx and bool are always non-negative
dtype = standardize_dtype(x.dtype)
if "uint" in dtype or dtype == "bool":
return x
return tf.abs(x)


@sparse.elementwise_unary
def abs(x):
return absolute(x)

Expand Down Expand Up @@ -2405,4 +2405,4 @@ def correlate(x1, x2, mode="valid"):


def select(condlist, choicelist, default=0):
return tfnp.select(condlist, choicelist, default=default)
return tf.experimental.numpy.select(condlist, choicelist, default=default)
8 changes: 4 additions & 4 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,17 @@ def zeros_like(x, dtype=None):


def absolute(x):
return abs(x)


def abs(x):
x = convert_to_tensor(x)
# bool are always non-negative
if standardize_dtype(x.dtype) == "bool":
return x
return torch.abs(x)


def abs(x):
return absolute(x)


def all(x, axis=None, keepdims=False):
x = convert_to_tensor(x)
if axis is None:
Expand Down

0 comments on commit 29d10d1

Please sign in to comment.