-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Keras Ops] Add einops-style rearrange()
to keras.ops
#20733
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #20733 +/- ##
==========================================
+ Coverage 81.95% 81.98% +0.02%
==========================================
Files 553 554 +1
Lines 51446 51549 +103
Branches 7957 7972 +15
==========================================
+ Hits 42164 42262 +98
- Misses 7346 7347 +1
- Partials 1936 1940 +4
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
|
||
|
||
@keras_export("keras.ops.rearrange") | ||
def rearrange(tensor, pattern, **axes_lengths): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
An op should be able to run on either symbolic Keras tensors or backend native eager tensors. And they should render as a single node in the op graph. This would require creating a class for the op, with a compute_output_spec method (see how other ops are implemented)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, sorry, forgot to add the class.
@fchollet Thank you for the feedback! Changes should be reflected in the latest commit. |
from keras.src.ops.operation import Operation | ||
|
||
|
||
def _create_axes_map(axes, input_shape, axes_lengths): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want any documentation or code comments on these?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the update!
**axes_lengths: Keyword arguments specifying lengths of axes | ||
when axes decomposition is used. | ||
|
||
Returns: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please also add a code example.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added examples to mirror: https://einops.rocks/api/rearrange/
output_shape = _compute_output_shape(axes_map, grouped_output_axes) | ||
tensor = reshape(tensor, output_shape) | ||
|
||
return tensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's unusual to inline logic in a src/ops/
op rather than defining it N times in the backends in a backend specific fashion. But it's done for a couple other ops (image ops in particular). It's fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I was debating opening N backend operations instead of one here. Though, since it just uses reshape()
and transpose()
, it gets to use backend-equal implementations by virtue of keras.ops
by default. Figured that lower redundancy/copying is preferred in this case, especially since we could look into adding more operations in keras.src.ops.einops
in the future.
keras/src/ops/einops_test.py
Outdated
@@ -0,0 +1,24 @@ | |||
import numpy as np |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add this file to the list of excluded test files for openVINO, or otherwise fix the test. OpenVINO tests are failing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed numpy in lieu of keras.ops
- some of the ops in the tests themselves (all()
, etc.) don't seem to be supported by openVINO. Skipped those.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thank you
Context
Closes #20332.
Add an
einops
stylerearrange()
operation to thekeras.src.ops.einops
.Making a new module for this makes sense if we plan to add other similar operations and want to re-use some of the private-marked utility methods here or generalize the logic further.
It makes use of solely
reshape()
andtranspose()
, so these natively run on all backends. Furthermore, only these two operations are applied so it should be fine with symbolic tensors (sometimes,einops
doesn't work correctly with symbolic tensors in a computation graph).Comparison with
einops
Direct comparison with
einops
, using the examples from the official documentation:Results in:
Other
TODO: