Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Keras Ops] Add einops-style rearrange() to keras.ops #20733

Merged
merged 11 commits into from
Jan 15, 2025
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from keras.src.ops.core import unstack
from keras.src.ops.core import vectorized_map
from keras.src.ops.core import while_loop
from keras.src.ops.einops import rearrange
from keras.src.ops.linalg import cholesky
from keras.src.ops.linalg import det
from keras.src.ops.linalg import eig
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from keras.src.ops.core import unstack
from keras.src.ops.core import vectorized_map
from keras.src.ops.core import while_loop
from keras.src.ops.einops import rearrange
from keras.src.ops.linalg import cholesky
from keras.src.ops.linalg import det
from keras.src.ops.linalg import eig
Expand Down
157 changes: 157 additions & 0 deletions keras/src/ops/einops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import re

from keras.src.api_export import keras_export
from keras.src.backend import KerasTensor
from keras.src.backend import any_symbolic_tensors
from keras.src.ops.core import shape
from keras.src.ops.numpy import prod
from keras.src.ops.numpy import reshape
from keras.src.ops.numpy import transpose
from keras.src.ops.operation import Operation


def _create_axes_map(axes, input_shape, axes_lengths):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want any documentation or code comments on these?

axes_map = {}

for axis, dim in zip(axes, input_shape):
# Check for grouped axes pattern, e.g., "(h1 h)"
grouped_axes = re.match(r"\(([\w\s]+)\)", axis)

if grouped_axes:
inner_axes = grouped_axes.group(1).split()
known_axes = [a for a in inner_axes if a in axes_lengths]
inferred_axes = [a for a in inner_axes if a not in axes_lengths]

if inferred_axes:
inferred_axis = inferred_axes[0]
known_product = prod([axes_lengths[a] for a in known_axes])
axes_lengths[inferred_axis] = dim // known_product

axes_map.update({a: axes_lengths[a] for a in inner_axes})
else:
axes_map[axis] = dim

return axes_map


def _create_grouped_axes(axes):
grouped_output_axes = []
for axis in axes:
grouped_axes = re.match(r"\(([\w\s]+)\)", axis)

if grouped_axes:
inner_axes = grouped_axes.group(1).split()
grouped_output_axes.append(inner_axes)
else:
grouped_output_axes.append([axis])

return grouped_output_axes


def _flatten_group(axes):
return [x for xs in axes for x in xs]


def _get_transpose_order(from_shape, to_shape):
flattened_from_shape = _flatten_group(_create_grouped_axes(from_shape))

return [flattened_from_shape.index(dim) for dim in to_shape]


def _compute_output_shape(axes_map, grouped_axes):
output_shape = []
for group in grouped_axes:
size = 1
for axis in group:
size *= axes_map[axis]
output_shape.append(size)

return tuple(output_shape)


def _compute_decomposed_shape(input_axes, axes_lengths, axes_map):
reshaped_input_axes = []
reshaped_sizes = []

for axis in input_axes:
if "(" in axis: # Decomposed axis
inner_axes = re.findall(r"\w+", axis)
sizes = [axes_lengths[a] for a in inner_axes]
reshaped_input_axes.extend(inner_axes)
reshaped_sizes.extend(sizes)
else:
reshaped_input_axes.append(axis)
reshaped_sizes.append(axes_map[axis])

return reshaped_sizes


class Rearrange(Operation):
def call(self, tensor, pattern, **axes_lengths):
return rearrange(tensor, pattern, **axes_lengths)

def compute_output_spec(self, tensor, pattern, **axes_lengths):
input_pattern, output_pattern = re.split(r"\s*->\s*", pattern)
input_axes = re.findall(r"\w+|\(.*?\)", input_pattern)
output_axes = re.findall(r"\w+|\(.*?\)", output_pattern)
input_shape = shape(tensor)

axes_map = _create_axes_map(input_axes, input_shape, axes_lengths)
grouped_output_axes = _create_grouped_axes(output_axes)
output_shape = _compute_output_shape(axes_map, grouped_output_axes)

return KerasTensor(shape=output_shape, dtype=tensor.dtype)


