Skip to content

Commit

Permalink
[Keras Ops] Add einops-style rearrange() to keras.ops (#20733)
Browse files Browse the repository at this point in the history
* Add einops-style rearrange to keras.ops.einops

* Address PR comments

* Add any_symbolic_tensors() check on call

* Pass all arguments in symbolic_call

* Remove constructor and fix call

* Add basic couple of tests

* Add more tests

* Add examples to docstring

* Skip tests if backend is openvino

* Remove numpy from tests in lieu of keras.ops

* Skip tests for openvino when the testing operation isn't supported
DavidLandup0 authored Jan 15, 2025
1 parent e37ee79 commit 617b821
Showing 4 changed files with 242 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@
from keras.src.ops.core import unstack
from keras.src.ops.core import vectorized_map
from keras.src.ops.core import while_loop
from keras.src.ops.einops import rearrange
from keras.src.ops.linalg import cholesky
from keras.src.ops.linalg import det
from keras.src.ops.linalg import eig
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -30,6 +30,7 @@
from keras.src.ops.core import unstack
from keras.src.ops.core import vectorized_map
from keras.src.ops.core import while_loop
from keras.src.ops.einops import rearrange
from keras.src.ops.linalg import cholesky
from keras.src.ops.linalg import det
from keras.src.ops.linalg import eig
189 changes: 189 additions & 0 deletions keras/src/ops/einops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import re

from keras.src.api_export import keras_export
from keras.src.backend import KerasTensor
from keras.src.backend import any_symbolic_tensors
from keras.src.ops.core import shape
from keras.src.ops.numpy import prod
from keras.src.ops.numpy import reshape
from keras.src.ops.numpy import transpose
from keras.src.ops.operation import Operation


def _create_axes_map(axes, input_shape, axes_lengths):
axes_map = {}

for axis, dim in zip(axes, input_shape):
# Check for grouped axes pattern, e.g., "(h1 h)"
grouped_axes = re.match(r"\(([\w\s]+)\)", axis)

if grouped_axes:
inner_axes = grouped_axes.group(1).split()
known_axes = [a for a in inner_axes if a in axes_lengths]
inferred_axes = [a for a in inner_axes if a not in axes_lengths]

if inferred_axes:
inferred_axis = inferred_axes[0]
known_product = prod([axes_lengths[a] for a in known_axes])
axes_lengths[inferred_axis] = dim // known_product

axes_map.update({a: axes_lengths[a] for a in inner_axes})
else:
axes_map[axis] = dim

return axes_map


def _create_grouped_axes(axes):
grouped_output_axes = []
for axis in axes:
grouped_axes = re.match(r"\(([\w\s]+)\)", axis)

if grouped_axes:
inner_axes = grouped_axes.group(1).split()
grouped_output_axes.append(inner_axes)
else:
grouped_output_axes.append([axis])

return grouped_output_axes


def _flatten_group(axes):
return [x for xs in axes for x in xs]


def _get_transpose_order(from_shape, to_shape):
flattened_from_shape = _flatten_group(_create_grouped_axes(from_shape))

return [flattened_from_shape.index(dim) for dim in to_shape]


def _compute_output_shape(axes_map, grouped_axes):
output_shape = []
for group in grouped_axes:
size = 1
for axis in group:
size *= axes_map[axis]
output_shape.append(size)

return tuple(output_shape)


def _compute_decomposed_shape(input_axes, axes_lengths, axes_map):
reshaped_input_axes = []
reshaped_sizes = []

for axis in input_axes:
if "(" in axis: # Decomposed axis
inner_axes = re.findall(r"\w+", axis)
sizes = [axes_lengths[a] for a in inner_axes]
reshaped_input_axes.extend(inner_axes)
reshaped_sizes.extend(sizes)
else:
reshaped_input_axes.append(axis)
reshaped_sizes.append(axes_map[axis])

return reshaped_sizes


class Rearrange(Operation):
def call(self, tensor, pattern, **axes_lengths):
return rearrange(tensor, pattern, **axes_lengths)

def compute_output_spec(self, tensor, pattern, **axes_lengths):
input_pattern, output_pattern = re.split(r"\s*->\s*", pattern)
input_axes = re.findall(r"\w+|\(.*?\)", input_pattern)
output_axes = re.findall(r"\w+|\(.*?\)", output_pattern)
input_shape = shape(tensor)

axes_map = _create_axes_map(input_axes, input_shape, axes_lengths)
grouped_output_axes = _create_grouped_axes(output_axes)
output_shape = _compute_output_shape(axes_map, grouped_output_axes)

return KerasTensor(shape=output_shape, dtype=tensor.dtype)


@keras_export("keras.ops.rearrange")
def rearrange(tensor, pattern, **axes_lengths):
"""Rearranges the axes of a Keras tensor according to a specified pattern,
einops-style.
Args:
tensor: Input Keras tensor.
pattern: String describing the rearrangement in einops notation.
**axes_lengths: Keyword arguments specifying lengths of axes
when axes decomposition is used.
Returns:
Tensor: A Keras tensor with rearranged axes.
Follows the logic of:
1. If decomposition is needed, reshape to match decomposed dimensions.
2. Permute known and inferred axes to match the form of the output.
3. Reshape to match the desired output shape.
Example Usage:
```
>>> import numpy as np
>>> from keras.ops import rearrange
>>> images = np.random.rand(32, 30, 40, 3) # BHWC format
# Reordering to BCHW
>>> rearrange(images, 'b h w c -> b c h w').shape
TensorShape([32, 3, 30, 40])
# "Merge" along first axis - concat images from a batch
>>> rearrange(images, 'b h w c -> (b h) w c').shape
TensorShape([960, 40, 3])
# "Merge" along second axis - concat images horizontally
>>> rearrange(images, 'b h w c -> h (b w) c').shape
TensorShape([30, 1280, 3])
# Flatten images into a CHW vector
>>> rearrange(images, 'b h w c -> b (c h w)').shape
TensorShape([32, 3600])
# Decompose H and W axes into 4 smaller patches
>>> rearrange(images, 'b (h1 h) (w1 w) c -> (b h1 w1) h w c', h1=2, w1=2).shape
TensorShape([128, 15, 20, 3])
# Space-to-depth decomposition of input axes
>>> rearrange(images, 'b (h h1) (w w1) c -> b h w (c h1 w1)', h1=2, w1=2).shape
TensorShape([32, 15, 20, 12])
```
""" # noqa: E501

if any_symbolic_tensors((tensor,)):
return Rearrange().symbolic_call(tensor, pattern, **axes_lengths)

# Split the input and output patterns
input_pattern, output_pattern = re.split(r"\s*->\s*", pattern)
input_axes = re.findall(r"\w+|\(.*?\)", input_pattern)
output_axes = re.findall(r"\w+|\(.*?\)", output_pattern)
input_shape = shape(tensor)

# Create axes map, and flattened output group
axes_map = _create_axes_map(input_axes, input_shape, axes_lengths)
grouped_output_axes = _create_grouped_axes(output_axes)
flattened_output_axes = _flatten_group(grouped_output_axes)

# 1. Axes decomposition
decomposed_shapes = _compute_decomposed_shape(
input_axes, axes_lengths, axes_map
)
if decomposed_shapes != tensor.shape:
tensor = reshape(tensor, decomposed_shapes)

# 2. Transpose to match target shape
permute_order = _get_transpose_order(input_axes, flattened_output_axes)
tensor = transpose(tensor, permute_order)

# 3. Reshape to final target shape
output_shape = _compute_output_shape(axes_map, grouped_output_axes)
tensor = reshape(tensor, output_shape)

return tensor
51 changes: 51 additions & 0 deletions keras/src/ops/einops_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from conftest import skip_if_backend
from keras.src import ops
from keras.src import testing
from keras.src.backend.common import keras_tensor
from keras.src.ops.einops import rearrange


class RearrangeTest(testing.TestCase):
def test_basic_rearrangement_symbolic(self):
x = keras_tensor.KerasTensor((2, 3, 4))
y = rearrange(x, "b c h -> b h c")
self.assertIsInstance(y, keras_tensor.KerasTensor)
self.assertEqual(y.shape, (2, 4, 3))

@skip_if_backend("openvino", "Test operation not supported by openvino")
def test_basic_rearrangement(self):
x = ops.random.uniform((2, 3, 4))
y = rearrange(x, "b c h -> b h c")
self.assertEqual(y.shape, (2, 4, 3))
self.assertTrue(ops.all(ops.equal(y, ops.transpose(x, (0, 2, 1)))))

@skip_if_backend("openvino", "Test operation not supported by openvino")
def test_output_composition(self):
x = ops.random.uniform((2, 4, 4, 3))
y = rearrange(x, "b h w c -> (b h) w c")
target_shape = (8, 4, 3)
self.assertEqual(y.shape, target_shape)
self.assertTrue(ops.all(ops.equal(y, ops.reshape(x, (8, 4, 3)))))

def test_basic_decomposition_and_rearrangement_symbolic(self):
x = keras_tensor.KerasTensor((6, 8))
y = rearrange(x, "(h w) c -> h w c", h=2, w=3)
self.assertIsInstance(y, keras_tensor.KerasTensor)
self.assertEqual(y.shape, (2, 3, 8))

def test_basic_decomposition_and_rearrangement(self):
x = ops.random.uniform((6, 8))
y = rearrange(x, "(h w) c -> h w c", h=2, w=3)
self.assertEqual(y.shape, (2, 3, 8))

@skip_if_backend("openvino", "Test operation not supported by openvino")
def test_unchanged_shape(self):
x = ops.ones([2, 3, 4])
y = rearrange(x, "b h c -> b h c")
self.assertTrue(ops.all(ops.equal(y, x)))
self.assertTrue(x.shape, y.shape)

def test_unchanged_shape_symbolic(self):
x = keras_tensor.KerasTensor((2, 3, 4))
y = rearrange(x, "b h c -> b h c")
self.assertTrue(x.shape, y.shape)

0 comments on commit 617b821

Please sign in to comment.