-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Keras Ops] Add einops-style rearrange()
to keras.ops
#20733
Changes from 6 commits
a48de7e
46defa2
cd39981
35f520e
aa07f7d
a7be939
586e143
29980f3
7be977b
76911c7
90078cf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,157 @@ | ||
import re | ||
|
||
from keras.src.api_export import keras_export | ||
from keras.src.backend import KerasTensor | ||
from keras.src.backend import any_symbolic_tensors | ||
from keras.src.ops.core import shape | ||
from keras.src.ops.numpy import prod | ||
from keras.src.ops.numpy import reshape | ||
from keras.src.ops.numpy import transpose | ||
from keras.src.ops.operation import Operation | ||
|
||
|
||
def _create_axes_map(axes, input_shape, axes_lengths): | ||
axes_map = {} | ||
|
||
for axis, dim in zip(axes, input_shape): | ||
# Check for grouped axes pattern, e.g., "(h1 h)" | ||
grouped_axes = re.match(r"\(([\w\s]+)\)", axis) | ||
|
||
if grouped_axes: | ||
inner_axes = grouped_axes.group(1).split() | ||
known_axes = [a for a in inner_axes if a in axes_lengths] | ||
inferred_axes = [a for a in inner_axes if a not in axes_lengths] | ||
|
||
if inferred_axes: | ||
inferred_axis = inferred_axes[0] | ||
known_product = prod([axes_lengths[a] for a in known_axes]) | ||
axes_lengths[inferred_axis] = dim // known_product | ||
|
||
axes_map.update({a: axes_lengths[a] for a in inner_axes}) | ||
else: | ||
axes_map[axis] = dim | ||
|
||
return axes_map | ||
|
||
|
||
def _create_grouped_axes(axes): | ||
grouped_output_axes = [] | ||
for axis in axes: | ||
grouped_axes = re.match(r"\(([\w\s]+)\)", axis) | ||
|
||
if grouped_axes: | ||
inner_axes = grouped_axes.group(1).split() | ||
grouped_output_axes.append(inner_axes) | ||
else: | ||
grouped_output_axes.append([axis]) | ||
|
||
return grouped_output_axes | ||
|
||
|
||
def _flatten_group(axes): | ||
return [x for xs in axes for x in xs] | ||
|
||
|
||
def _get_transpose_order(from_shape, to_shape): | ||
flattened_from_shape = _flatten_group(_create_grouped_axes(from_shape)) | ||
|
||
return [flattened_from_shape.index(dim) for dim in to_shape] | ||
|
||
|
||
def _compute_output_shape(axes_map, grouped_axes): | ||
output_shape = [] | ||
for group in grouped_axes: | ||
size = 1 | ||
for axis in group: | ||
size *= axes_map[axis] | ||
output_shape.append(size) | ||
|
||
return tuple(output_shape) | ||
|
||
|
||
def _compute_decomposed_shape(input_axes, axes_lengths, axes_map): | ||
reshaped_input_axes = [] | ||
reshaped_sizes = [] | ||
|
||
for axis in input_axes: | ||
if "(" in axis: # Decomposed axis | ||
inner_axes = re.findall(r"\w+", axis) | ||
sizes = [axes_lengths[a] for a in inner_axes] | ||
reshaped_input_axes.extend(inner_axes) | ||
reshaped_sizes.extend(sizes) | ||
else: | ||
reshaped_input_axes.append(axis) | ||
reshaped_sizes.append(axes_map[axis]) | ||
|
||
return reshaped_sizes | ||
|
||
|
||
class Rearrange(Operation): | ||
def call(self, tensor, pattern, **axes_lengths): | ||
return rearrange(tensor, pattern, **axes_lengths) | ||
|
||
def compute_output_spec(self, tensor, pattern, **axes_lengths): | ||
input_pattern, output_pattern = re.split(r"\s*->\s*", pattern) | ||
input_axes = re.findall(r"\w+|\(.*?\)", input_pattern) | ||
output_axes = re.findall(r"\w+|\(.*?\)", output_pattern) | ||
input_shape = shape(tensor) | ||
|
||
axes_map = _create_axes_map(input_axes, input_shape, axes_lengths) | ||
grouped_output_axes = _create_grouped_axes(output_axes) | ||
output_shape = _compute_output_shape(axes_map, grouped_output_axes) | ||
|
||
return KerasTensor(shape=output_shape, dtype=tensor.dtype) | ||
|
||
|
||
@keras_export("keras.ops.rearrange") | ||
def rearrange(tensor, pattern, **axes_lengths): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An op should be able to run on either symbolic Keras tensors or backend native eager tensors. And they should render as a single node in the op graph. This would require creating a class for the op, with a compute_output_spec method (see how other ops are implemented) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, sorry, forgot to add the class. |
||
""" | ||
Rearranges the axes of a Keras tensor according to a specified pattern. | ||
DavidLandup0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Args: | ||
tensor: Input Keras tensor. | ||
pattern: String describing the rearrangement in einops notation. | ||
**axes_lengths: Keyword arguments specifying lengths of axes | ||
when axes decomposition is used. | ||
|
||
Returns: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please also add a code example. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added examples to mirror: https://einops.rocks/api/rearrange/ |
||
Tensor: A Keras tensor with rearranged axes. | ||
|
||
Follows the logic of: | ||
|
||
1. If decomposition is needed: | ||
- Reshape to match decomposed dimensions. | ||
DavidLandup0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
2. Permute known and inferred axes to match the form of the output. | ||
3. Reshape to match the desired output shape. | ||
""" | ||
|
||
if any_symbolic_tensors((tensor,)): | ||
return Rearrange().symbolic_call(tensor, pattern, **axes_lengths) | ||
|
||
# Split the input and output patterns | ||
input_pattern, output_pattern = re.split(r"\s*->\s*", pattern) | ||
input_axes = re.findall(r"\w+|\(.*?\)", input_pattern) | ||
output_axes = re.findall(r"\w+|\(.*?\)", output_pattern) | ||
input_shape = shape(tensor) | ||
|
||
# Create axes map, and flattened output group | ||
axes_map = _create_axes_map(input_axes, input_shape, axes_lengths) | ||
grouped_output_axes = _create_grouped_axes(output_axes) | ||
flattened_output_axes = _flatten_group(grouped_output_axes) | ||
|
||
# 1. Axes decomposition | ||
decomposed_shapes = _compute_decomposed_shape( | ||
input_axes, axes_lengths, axes_map | ||
) | ||
if decomposed_shapes != tensor.shape: | ||
tensor = reshape(tensor, decomposed_shapes) | ||
|
||
# 2. Transpose to match target shape | ||
permute_order = _get_transpose_order(input_axes, flattened_output_axes) | ||
tensor = transpose(tensor, permute_order) | ||
|
||
# 3. Reshape to final target shape | ||
output_shape = _compute_output_shape(axes_map, grouped_output_axes) | ||
tensor = reshape(tensor, output_shape) | ||
|
||
return tensor | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's unusual to inline logic in a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I was debating opening N backend operations instead of one here. Though, since it just uses |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import numpy as np | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please add this file to the list of excluded test files for openVINO, or otherwise fix the test. OpenVINO tests are failing. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Removed numpy in lieu of |
||
|
||
from keras.src import testing | ||
from keras.src.backend.common import keras_tensor | ||
from keras.src.ops.einops import rearrange | ||
|
||
|
||
class RearrangeTest(testing.TestCase): | ||
def test_basic_rearrangement(self): | ||
x = keras_tensor.KerasTensor((2, 3, 4)) | ||
y = rearrange(x, "b c h -> b h c") | ||
self.assertIsInstance(y, keras_tensor.KerasTensor) | ||
self.assertEqual(y.shape, (2, 4, 3)) | ||
|
||
def test_basic_decomposition_and_rearrangement(self): | ||
x = keras_tensor.KerasTensor((6, 8)) | ||
y = rearrange(x, "(h w) c -> h w c", h=2, w=3) | ||
self.assertIsInstance(y, keras_tensor.KerasTensor) | ||
self.assertEqual(y.shape, (2, 3, 8)) | ||
|
||
def test_static_shape(self): | ||
x = np.ones([2, 3, 4]) | ||
y = rearrange(x, "b c h -> b h c") | ||
np.testing.assert_array_equal(y, np.transpose(x, (0, 2, 1))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want any documentation or code comments on these?