-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
scan added #18515
scan added #18515
Changes from 3 commits
5594501
4051472
b83a8d8
aed1f55
bcfc456
e756750
f40911a
d5b1e8f
2fae37e
ef216c4
611e974
71d3173
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -194,6 +194,25 @@ def slice_update(inputs, start_indices, updates): | |
return dynamic_update_slice(inputs, updates, start_indices) | ||
|
||
|
||
def scan(f, init, xs, length=None, reverse=False, unroll=1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||
|
||
if xs is None: | ||
xs = [None] * length | ||
if reverse: | ||
np.flip(xs) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use a TF op for this |
||
|
||
init = (init, np.array(0, dtype=init.dtype)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can use |
||
|
||
carry, ys = tf.scan(f, xs, initializer=init) | ||
|
||
ys = ys.numpy() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Don't convert to numpy, keep TF tensors throughout. |
||
carry = carry.numpy() | ||
if reverse: | ||
np.flip(ys) | ||
|
||
return carry[0], ys | ||
|
||
|
||
def while_loop( | ||
cond, | ||
body, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -335,6 +335,21 @@ def slice_update(inputs, start_indices, updates): | |
return outputs | ||
|
||
|
||
def scan(f, init, xs, length=None, reverse=False, unroll=1): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here |
||
if xs is None: | ||
xs = [None] * length | ||
carry = init | ||
ys = [] | ||
if reverse: | ||
xs.reverse() | ||
for x in xs: | ||
carry, y = f(carry, x) | ||
ys.append(y) | ||
if reverse: | ||
ys.reverse() | ||
return carry, np.stack(ys) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use torch ops, not np |
||
|
||
|
||
def while_loop( | ||
cond, | ||
body, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -168,6 +168,14 @@ def compute_output_spec(self, inputs, start_indices, updates): | |
return KerasTensor(inputs.shape, dtype=inputs.dtype) | ||
|
||
|
||
class Scan(Operation): | ||
def call(f, init, xs, length, reverse, unroll): | ||
return backend.core.scan(f, init, xs, length, reverse, unroll) | ||
fchollet marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def compute_output_spec(self, f, init, xs, length, reverse, unroll): | ||
return KerasTensor(xs.shape, dtype=xs.dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't look correct -- it should match the shape of the output tensors. To start with, there are 2 tensors in this case. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay I have added 2 tensors to the return statement but I am not sure if they are correct |
||
|
||
|
||
@keras_export("keras.ops.slice_update") | ||
def slice_update(inputs, start_indices, updates): | ||
"""Update an input by slicing in a tensor of updated values. | ||
|
@@ -204,6 +212,46 @@ def slice_update(inputs, start_indices, updates): | |
return backend.core.slice_update(inputs, start_indices, updates) | ||
|
||
|
||
@keras_export("keras.ops.scan") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There should be a correspond symbolic op class. Also, please add a docstring for this function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done.I am not sure about the corresponding symbolic class |
||
def scan(f, init, xs, length=None, reverse=False, unroll=1): | ||
"""Scan a function over leading array axes while carrying along state. | ||
|
||
At a high level, this operation does | ||
carry, y = f(carry, x) and adds the y value to an array which is | ||
returned at the end along with the last carry.The x is taken from | ||
the array of xs. The initial state of the carry can be set by using | ||
the init argument.In the case of None argument for xs a None array | ||
will be initialized by using the length argument like this | ||
xs = [None]*length.Example: | ||
|
||
```python | ||
def f(carry, x): | ||
x += 1 | ||
carry = x | ||
return carry, x | ||
|
||
inputs = keras.ops.scan(f,1,[0,1,2,3],length=None,reverse=False,unroll=1) | ||
``` | ||
|
||
Args: | ||
f: The function that will be used to scan over the array xs. | ||
init: The initial state of the carry argument for the scan function. | ||
xs: The array that will be scanned over. | ||
length: The length of the None xs array in the case of None xs argument. | ||
reverse: If set to true the xs will be reversed at the start and | ||
the ys array will be reversed at the end. | ||
unroll: Optional positive int specifying, in the underlying operation of | ||
the scan primitive, how many scan iterations to unroll within a single | ||
iteration of a loop(Supported only on jax backend). | ||
|
||
Returns: | ||
A scanned array and a carry element. | ||
""" | ||
if any_symbolic_tensors((init, xs)): | ||
return SliceUpdate().symbolic_call(f, init, xs, length, reverse, unroll) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So this should be |
||
return backend.core.scan(f, init, xs, length, reverse, unroll) | ||
|
||
|
||
class WhileLoop(Operation): | ||
def __init__(self, cond, body, maximum_iterations): | ||
super().__init__() | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import jax | ||
import numpy as np | ||
import pytest | ||
|
||
|
@@ -215,6 +216,48 @@ def test_slice_update(self): | |
outputs = core.slice_update(inputs, start_indices, updates) | ||
self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2])) | ||
|
||
def test_scan(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added two tests one with int and one with float. The float is not passing due to rounding errors. Should I round it to remove such errors . Also I am not sure about what do you mean by static shape inference sorry. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This means testing the output shape and dtypes obtained when calling the op on a KerasTensor. There is a test case for this in this file. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The test class is called |
||
def f(carry, x): | ||
if type(carry) is list or type(carry) is tuple: | ||
x += carry[0] | ||
carry = carry[0] | ||
else: | ||
|
||
x += carry | ||
return carry, x | ||
|
||
init_carr = np.array(1) | ||
xs = np.array([0, 1, 2, 3, 4, 5, 6]) | ||
|
||
carry_op, ys_op = core.scan( | ||
f, init_carr, xs, length=len(xs), reverse=False | ||
) | ||
carry_jax, ys_jax = jax.lax.scan( | ||
f, init_carr, xs, length=len(xs), reverse=False | ||
) | ||
|
||
ys_op = ys_op.tolist() | ||
ys_jax = ys_jax.tolist() | ||
|
||
self.assertEqual(carry_op, carry_jax) | ||
self.assertListEqual(ys_op, ys_jax) | ||
|
||
init_carr = np.array(1.1) | ||
xs = np.array([0.1, 1.2, 2.3, 3.4, 4.5, 5.6, 6.7]) | ||
|
||
carry_op, ys_op = core.scan( | ||
f, init_carr, xs, length=len(xs), reverse=True | ||
) | ||
carry_jax, ys_jax = jax.lax.scan( | ||
f, init_carr, xs, length=len(xs), reverse=True | ||
) | ||
|
||
ys_op = ys_op.tolist() | ||
ys_jax = ys_jax.tolist() | ||
|
||
self.assertEqual(carry_op, carry_jax) | ||
self.assertListEqual(ys_op, ys_jax) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use Please test more input argument configurations. You could use a parameterize test for that. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. New test added. I couldn't use parameterize because I can't pass arguments.I used instead for cycle. |
||
|
||
def test_while_loop(self): | ||
def cond(x, y): | ||
return x[0, 0] < 10 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
#!/bin/bash | ||
set -Eeuo pipefail | ||
set -Eeu pipefail | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can revert this |
||
|
||
base_dir=$(dirname $(dirname $0)) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The unroll arg is unused