Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scan added #18515

Closed
wants to merge 12 commits into from
Closed

scan added #18515

wants to merge 12 commits into from

Conversation

vulkomilev
Copy link

scan functionality added

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the PR! 👍

length=None,
reverse=False,
unroll=1):
return tf.scan(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please format the code via sh shell/format.sh

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@@ -204,6 +204,16 @@ def slice_update(inputs, start_indices, updates):
return backend.core.slice_update(inputs, start_indices, updates)


@keras_export("keras.ops.scan")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be a correspond symbolic op class.

Also, please add a docstring for this function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.I am not sure about the corresponding symbolic class

@@ -215,6 +215,17 @@ def test_slice_update(self):
outputs = core.slice_update(inputs, start_indices, updates)
self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2]))


def test_scan(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use jax.scan as the reference and let's test consistency with the reference across multiple inputs and dtypes. Let's also test static shape inference.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added two tests one with int and one with float. The float is not passing due to rounding errors. Should I round it to remove such errors . Also I am not sure about what do you mean by static shape inference sorry.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I am not sure about what do you mean by static shape inference sorry.

This means testing the output shape and dtypes obtained when calling the op on a KerasTensor. There is a test case for this in this file.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test class is called CoreOpsStaticShapeTest and is right above this one. You can take a look at some of the test cases for the class to get an idea of how static shape inference testing works.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates!

if reverse:
np.flip(xs)

init = (init, np.array(0, dtype=init.dtype))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use tf.zeros_like instead of np.array


carry, ys = tf.scan(f, xs, initializer=init)

ys = ys.numpy()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't convert to numpy, keep TF tensors throughout.

if xs is None:
xs = [None] * length
if reverse:
np.flip(xs)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use a TF op for this

ys.append(y)
if reverse:
ys.reverse()
return carry, np.stack(ys)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use torch ops, not np

@@ -215,6 +215,17 @@ def test_slice_update(self):
outputs = core.slice_update(inputs, start_indices, updates)
self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2]))


