Skip to content

Commit

Permalink
Apply backend.result_type to append, average, broadcast_to, `…
Browse files Browse the repository at this point in the history
…concatenate`, `copy`, `count_nonzero`, `cross`, `diag*`, `diff` and `digitize` (#18779)

* Apply `backend.result_type` to `append`, `average`, `broadcast_to`, `concatenate` and `copy`

* Apply `backend.result_type` to `count_nonzero`

* Apply `backend.result_type` to `cross`, `diag` and `diagonal`

* dtype test cleanup for `convert_to_tensor`

* Apply `backend.result_type` to `diff` and `digitize`

* Revert tf's `digitize`

* Revert `test_convert_to_tensor`
  • Loading branch information
james77777778 authored Nov 15, 2023
1 parent fae2b0d commit f889c1f
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 78 deletions.
17 changes: 11 additions & 6 deletions keras/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,7 @@ def amin(x, axis=None, keepdims=False):
return jnp.amin(x, axis=axis, keepdims=keepdims)


def append(
x1,
x2,
axis=None,
):
def append(x1, x2, axis=None):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.append(x1, x2, axis=axis)
Expand Down Expand Up @@ -213,6 +209,15 @@ def array(x, dtype=None):


def average(x, axis=None, weights=None):
x = convert_to_tensor(x)
dtypes_to_resolve = [x.dtype, float]
if weights is not None:
weights = convert_to_tensor(weights)
dtypes_to_resolve.append(weights.dtype)
dtype = dtypes.result_type(*dtypes_to_resolve)
x = cast(x, dtype)
if weights is not None:
weights = cast(weights, dtype)
return jnp.average(x, weights=weights, axis=axis)


Expand Down Expand Up @@ -268,7 +273,7 @@ def cosh(x):


def count_nonzero(x, axis=None):
return jnp.count_nonzero(x, axis=axis)
return cast(jnp.count_nonzero(x, axis=axis), "int32")


def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
Expand Down
36 changes: 30 additions & 6 deletions keras/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
import tree

from keras.backend import config
from keras.backend import standardize_dtype
Expand Down Expand Up @@ -94,12 +95,13 @@ def amin(x, axis=None, keepdims=False):
return np.amin(x, axis=axis, keepdims=keepdims)


def append(
x1,
x2,
axis=None,
):
def append(x1, x2, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.append(x1, x2, axis=axis)


Expand Down Expand Up @@ -205,6 +207,15 @@ def array(x, dtype=None):

def average(x, axis=None, weights=None):
axis = tuple(axis) if isinstance(axis, list) else axis
x = convert_to_tensor(x)
dtypes_to_resolve = [x.dtype, float]
if weights is not None:
weights = convert_to_tensor(weights)
dtypes_to_resolve.append(weights.dtype)
dtype = dtypes.result_type(*dtypes_to_resolve)
x = x.astype(dtype)
if weights is not None:
weights = weights.astype(dtype)
return np.average(x, weights=weights, axis=axis)


Expand Down Expand Up @@ -259,6 +270,12 @@ def clip(x, x_min, x_max):

def concatenate(xs, axis=0):
axis = tuple(axis) if isinstance(axis, list) else axis
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tree.map_structure(
lambda x: convert_to_tensor(x).astype(dtype), xs
)
return np.concatenate(xs, axis=axis)


Expand Down Expand Up @@ -296,11 +313,18 @@ def cosh(x):

def count_nonzero(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
return np.count_nonzero(x, axis=axis)
# np.count_nonzero will return python int when axis=None, so we need
# to convert_to_tensor
return convert_to_tensor(np.count_nonzero(x, axis=axis)).astype("int32")


def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.cross(
x1,
x2,
Expand Down
50 changes: 42 additions & 8 deletions keras/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,11 +322,12 @@ def amin(x, axis=None, keepdims=False):
return tfnp.amin(x, axis=axis, keepdims=keepdims)


def append(
x1,
x2,
axis=None,
):
def append(x1, x2, axis=None):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
x1 = tf.cast(x1, dtype)
x2 = tf.cast(x2, dtype)
return tfnp.append(x1, x2, axis=axis)


Expand Down Expand Up @@ -446,12 +447,22 @@ def array(x, dtype=None):


def average(x, axis=None, weights=None):
x = convert_to_tensor(x)
if not isinstance(axis, (list, tuple)):
axis = (axis,)
dtypes_to_resolve = [x.dtype, float]
if weights is not None:
weights = convert_to_tensor(weights)
dtypes_to_resolve.append(weights.dtype)
dtype = dtypes.result_type(*dtypes_to_resolve)
x = tf.cast(x, dtype)
if weights is not None:
weights = tf.cast(weights, dtype)
for a in axis:
# `tfnp.average` does not handle multiple axes.
x = tfnp.average(x, weights=weights, axis=a)
return x
# TODO: tfnp.average incorrectly promote bfloat16 to float64
return tf.cast(x, dtype)


def broadcast_to(x, shape):
Expand Down Expand Up @@ -485,6 +496,10 @@ def concatenate(xs, axis=0):
tf.sparse.to_dense(x) if isinstance(x, tf.SparseTensor) else x
for x in xs
]
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
xs = tf.nest.map_structure(lambda x: tf.cast(x, dtype), xs)
return tfnp.concatenate(xs, axis=axis)


Expand Down Expand Up @@ -528,10 +543,15 @@ def cosh(x):


def count_nonzero(x, axis=None):
return tfnp.count_nonzero(x, axis=axis)
return tf.cast(tfnp.count_nonzero(x, axis=axis), "int32")


def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
x1 = tf.cast(x1, dtype)
x2 = tf.cast(x2, dtype)
return tfnp.cross(
x1,
x2,
Expand Down Expand Up @@ -568,7 +588,22 @@ def diff(a, n=1, axis=-1):


def digitize(x, bins):
x = convert_to_tensor(x)
bins = list(bins)

# bins must be float type
bins = tf.nest.map_structure(lambda x: float(x), bins)

# TODO: tf.raw_ops.Bucketize doesn't support bool, bfloat16, float16, int8
# int16, uint8, uint16, uint32
ori_dtype = standardize_dtype(x.dtype)
if ori_dtype in ("bool", "int8", "int16", "uint8", "uint16"):
x = tf.cast(x, "int32")
elif ori_dtype == "uint32":
x = tf.cast(x, "int64")
elif ori_dtype in ("bfloat16", "float16"):
x = tf.cast(x, "float32")

if isinstance(x, tf.RaggedTensor):
return tf.ragged.map_flat_values(
lambda y: tf.raw_ops.Bucketize(input=y, boundaries=bins), x
Expand All @@ -579,7 +614,6 @@ def digitize(x, bins):
values=tf.raw_ops.Bucketize(input=x.values, boundaries=bins),
dense_shape=tf.identity(x.dense_shape),
)
x = convert_to_tensor(x)
return tf.raw_ops.Bucketize(input=x, boundaries=bins)


Expand Down
36 changes: 25 additions & 11 deletions keras/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,11 +206,7 @@ def amin(x, axis=None, keepdims=False):
return torch.amin(x, dim=axis, keepdim=keepdims)


def append(
x1,
x2,
axis=None,
):
def append(x1, x2, axis=None):
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
if axis is None:
return torch.cat((x1.flatten(), x2.flatten()))
Expand Down Expand Up @@ -316,13 +312,18 @@ def array(x, dtype=None):

def average(x, axis=None, weights=None):
x = convert_to_tensor(x)
# Conversion to float necessary for `torch.mean`
x = cast(x, "float32") if x.dtype in TORCH_INT_TYPES else x
dtypes_to_resolve = [x.dtype, float]
if weights is not None:
weights = convert_to_tensor(weights)
dtypes_to_resolve.append(weights.dtype)
dtype = dtypes.result_type(*dtypes_to_resolve)
x = cast(x, dtype)
if weights is not None:
weights = cast(weights, dtype)
if axis == () or axis == []:
# Torch handles the empty axis case differently from numpy.
return x
if weights is not None:
weights = convert_to_tensor(weights)
return torch.sum(torch.mul(x, weights), dim=axis) / torch.sum(
weights, dim=-1
)
Expand Down Expand Up @@ -432,7 +433,7 @@ def count_nonzero(x, axis=None):
if axis == () or axis == []:
# Torch handles the empty axis case differently from numpy.
return cast(torch.ne(x, 0), "int32")
return torch.count_nonzero(x, dim=axis).T
return cast(torch.count_nonzero(x, dim=axis).T, "int32")


def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=-1):
Expand All @@ -442,8 +443,19 @@ def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=-1):
f"Received: axisa={axisa}, axisb={axisb}, axisc={axisc}. Please "
"use `axis` arg in torch backend."
)
x1, x2 = convert_to_tensor(x1), convert_to_tensor(x2)
return torch.cross(x1, x2, dim=axis)
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
compute_dtype = dtypes.result_type(x1.dtype, x2.dtype)
result_dtype = compute_dtype
# TODO: torch.cross doesn't support bfloat16 with gpu
if get_device() == "cuda" and compute_dtype == "bfloat16":
compute_dtype = "float32"
# TODO: torch.cross doesn't support float16 with cpu
elif get_device() == "cpu" and compute_dtype == "float16":
compute_dtype = "float32"
x1 = cast(x1, compute_dtype)
x2 = cast(x2, compute_dtype)
return cast(torch.cross(x1, x2, dim=axis), result_dtype)


def cumprod(x, axis=None):
Expand Down Expand Up @@ -485,6 +497,8 @@ def diff(a, n=1, axis=-1):
def digitize(x, bins):
x = convert_to_tensor(x)
bins = convert_to_tensor(bins)
if standardize_dtype(x.dtype) == "bool":
x = cast(x, "uint8")
return cast(torch.bucketize(x, bins, right=True), "int32")


Expand Down
Loading

0 comments on commit f889c1f

Please sign in to comment.