Skip to content

Commit

Permalink
Fix for TF XLA compilation error for SpectralNormalization. (#19232)
Browse files Browse the repository at this point in the history
- On the Tensorflow backend, reimplemented all cases for `ops.linalg.norm` instead of sometimes relying on `tf.linalg.norm`. `tf.linalg.norm` sometimes fails to compile on XLA or returns tensors with no shape. Note that there is not special op used by `tf.linalg.norm`.
- Added more test cases for `ops.linalg.norm`:
  - vector norms are now also tested with 2 dimensional inputs
  - axis as a tuple of 2 values is tested
  - many more cases of invalid combinations of rank / ord / axis are verified
  • Loading branch information
hertschuh authored Feb 27, 2024
1 parent 0dcac0e commit ab1f404
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 43 deletions.
86 changes: 55 additions & 31 deletions keras/backend/tensorflow/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,30 +40,31 @@ def norm(x, ord=None, axis=None, keepdims=False):
axis = tuple(range(ndim))
elif isinstance(axis, int):
axis = (axis,)

if any(a < -ndim or a >= ndim for a in axis):
raise ValueError(
"All `axis` values must be in the range [-ndim, ndim). "
f"Received inputs with ndim={ndim}, while axis={axis}"
)
axis = axis[0] if len(axis) == 1 else axis
num_axes = 1 if isinstance(axis, int) else len(axis)

if num_axes == 1 and ord is None:
ord = "euclidean"
elif num_axes == 2 and ord is None:
ord = "fro"

if standardize_dtype(x.dtype) == "int64":
dtype = config.floatx()
else:
dtype = dtypes.result_type(x.dtype, float)
x = cast(x, dtype)

# Fast path to utilze `tf.linalg.norm`
if (num_axes == 1 and ord in ("euclidean", 1, 2, float("inf"))) or (
num_axes == 2 and ord in ("euclidean", "fro", 1, 2, float("inf"))
):
return tf.linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)

