-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scan added #18515
scan added #18515
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR! 👍
keras/backend/tensorflow/core.py
Outdated
length=None, | ||
reverse=False, | ||
unroll=1): | ||
return tf.scan( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please format the code via sh shell/format.sh
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
@@ -204,6 +204,16 @@ def slice_update(inputs, start_indices, updates): | |||
return backend.core.slice_update(inputs, start_indices, updates) | |||
|
|||
|
|||
@keras_export("keras.ops.scan") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There should be a correspond symbolic op class.
Also, please add a docstring for this function.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done.I am not sure about the corresponding symbolic class
@@ -215,6 +215,17 @@ def test_slice_update(self): | |||
outputs = core.slice_update(inputs, start_indices, updates) | |||
self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2])) | |||
|
|||
|
|||
def test_scan(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's use jax.scan
as the reference and let's test consistency with the reference across multiple inputs and dtypes. Let's also test static shape inference.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have added two tests one with int and one with float. The float is not passing due to rounding errors. Should I round it to remove such errors . Also I am not sure about what do you mean by static shape inference sorry.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I am not sure about what do you mean by static shape inference sorry.
This means testing the output shape and dtypes obtained when calling the op on a KerasTensor. There is a test case for this in this file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test class is called CoreOpsStaticShapeTest
and is right above this one. You can take a look at some of the test cases for the class to get an idea of how static shape inference testing works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates!
keras/backend/tensorflow/core.py
Outdated
if reverse: | ||
np.flip(xs) | ||
|
||
init = (init, np.array(0, dtype=init.dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can use tf.zeros_like
instead of np.array
keras/backend/tensorflow/core.py
Outdated
|
||
carry, ys = tf.scan(f, xs, initializer=init) | ||
|
||
ys = ys.numpy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't convert to numpy, keep TF tensors throughout.
keras/backend/tensorflow/core.py
Outdated
if xs is None: | ||
xs = [None] * length | ||
if reverse: | ||
np.flip(xs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use a TF op for this
keras/backend/torch/core.py
Outdated
ys.append(y) | ||
if reverse: | ||
ys.reverse() | ||
return carry, np.stack(ys) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use torch ops, not np
@@ -215,6 +215,17 @@ def test_slice_update(self): | |||
outputs = core.slice_update(inputs, start_indices, updates) | |||
self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2])) | |||
|
|||
|
|||
def test_scan(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also I am not sure about what do you mean by static shape inference sorry.
This means testing the output shape and dtypes obtained when calling the op on a KerasTensor. There is a test case for this in this file.
Codecov ReportAttention:
Additional details and impacted files@@ Coverage Diff @@
## master #18515 +/- ##
==========================================
- Coverage 63.86% 55.65% -8.21%
==========================================
Files 336 339 +3
Lines 34848 35689 +841
Branches 6855 7032 +177
==========================================
- Hits 22255 19864 -2391
- Misses 11121 14335 +3214
- Partials 1472 1490 +18
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
@fchollet I have made some changes |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates!
keras/backend/numpy/core.py
Outdated
@@ -204,6 +204,22 @@ def while_loop( | |||
return loop_vars | |||
|
|||
|
|||
def scan(f, init, xs, length=None, reverse=False, unroll=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unroll arg is unused
keras/backend/tensorflow/core.py
Outdated
@@ -194,6 +194,22 @@ def slice_update(inputs, start_indices, updates): | |||
return dynamic_update_slice(inputs, updates, start_indices) | |||
|
|||
|
|||
def scan(f, init, xs, length=None, reverse=False, unroll=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
keras/backend/torch/core.py
Outdated
@@ -335,6 +335,22 @@ def slice_update(inputs, start_indices, updates): | |||
return outputs | |||
|
|||
|
|||
def scan(f, init, xs, length=None, reverse=False, unroll=1): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here
shell/format.sh
Outdated
@@ -1,5 +1,5 @@ | |||
#!/bin/bash | |||
set -Eeuo pipefail | |||
set -Eeu pipefail |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can revert this
keras/ops/core_test.py
Outdated
ys_op[i] = round(ys_op[0],2) | ||
ys_jax[i] = round(ys_jax[0], 2) | ||
|
||
self.assertListEqual(ys_op, ys_jax) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use assertAllClose
for numerical checks.
Please test more input argument configurations. You could use a parameterize test for that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
New test added. I couldn't use parameterize because I can't pass arguments.I used instead for cycle.
New commit |
Thanks for the update. Can you get the tests to pass? |
I think that test are passing now @fchollet |
keras/ops/core.py
Outdated
return backend.core.scan(f, init, xs, length, reverse, unroll) | ||
|
||
def compute_output_spec(self, f, init, xs, length, reverse, unroll): | ||
return KerasTensor(xs.shape, dtype=xs.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't look correct -- it should match the shape of the output tensors. To start with, there are 2 tensors in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
okay I have added 2 tensors to the return statement but I am not sure if they are correct
keras/ops/core.py
Outdated
@@ -204,6 +212,46 @@ def slice_update(inputs, start_indices, updates): | |||
return backend.core.slice_update(inputs, start_indices, updates) | |||
|
|||
|
|||
@keras_export("keras.ops.scan") | |||
def scan(f, init, xs, length=None, reverse=False): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can just add unroll
everywhere for consistency, even if it's not used by some backends.
@@ -215,6 +215,17 @@ def test_slice_update(self): | |||
outputs = core.slice_update(inputs, start_indices, updates) | |||
self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2])) | |||
|
|||
|
|||
def test_scan(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test class is called CoreOpsStaticShapeTest
and is right above this one. You can take a look at some of the test cases for the class to get an idea of how static shape inference testing works.
Hi @vulkomilev Can you please check @fchollet's comments and resolve conflicts?. Thank you! |
yes I will do it today |
@fchollet I think that I have scan test in CoreOpsStaticShapeTest |
Hi @vulkomilev Can you please resolve conflicts? Thank you! |
@gbaned done |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Congrats on getting the tests to pass! This would still need a Scan
class and tests for static shape inference.
keras/backend/tensorflow/core.py
Outdated
|
||
carry, ys = tf.scan(f, xs, initializer=init) | ||
|
||
# if carry[0].dtype is tf.float64: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If the code is no longer useful, please remove it
keras/ops/core.py
Outdated
A scanned array and a carry element. | ||
""" | ||
if any_symbolic_tensors((init, xs)): | ||
return SliceUpdate().symbolic_call(f, init, xs, length, reverse, unroll) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So this should be Scan
. Then there should be a class Scan
object definition, similar to the other classes in this file
Hi @vulkomilev Can you please check @fchollet's comments and keep us posted ? Thank you! |
@gbaned done |
scan functionality added