Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scan added #18515

Closed
wants to merge 12 commits into from
5 changes: 5 additions & 0 deletions keras/backend/jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,11 @@ def slice_update(inputs, start_indices, updates):
return jax.lax.dynamic_update_slice(inputs, updates, start_indices)


def scan(f, init, xs, length=None, reverse=False, unroll=1):

return jax.lax.scan(f, init, xs, length, reverse, unroll)


def while_loop(
cond,
body,
Expand Down
16 changes: 16 additions & 0 deletions keras/backend/numpy/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,22 @@ def while_loop(
return loop_vars


def scan(f, init, xs, length=None, reverse=False, unroll=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unroll arg is unused


if xs is None:
xs = [None] * length
carry = init
ys = []
if reverse:
xs.reverse()
for x in xs:
carry, y = f(carry, x)
ys.append(y)
if reverse:
ys.reverse()
return carry, np.stack(ys)


def fori_loop(lower, upper, body_fun, init_val):
val = init_val
for i in range(lower, upper):
Expand Down
19 changes: 19 additions & 0 deletions keras/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,25 @@ def slice_update(inputs, start_indices, updates):
return dynamic_update_slice(inputs, updates, start_indices)


def scan(f, init, xs, length=None, reverse=False, unroll=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here


if xs is None:
xs = [None] * length
if reverse:
np.flip(xs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a TF op for this


init = (init, np.array(0, dtype=init.dtype))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use tf.zeros_like instead of np.array


carry, ys = tf.scan(f, xs, initializer=init)

ys = ys.numpy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't convert to numpy, keep TF tensors throughout.

carry = carry.numpy()
if reverse:
np.flip(ys)

return carry[0], ys


def while_loop(
cond,
body,
Expand Down
15 changes: 15 additions & 0 deletions keras/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,21 @@ def slice_update(inputs, start_indices, updates):
return outputs


def scan(f, init, xs, length=None, reverse=False, unroll=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

if xs is None:
xs = [None] * length
carry = init
ys = []
if reverse:
xs.reverse()
for x in xs:
carry, y = f(carry, x)
ys.append(y)
if reverse:
ys.reverse()
return carry, np.stack(ys)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use torch ops, not np



def while_loop(
cond,
body,
Expand Down
48 changes: 48 additions & 0 deletions keras/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,14 @@ def compute_output_spec(self, inputs, start_indices, updates):
return KerasTensor(inputs.shape, dtype=inputs.dtype)


class Scan(Operation):
def call(f, init, xs, length, reverse, unroll):
return backend.core.scan(f, init, xs, length, reverse, unroll)
fchollet marked this conversation as resolved.
Show resolved Hide resolved

def compute_output_spec(self, f, init, xs, length, reverse, unroll):
return KerasTensor(xs.shape, dtype=xs.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look correct -- it should match the shape of the output tensors. To start with, there are 2 tensors in this case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay I have added 2 tensors to the return statement but I am not sure if they are correct



@keras_export("keras.ops.slice_update")
def slice_update(inputs, start_indices, updates):
"""Update an input by slicing in a tensor of updated values.
Expand Down Expand Up @@ -204,6 +212,46 @@ def slice_update(inputs, start_indices, updates):
return backend.core.slice_update(inputs, start_indices, updates)


@keras_export("keras.ops.scan")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a correspond symbolic op class.

Also, please add a docstring for this function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.I am not sure about the corresponding symbolic class

def scan(f, init, xs, length=None, reverse=False, unroll=1):
"""Scan a function over leading array axes while carrying along state.

At a high level, this operation does
carry, y = f(carry, x) and adds the y value to an array which is
returned at the end along with the last carry.The x is taken from
the array of xs. The initial state of the carry can be set by using
the init argument.In the case of None argument for xs a None array
will be initialized by using the length argument like this
xs = [None]*length.Example:

```python
def f(carry, x):
x += 1
carry = x
return carry, x

inputs = keras.ops.scan(f,1,[0,1,2,3],length=None,reverse=False,unroll=1)
```

Args:
f: The function that will be used to scan over the array xs.
init: The initial state of the carry argument for the scan function.
xs: The array that will be scanned over.
length: The length of the None xs array in the case of None xs argument.
reverse: If set to true the xs will be reversed at the start and
the ys array will be reversed at the end.
unroll: Optional positive int specifying, in the underlying operation of
the scan primitive, how many scan iterations to unroll within a single
iteration of a loop(Supported only on jax backend).

Returns:
A scanned array and a carry element.
"""
if any_symbolic_tensors((init, xs)):
return SliceUpdate().symbolic_call(f, init, xs, length, reverse, unroll)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this should be Scan. Then there should be a class Scan object definition, similar to the other classes in this file

return backend.core.scan(f, init, xs, length, reverse, unroll)


class WhileLoop(Operation):
def __init__(self, cond, body, maximum_iterations):
super().__init__()
Expand Down
43 changes: 43 additions & 0 deletions keras/ops/core_test.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import jax
import numpy as np
import pytest

Expand Down Expand Up @@ -215,6 +216,48 @@ def test_slice_update(self):
outputs = core.slice_update(inputs, start_indices, updates)
self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2]))

def test_scan(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use jax.scan as the reference and let's test consistency with the reference across multiple inputs and dtypes. Let's also test static shape inference.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added two tests one with int and one with float. The float is not passing due to rounding errors. Should I round it to remove such errors . Also I am not sure about what do you mean by static shape inference sorry.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I am not sure about what do you mean by static shape inference sorry.

This means testing the output shape and dtypes obtained when calling the op on a KerasTensor. There is a test case for this in this file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test class is called CoreOpsStaticShapeTest and is right above this one. You can take a look at some of the test cases for the class to get an idea of how static shape inference testing works.

def f(carry, x):
if type(carry) is list or type(carry) is tuple:
x += carry[0]
carry = carry[0]
else:

x += carry
return carry, x

init_carr = np.array(1)
xs = np.array([0, 1, 2, 3, 4, 5, 6])

carry_op, ys_op = core.scan(
f, init_carr, xs, length=len(xs), reverse=False
)
carry_jax, ys_jax = jax.lax.scan(
f, init_carr, xs, length=len(xs), reverse=False
)

ys_op = ys_op.tolist()
ys_jax = ys_jax.tolist()

self.assertEqual(carry_op, carry_jax)
self.assertListEqual(ys_op, ys_jax)

init_carr = np.array(1.1)
xs = np.array([0.1, 1.2, 2.3, 3.4, 4.5, 5.6, 6.7])

carry_op, ys_op = core.scan(
f, init_carr, xs, length=len(xs), reverse=True
)
carry_jax, ys_jax = jax.lax.scan(
f, init_carr, xs, length=len(xs), reverse=True
)

ys_op = ys_op.tolist()
ys_jax = ys_jax.tolist()

self.assertEqual(carry_op, carry_jax)
self.assertListEqual(ys_op, ys_jax)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use assertAllClose for numerical checks.

Please test more input argument configurations. You could use a parameterize test for that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New test added. I couldn't use parameterize because I can't pass arguments.I used instead for cycle.


def test_while_loop(self):
def cond(x, y):
return x[0, 0] < 10
Expand Down
2 changes: 1 addition & 1 deletion shell/format.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/bash
set -Eeuo pipefail
set -Eeu pipefail
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can revert this


base_dir=$(dirname $(dirname $0))

Expand Down