# Ref: jax.numpy.linalg.norm
if num_axes == 1 and ord not in ("fro", "nuc"):
if ord == float("-inf"):
if num_axes == 1:
if ord is None or ord == 2:
return tf.sqrt(
tf.reduce_sum(x * tf.math.conj(x), axis=axis, keepdims=keepdims)
)
elif ord == float("inf"):
return tf.math.reduce_max(
tf.math.abs(x), axis=axis, keepdims=keepdims
)
elif ord == float("-inf"):
return tf.math.reduce_min(
tf.math.abs(x), axis=axis, keepdims=keepdims
)
Expand All @@ -73,22 +74,30 @@ def norm(x, ord=None, axis=None, keepdims=False):
axis=axis,
keepdims=keepdims,
)
elif isinstance(ord, str):
raise ValueError(
f"Invalid `ord` argument for vector norm. Received: ord={ord}"
)
else:
ord = convert_to_tensor(ord, dtype=x.dtype)
out = tf.math.reduce_sum(
tf.pow(tf.math.abs(x), ord), axis=axis, keepdims=keepdims
)
return tf.pow(out, 1.0 / ord)
elif num_axes == 2 and ord in ("nuc", float("-inf"), -2, -1):
elif num_axes == 2:
row_axis, col_axis = axis[0], axis[1]
row_axis = row_axis + ndim if row_axis < 0 else row_axis
col_axis = col_axis + ndim if col_axis < 0 else col_axis
if ord == float("-inf"):
if not keepdims and row_axis > col_axis:
row_axis -= 1
x = tf.math.reduce_min(
tf.reduce_sum(tf.math.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis,
if ord is None or ord == "fro":
return tf.sqrt(
tf.reduce_sum(x * tf.math.conj(x), axis=axis, keepdims=keepdims)
)
elif ord == 1:
if not keepdims and col_axis > row_axis:
col_axis -= 1
x = tf.math.reduce_max(
tf.reduce_sum(tf.math.abs(x), axis=row_axis, keepdims=keepdims),
axis=col_axis,
keepdims=keepdims,
)
elif ord == -1:
Expand All @@ -99,29 +108,44 @@ def norm(x, ord=None, axis=None, keepdims=False):
axis=col_axis,
keepdims=keepdims,
)
else:
elif ord == float("inf"):
if not keepdims and row_axis > col_axis:
row_axis -= 1
x = tf.math.reduce_max(
tf.reduce_sum(tf.math.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis,
keepdims=keepdims,
)
elif ord == float("-inf"):
if not keepdims and row_axis > col_axis:
row_axis -= 1
x = tf.math.reduce_min(
tf.reduce_sum(tf.math.abs(x), axis=col_axis, keepdims=keepdims),
axis=row_axis,
keepdims=keepdims,
)
elif ord in ("nuc", 2, -2):
x = tfnp.moveaxis(x, axis, (-2, -1))
if ord == -2:
x = tf.math.reduce_min(
tf.linalg.svd(x, compute_uv=False), axis=-1
)
elif ord == 2:
x = tf.math.reduce_max(
tf.linalg.svd(x, compute_uv=False), axis=-1
)
else:
x = tf.math.reduce_sum(
tf.linalg.svd(x, compute_uv=False), axis=-1
)
if keepdims:
x = tf.expand_dims(x, axis[0])
x = tf.expand_dims(x, axis[1])
else:
raise ValueError(
f"Invalid `ord` argument for matrix norm. Received: ord={ord}"
)
return x

if num_axes == 1:
raise ValueError(
f"Invalid `ord` argument for vector norm. Received: ord={ord}"
)
elif num_axes == 2:
raise ValueError(
f"Invalid `ord` argument for matrix norm. Received: ord={ord}"
)
else:
raise ValueError(f"Invalid axis values. Received: axis={axis}")

Expand Down
4 changes: 0 additions & 4 deletions keras/layers/normalization/spectral_normalization_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,6 @@

class SpectralNormalizationTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
@pytest.mark.skipif(
backend.backend() == "tensorflow",
reason="TODO: test fails on GPU. XLA related.",
)
def test_basic_spectralnorm(self):
self.run_layer_test(
layers.SpectralNormalization,
Expand Down
4 changes: 2 additions & 2 deletions keras/ops/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def norm(x, ord=None, axis=None, keepdims=False):
- For matrices:
- `ord=None`: Frobenius norm
- `ord="fro"`: Frobenius norm
- `ord=nuc`: nuclear norm
- `ord="nuc"`: nuclear norm
- `ord=np.inf`: `max(sum(abs(x), axis=1))`
- `ord=-np.inf`: `min(sum(abs(x), axis=1))`
- `ord=0`: not supported
Expand All @@ -306,7 +306,7 @@ def norm(x, ord=None, axis=None, keepdims=False):
- For vectors:
- `ord=None`: 2-norm
- `ord="fro"`: not supported
- `ord=nuc`: not supported
- `ord="nuc"`: not supported
- `ord=np.inf`: `max(abs(x))`
- `ord=-np.inf`: `min(abs(x))`
- `ord=0`: `sum(x != 0)`
Expand Down
38 changes: 32 additions & 6 deletions keras/ops/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,19 +398,45 @@ def _reconstruct(lu, pivots, m, n):

@parameterized.named_parameters(
named_product(
ndim=[1, 2],
ord=[None, "fro", "nuc", -np.inf, -2, -1, 0, 1, 2, np.inf, 3],
axis=[None, 1, -1],
axis=[None, 1, -1, (0, 1)],
keepdims=[False, True],
)
)
def test_norm_vectors(self, ord, axis, keepdims):
if axis is None:
def test_norm(self, ndim, ord, axis, keepdims):
if ndim == 1:
x = np.random.random((5,))
else:
x = np.random.random((5, 6))
if ord in ("fro", "nuc"):
error = RuntimeError if backend.backend() == "torch" else ValueError
with self.assertRaises(error):

vector_norm = (ndim == 1) or isinstance(axis, int)

axis_out_of_bounds = ndim == 1 and (
axis == 1 or isinstance(axis, tuple)
)
expected_error = None
# when an out of bounds axis triggers an IndexError on torch is complex
if (
axis_out_of_bounds
and (not isinstance(axis, tuple) or ord is None)
and ord not in ("fro", "nuc")
):
expected_error = IndexError
elif (
axis_out_of_bounds
or (vector_norm and isinstance(axis, tuple)) # inv. axis for vector
or (vector_norm and ord in ("fro", "nuc")) # invalid ord for vector
or (not vector_norm and ord in (0, 3)) # invalid ord for matrix
):
expected_error = RuntimeError

if expected_error is not None:
# Non-torch backends always throw a ValueError
expected_error = (
expected_error if backend.backend() == "torch" else ValueError
)
with self.assertRaises(expected_error):
linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
return
output = linalg.norm(x, ord=ord, axis=axis, keepdims=keepdims)
Expand Down

0 comments on commit ab1f404

Please sign in to comment.