def test_scan(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I am not sure about what do you mean by static shape inference sorry.

This means testing the output shape and dtypes obtained when calling the op on a KerasTensor. There is a test case for this in this file.

@codecov-commenter
Copy link

codecov-commenter commented Oct 10, 2023

Codecov Report

Attention: 51 lines in your changes are missing coverage. Please review.

Comparison is base (10252a9) 63.86% compared to head (71d3173) 55.65%.
Report is 84 commits behind head on master.

Files Patch % Lines
keras/backend/tensorflow/core.py 5.26% 18 Missing ⚠️
keras/backend/torch/core.py 0.00% 17 Missing ⚠️
keras/backend/numpy/core.py 52.17% 5 Missing and 6 partials ⚠️
keras/ops/core.py 60.00% 3 Missing and 1 partial ⚠️
keras/backend/jax/core.py 50.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18515      +/-   ##
==========================================
- Coverage   63.86%   55.65%   -8.21%     
==========================================
  Files         336      339       +3     
  Lines       34848    35689     +841     
  Branches     6855     7032     +177     
==========================================
- Hits        22255    19864    -2391     
- Misses      11121    14335    +3214     
- Partials     1472     1490      +18     
Flag Coverage Δ
keras 55.65% <28.16%> (-8.21%) ⬇️
keras-numpy 55.65% <28.16%> (?)
keras-torch ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@vulkomilev
Copy link
Author

@fchollet I have made some changes

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates!

@@ -204,6 +204,22 @@ def while_loop(
return loop_vars


def scan(f, init, xs, length=None, reverse=False, unroll=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The unroll arg is unused

@@ -194,6 +194,22 @@ def slice_update(inputs, start_indices, updates):
return dynamic_update_slice(inputs, updates, start_indices)


def scan(f, init, xs, length=None, reverse=False, unroll=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

@@ -335,6 +335,22 @@ def slice_update(inputs, start_indices, updates):
return outputs


def scan(f, init, xs, length=None, reverse=False, unroll=1):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here

shell/format.sh Outdated
@@ -1,5 +1,5 @@
#!/bin/bash
set -Eeuo pipefail
set -Eeu pipefail
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can revert this

ys_op[i] = round(ys_op[0],2)
ys_jax[i] = round(ys_jax[0], 2)

self.assertListEqual(ys_op, ys_jax)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use assertAllClose for numerical checks.

Please test more input argument configurations. You could use a parameterize test for that.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

New test added. I couldn't use parameterize because I can't pass arguments.I used instead for cycle.

@vulkomilev
Copy link
Author

New commit

@sachinprasadhs sachinprasadhs added the stat:awaiting keras-eng Awaiting response from Keras engineer label Oct 25, 2023
@fchollet
Copy link
Collaborator

fchollet commented Nov 2, 2023

Thanks for the update. Can you get the tests to pass?

@vulkomilev
Copy link
Author

I think that test are passing now @fchollet

return backend.core.scan(f, init, xs, length, reverse, unroll)

def compute_output_spec(self, f, init, xs, length, reverse, unroll):
return KerasTensor(xs.shape, dtype=xs.dtype)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't look correct -- it should match the shape of the output tensors. To start with, there are 2 tensors in this case.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay I have added 2 tensors to the return statement but I am not sure if they are correct

keras/ops/core.py Show resolved Hide resolved
@@ -204,6 +212,46 @@ def slice_update(inputs, start_indices, updates):
return backend.core.slice_update(inputs, start_indices, updates)


@keras_export("keras.ops.scan")
def scan(f, init, xs, length=None, reverse=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can just add unroll everywhere for consistency, even if it's not used by some backends.

@@ -215,6 +215,17 @@ def test_slice_update(self):
outputs = core.slice_update(inputs, start_indices, updates)
self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2]))


def test_scan(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test class is called CoreOpsStaticShapeTest and is right above this one. You can take a look at some of the test cases for the class to get an idea of how static shape inference testing works.

@gbaned
Copy link
Collaborator

gbaned commented Nov 15, 2023

Hi @vulkomilev Can you please check @fchollet's comments and resolve conflicts?. Thank you!

@gbaned gbaned added stat:awaiting response from contributor and removed stat:awaiting keras-eng Awaiting response from Keras engineer labels Nov 15, 2023
@vulkomilev
Copy link
Author

yes I will do it today

@vulkomilev
Copy link
Author

@fchollet I think that I have scan test in CoreOpsStaticShapeTest

@sachinprasadhs sachinprasadhs added the stat:awaiting keras-eng Awaiting response from Keras engineer label Nov 15, 2023
@gbaned
Copy link
Collaborator

gbaned commented Dec 1, 2023

Hi @vulkomilev Can you please resolve conflicts? Thank you!

@gbaned gbaned added stat:awaiting response from contributor and removed stat:awaiting keras-eng Awaiting response from Keras engineer labels Dec 1, 2023
@vulkomilev
Copy link
Author

@gbaned done

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congrats on getting the tests to pass! This would still need a Scan class and tests for static shape inference.


carry, ys = tf.scan(f, xs, initializer=init)

# if carry[0].dtype is tf.float64:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the code is no longer useful, please remove it

A scanned array and a carry element.
"""
if any_symbolic_tensors((init, xs)):
return SliceUpdate().symbolic_call(f, init, xs, length, reverse, unroll)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this should be Scan. Then there should be a class Scan object definition, similar to the other classes in this file

@gbaned
Copy link
Collaborator

gbaned commented Dec 29, 2023

Hi @vulkomilev Can you please check @fchollet's comments and keep us posted ? Thank you!

@vulkomilev
Copy link
Author

@gbaned done

@gbaned gbaned requested a review from fchollet January 9, 2024 00:42
@sachinprasadhs sachinprasadhs added the stat:awaiting keras-eng Awaiting response from Keras engineer label Jan 17, 2024
@vulkomilev vulkomilev closed this by deleting the head repository Apr 18, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
size:M stat:awaiting keras-eng Awaiting response from Keras engineer
Projects
Status: Closed/Rejected
Development

Successfully merging this pull request may close these issues.

5 participants