Skip to content

Commit

Permalink
Refactor axis logic across all backends and add support for multipl…
Browse files Browse the repository at this point in the history
…e axes in `expand_dims` and `squeeze` (#19252)

* Introduce `canonicalize_axis`, `standardize_axis_for_numpy` and `to_tuple_or_list`

* Add support  of multiple axes to `expand_dims` and `squeeze`

* Fix the strings

* Fix `pytest.warn`

* Fix `ops.squeeze`
  • Loading branch information
james77777778 authored Mar 4, 2024
1 parent 0d28a01 commit 1665377
Show file tree
Hide file tree
Showing 9 changed files with 228 additions and 111 deletions.
33 changes: 33 additions & 0 deletions keras/backend/common/backend_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import operator
import warnings


Expand Down Expand Up @@ -255,3 +256,35 @@ def compute_conv_transpose_output_shape(
else:
output_shape = [input_shape[0], filters] + output_shape
return output_shape


def canonicalize_axis(axis, num_dims):
"""Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
axis = operator.index(axis)
if not -num_dims <= axis < num_dims:
raise ValueError(
f"axis {axis} is out of bounds for an array with dimension "
f"{num_dims}."
)
if axis < 0:
axis = axis + num_dims
return axis


def standardize_axis_for_numpy(axis):
"""Standardize an axis to a tuple if it is a list in the numpy backend."""
return tuple(axis) if isinstance(axis, list) else axis


def to_tuple_or_list(value):
"""Convert the non-`None` value to either a tuple or a list."""
if value is None:
return value
if not isinstance(value, (int, tuple, list)):
raise ValueError(
"`value` must be an integer, tuple or list. "
f"Received: value={value}"
)
if isinstance(value, int):
return (value,)
return value
21 changes: 6 additions & 15 deletions keras/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from keras.backend import config
from keras.backend.common import dtypes
from keras.backend.common.backend_utils import canonicalize_axis
from keras.backend.common.backend_utils import to_tuple_or_list
from keras.backend.common.variables import standardize_dtype
from keras.backend.jax import sparse
from keras.backend.jax.core import cast
Expand Down Expand Up @@ -367,14 +369,7 @@ def concatenate(xs, axis=0):
bcoo_count = builtins.sum(isinstance(x, jax_sparse.BCOO) for x in xs)
if bcoo_count:
if bcoo_count == len(xs):
ndim = len(xs[0].shape)
if not -ndim <= axis < ndim:
raise ValueError(
f"In `axis`, axis {axis} is out of bounds for array "
f"of dimension {ndim}"
)
if axis < 0:
axis = axis + ndim
axis = canonicalize_axis(axis, len(xs[0].shape))
return jax_sparse.bcoo_concatenate(xs, dimension=axis)
else:
xs = [
Expand Down Expand Up @@ -1040,8 +1035,7 @@ def squeeze(x, axis=None):
if isinstance(x, jax_sparse.BCOO):
if axis is None:
axis = tuple(i for i, d in enumerate(x.shape) if d == 1)
elif isinstance(axis, int):
axis = (axis,)
axis = to_tuple_or_list(axis)
return jax_sparse.bcoo_squeeze(x, dimensions=axis)
return jnp.squeeze(x, axis=axis)

Expand All @@ -1055,11 +1049,8 @@ def transpose(x, axes=None):
else:
permutation = []
for a in axes:
if not -num_dims <= a < num_dims:
raise ValueError(
f"axis {a} out of bounds for tensor of rank {num_dims}"
)
permutation.append(a if a >= 0 else a + num_dims)
a = canonicalize_axis(a, num_dims)
permutation.append(a)
return jax_sparse.bcoo_transpose(x, permutation=permutation)
return jnp.transpose(x, axes=axes)

Expand Down
78 changes: 37 additions & 41 deletions keras/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from keras.backend import config
from keras.backend import standardize_dtype
from keras.backend.common import dtypes
from keras.backend.common.backend_utils import standardize_axis_for_numpy
from keras.backend.numpy.core import convert_to_tensor


Expand Down Expand Up @@ -82,7 +83,7 @@ def multiply(x1, x2):


def mean(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)
ori_dtype = standardize_dtype(x.dtype)
if "int" in ori_dtype or ori_dtype == "bool":
Expand All @@ -93,7 +94,7 @@ def mean(x, axis=None, keepdims=False):


def max(x, axis=None, keepdims=False, initial=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.max(x, axis=axis, keepdims=keepdims, initial=initial)


Expand All @@ -116,27 +117,27 @@ def abs(x):


def all(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.all(x, axis=axis, keepdims=keepdims)


def any(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.any(x, axis=axis, keepdims=keepdims)


def amax(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.amax(x, axis=axis, keepdims=keepdims)


def amin(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.amin(x, axis=axis, keepdims=keepdims)


def append(x1, x2, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
Expand Down Expand Up @@ -227,17 +228,17 @@ def arctanh(x):


def argmax(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.argmax(x, axis=axis).astype("int32")


def argmin(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.argmin(x, axis=axis).astype("int32")


def argsort(x, axis=-1):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.argsort(x, axis=axis).astype("int32")


Expand All @@ -246,7 +247,7 @@ def array(x, dtype=None):


def average(x, axis=None, weights=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)
dtypes_to_resolve = [x.dtype, float]
if weights is not None:
Expand Down Expand Up @@ -311,7 +312,7 @@ def clip(x, x_min, x_max):


def concatenate(xs, axis=0):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
Expand Down Expand Up @@ -354,14 +355,14 @@ def cosh(x):


def count_nonzero(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
# np.count_nonzero will return python int when axis=None, so we need
# to convert_to_tensor
return convert_to_tensor(np.count_nonzero(x, axis=axis)).astype("int32")


def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
Expand All @@ -378,15 +379,15 @@ def cross(x1, x2, axisa=-1, axisb=-1, axisc=-1, axis=None):


def cumprod(x, axis=None, dtype=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
dtype = dtypes.result_type(dtype or x.dtype)
if dtype == "bool":
dtype = "int32"
return np.cumprod(x, axis=axis, dtype=dtype)


def cumsum(x, axis=None, dtype=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
dtype = dtypes.result_type(dtype or x.dtype)
if dtype == "bool":
dtype = "int32"
Expand All @@ -398,14 +399,9 @@ def diag(x, k=0):


def diagonal(x, offset=0, axis1=0, axis2=1):
axis1 = tuple(axis1) if isinstance(axis1, list) else axis1
axis2 = tuple(axis2) if isinstance(axis2, list) else axis2
return np.diagonal(
x,
offset=offset,
axis1=axis1,
axis2=axis2,
)
axis1 = standardize_axis_for_numpy(axis1)
axis2 = standardize_axis_for_numpy(axis2)
return np.diagonal(x, offset=offset, axis1=axis1, axis2=axis2)


def diff(a, n=1, axis=-1):
Expand Down Expand Up @@ -443,7 +439,7 @@ def exp(x):


def expand_dims(x, axis):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.expand_dims(x, axis)


Expand All @@ -456,7 +452,7 @@ def expm1(x):


def flip(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.flip(x, axis=axis)


Expand Down Expand Up @@ -534,7 +530,7 @@ def less_equal(x1, x2):
def linspace(
start, stop, num=50, endpoint=True, retstep=False, dtype=None, axis=0
):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
if dtype is None:
dtypes_to_resolve = [
getattr(start, "dtype", type(start)),
Expand Down Expand Up @@ -657,7 +653,7 @@ def meshgrid(*x, indexing="xy"):


def min(x, axis=None, keepdims=False, initial=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.min(x, axis=axis, keepdims=keepdims, initial=initial)


Expand Down Expand Up @@ -737,7 +733,7 @@ def pad(x, pad_width, mode="constant", constant_values=None):


def prod(x, axis=None, keepdims=False, dtype=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)
if dtype is None:
dtype = dtypes.result_type(x.dtype)
Expand All @@ -749,7 +745,7 @@ def prod(x, axis=None, keepdims=False, dtype=None):


def quantile(x, q, axis=None, method="linear", keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)

ori_dtype = standardize_dtype(x.dtype)
Expand Down Expand Up @@ -818,17 +814,17 @@ def size(x):


def sort(x, axis=-1):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.sort(x, axis=axis)


def split(x, indices_or_sections, axis=0):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.split(x, indices_or_sections, axis=axis)


def stack(x, axis=0):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
dtype_set = set([getattr(a, "dtype", type(a)) for a in x])
if len(dtype_set) > 1:
dtype = dtypes.result_type(*dtype_set)
Expand All @@ -837,7 +833,7 @@ def stack(x, axis=0):


def std(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)
ori_dtype = standardize_dtype(x.dtype)
if "int" in ori_dtype or ori_dtype == "bool":
Expand All @@ -850,12 +846,12 @@ def swapaxes(x, axis1, axis2):


def take(x, indices, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.take(x, indices, axis=axis)


def take_along_axis(x, indices, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.take_along_axis(x, indices, axis=axis)


Expand Down Expand Up @@ -898,8 +894,8 @@ def tile(x, repeats):


def trace(x, offset=0, axis1=0, axis2=1):
axis1 = tuple(axis1) if isinstance(axis1, list) else axis1
axis2 = tuple(axis2) if isinstance(axis2, list) else axis2
axis1 = standardize_axis_for_numpy(axis1)
axis2 = standardize_axis_for_numpy(axis2)
x = convert_to_tensor(x)
dtype = standardize_dtype(x.dtype)
if dtype not in ("int64", "uint32", "uint64"):
Expand Down Expand Up @@ -1027,7 +1023,7 @@ def sqrt(x):


def squeeze(x, axis=None):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
return np.squeeze(x, axis=axis)


Expand All @@ -1037,7 +1033,7 @@ def transpose(x, axes=None):


def var(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
x = convert_to_tensor(x)
compute_dtype = dtypes.result_type(x.dtype, "float32")
result_dtype = dtypes.result_type(x.dtype, float)
Expand All @@ -1047,7 +1043,7 @@ def var(x, axis=None, keepdims=False):


def sum(x, axis=None, keepdims=False):
axis = tuple(axis) if isinstance(axis, list) else axis
axis = standardize_axis_for_numpy(axis)
dtype = standardize_dtype(x.dtype)
# follow jax's rule
if dtype in ("bool", "int8", "int16"):
Expand Down
Loading

0 comments on commit 1665377

Please sign in to comment.