@keras_export("keras.ops.rearrange")
def rearrange(tensor, pattern, **axes_lengths):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An op should be able to run on either symbolic Keras tensors or backend native eager tensors. And they should render as a single node in the op graph. This would require creating a class for the op, with a compute_output_spec method (see how other ops are implemented)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, sorry, forgot to add the class.

"""
Rearranges the axes of a Keras tensor according to a specified pattern.
DavidLandup0 marked this conversation as resolved.
Show resolved Hide resolved

Args:
tensor: Input Keras tensor.
pattern: String describing the rearrangement in einops notation.
**axes_lengths: Keyword arguments specifying lengths of axes
when axes decomposition is used.

Returns:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please also add a code example.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added examples to mirror: https://einops.rocks/api/rearrange/

Tensor: A Keras tensor with rearranged axes.

Follows the logic of:

1. If decomposition is needed:
- Reshape to match decomposed dimensions.
DavidLandup0 marked this conversation as resolved.
Show resolved Hide resolved
2. Permute known and inferred axes to match the form of the output.
3. Reshape to match the desired output shape.
"""

if any_symbolic_tensors((tensor,)):
return Rearrange().symbolic_call(tensor, pattern, **axes_lengths)

# Split the input and output patterns
input_pattern, output_pattern = re.split(r"\s*->\s*", pattern)
input_axes = re.findall(r"\w+|\(.*?\)", input_pattern)
output_axes = re.findall(r"\w+|\(.*?\)", output_pattern)
input_shape = shape(tensor)

# Create axes map, and flattened output group
axes_map = _create_axes_map(input_axes, input_shape, axes_lengths)
grouped_output_axes = _create_grouped_axes(output_axes)
flattened_output_axes = _flatten_group(grouped_output_axes)

# 1. Axes decomposition
decomposed_shapes = _compute_decomposed_shape(
input_axes, axes_lengths, axes_map
)
if decomposed_shapes != tensor.shape:
tensor = reshape(tensor, decomposed_shapes)

# 2. Transpose to match target shape
permute_order = _get_transpose_order(input_axes, flattened_output_axes)
tensor = transpose(tensor, permute_order)

# 3. Reshape to final target shape
output_shape = _compute_output_shape(axes_map, grouped_output_axes)
tensor = reshape(tensor, output_shape)

return tensor
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's unusual to inline logic in a src/ops/ op rather than defining it N times in the backends in a backend specific fashion. But it's done for a couple other ops (image ops in particular). It's fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was debating opening N backend operations instead of one here. Though, since it just uses reshape() and transpose(), it gets to use backend-equal implementations by virtue of keras.ops by default. Figured that lower redundancy/copying is preferred in this case, especially since we could look into adding more operations in keras.src.ops.einops in the future.

24 changes: 24 additions & 0 deletions keras/src/ops/einops_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import numpy as np
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add this file to the list of excluded test files for openVINO, or otherwise fix the test. OpenVINO tests are failing.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed numpy in lieu of keras.ops - some of the ops in the tests themselves (all(), etc.) don't seem to be supported by openVINO. Skipped those.


from keras.src import testing
from keras.src.backend.common import keras_tensor
from keras.src.ops.einops import rearrange


class RearrangeTest(testing.TestCase):
def test_basic_rearrangement(self):
x = keras_tensor.KerasTensor((2, 3, 4))
y = rearrange(x, "b c h -> b h c")
self.assertIsInstance(y, keras_tensor.KerasTensor)
self.assertEqual(y.shape, (2, 4, 3))

def test_basic_decomposition_and_rearrangement(self):
x = keras_tensor.KerasTensor((6, 8))
y = rearrange(x, "(h w) c -> h w c", h=2, w=3)
self.assertIsInstance(y, keras_tensor.KerasTensor)
self.assertEqual(y.shape, (2, 3, 8))

def test_static_shape(self):
x = np.ones([2, 3, 4])
y = rearrange(x, "b c h -> b h c")
np.testing.assert_array_equal(y, np.transpose(x, (0, 2, 1)))
Loading