From 559450176a759c98efb67e7bad0dcedd9d3df14f Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Wed, 27 Sep 2023 20:25:57 +0300 Subject: [PATCH 01/11] scan added --- keras/backend/jax/core.py | 9 +++++++++ keras/backend/numpy/core.py | 20 ++++++++++++++++++++ keras/backend/tensorflow/core.py | 13 +++++++++++++ keras/backend/torch/core.py | 20 ++++++++++++++++++++ keras/ops/core.py | 10 ++++++++++ keras/ops/core_test.py | 11 +++++++++++ 6 files changed, 83 insertions(+) diff --git a/keras/backend/jax/core.py b/keras/backend/jax/core.py index d493dae6ca3..e472ccfc1da 100644 --- a/keras/backend/jax/core.py +++ b/keras/backend/jax/core.py @@ -263,6 +263,15 @@ def slice_update(inputs, start_indices, updates): return jax.lax.dynamic_update_slice(inputs, updates, start_indices) +def scan(f, + init, + xs, + length=None, + reverse=False, + unroll=1): + return jax.lax.scan(f, init, xs, length, reverse, unroll) + + def while_loop( cond, body, diff --git a/keras/backend/numpy/core.py b/keras/backend/numpy/core.py index 1b3e8e4086f..fe1b5c56b95 100644 --- a/keras/backend/numpy/core.py +++ b/keras/backend/numpy/core.py @@ -204,6 +204,26 @@ def while_loop( return loop_vars +def scan(f, + init, + xs, + length=None, + reverse=False, + unroll=1): + if xs is None: + xs = [None] * length + carry = init + ys = [] + if reverse: + xs.reverse() + for x in xs: + carry, y = f(carry, x) + ys.append(y) + if reverse: + ys.reverse() + return carry, np.stack(ys) + + def fori_loop(lower, upper, body_fun, init_val): val = init_val for i in range(lower, upper): diff --git a/keras/backend/tensorflow/core.py b/keras/backend/tensorflow/core.py index cbf2030b810..7c98ad17d3c 100644 --- a/keras/backend/tensorflow/core.py +++ b/keras/backend/tensorflow/core.py @@ -194,6 +194,19 @@ def slice_update(inputs, start_indices, updates): return dynamic_update_slice(inputs, updates, start_indices) +def scan(f, + init, + xs, + length=None, + reverse=False, + unroll=1): + return tf.scan( + f, + xs, + initializer=init, + reverse=reverse) + + def while_loop( cond, body, diff --git a/keras/backend/torch/core.py b/keras/backend/torch/core.py index 47bf0674811..2c717771071 100644 --- a/keras/backend/torch/core.py +++ b/keras/backend/torch/core.py @@ -335,6 +335,26 @@ def slice_update(inputs, start_indices, updates): return outputs +def scan(f, + init, + xs, + length=None, + reverse=False, + unroll=1): + if xs is None: + xs = [None] * length + carry = init + ys = [] + if reverse: + xs.reverse() + for x in xs: + carry, y = f(carry, x) + ys.append(y) + if reverse: + ys.reverse() + return carry, np.stack(ys) + + def while_loop( cond, body, diff --git a/keras/ops/core.py b/keras/ops/core.py index 2ba549d4ae5..a7789958a8e 100644 --- a/keras/ops/core.py +++ b/keras/ops/core.py @@ -204,6 +204,16 @@ def slice_update(inputs, start_indices, updates): return backend.core.slice_update(inputs, start_indices, updates) +@keras_export("keras.ops.scan") +def scan(f, + init, + xs, + length=None, + reverse=False, + unroll=1): + return backend.core.scan(f,init,xs,length,reverse,unroll) + + class WhileLoop(Operation): def __init__(self, cond, body, maximum_iterations): super().__init__() diff --git a/keras/ops/core_test.py b/keras/ops/core_test.py index 07c79aee6cb..af751efe9b0 100644 --- a/keras/ops/core_test.py +++ b/keras/ops/core_test.py @@ -215,6 +215,17 @@ def test_slice_update(self): outputs = core.slice_update(inputs, start_indices, updates) self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2])) + + def test_scan(self): + x = [0,1,2,3,4,5,6] + x_target = [1, 2, 3, 4, 5, 6, 7] + def f(carry,x): + x +=carry + return carry,x + carr , y = core.scan(f,1,x) + self.assertEqual(x_target, y) + + def test_while_loop(self): def cond(x, y): return x[0, 0] < 10 From 405147204ded4ca589fcc82e16f3722c9b26f74c Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Sat, 30 Sep 2023 10:18:41 +0300 Subject: [PATCH 02/11] work in progress --- keras/backend/jax/core.py | 8 +--- keras/backend/numpy/core.py | 8 +--- keras/backend/tensorflow/core.py | 27 +++++++------ keras/backend/torch/core.py | 7 +--- keras/ops/core.py | 51 ++++++++++++++++++++---- keras/ops/core_test.py | 67 ++++++++++++++++++++++++++++---- shell/format.sh | 2 +- 7 files changed, 125 insertions(+), 45 deletions(-) diff --git a/keras/backend/jax/core.py b/keras/backend/jax/core.py index e472ccfc1da..007a5762165 100644 --- a/keras/backend/jax/core.py +++ b/keras/backend/jax/core.py @@ -263,12 +263,8 @@ def slice_update(inputs, start_indices, updates): return jax.lax.dynamic_update_slice(inputs, updates, start_indices) -def scan(f, - init, - xs, - length=None, - reverse=False, - unroll=1): +def scan(f, init, xs, length=None, reverse=False, unroll=1): + print('here jax') return jax.lax.scan(f, init, xs, length, reverse, unroll) diff --git a/keras/backend/numpy/core.py b/keras/backend/numpy/core.py index fe1b5c56b95..9411ed2a9e5 100644 --- a/keras/backend/numpy/core.py +++ b/keras/backend/numpy/core.py @@ -204,12 +204,8 @@ def while_loop( return loop_vars -def scan(f, - init, - xs, - length=None, - reverse=False, - unroll=1): +def scan(f, init, xs, length=None, reverse=False, unroll=1): + print('here numpy') if xs is None: xs = [None] * length carry = init diff --git a/keras/backend/tensorflow/core.py b/keras/backend/tensorflow/core.py index 7c98ad17d3c..cdb9f2bdc47 100644 --- a/keras/backend/tensorflow/core.py +++ b/keras/backend/tensorflow/core.py @@ -194,17 +194,22 @@ def slice_update(inputs, start_indices, updates): return dynamic_update_slice(inputs, updates, start_indices) -def scan(f, - init, - xs, - length=None, - reverse=False, - unroll=1): - return tf.scan( - f, - xs, - initializer=init, - reverse=reverse) +def scan(f, init, xs, length=None, reverse=False, unroll=1): + + if xs is None: + xs = [None] * length + if reverse: + np.flip(xs) + + init = (init, np.array(0,dtype=init.dtype)) + + carry,ys = tf.scan(f, xs, initializer=init) + + ys = ys.numpy() + if reverse: + np.flip(ys) + + return carry.numpy()[-1], ys def while_loop( diff --git a/keras/backend/torch/core.py b/keras/backend/torch/core.py index 2c717771071..c77070fc73f 100644 --- a/keras/backend/torch/core.py +++ b/keras/backend/torch/core.py @@ -335,12 +335,7 @@ def slice_update(inputs, start_indices, updates): return outputs -def scan(f, - init, - xs, - length=None, - reverse=False, - unroll=1): +def scan(f, init, xs, length=None, reverse=False, unroll=1): if xs is None: xs = [None] * length carry = init diff --git a/keras/ops/core.py b/keras/ops/core.py index a7789958a8e..2b424bc3c4a 100644 --- a/keras/ops/core.py +++ b/keras/ops/core.py @@ -168,6 +168,13 @@ def compute_output_spec(self, inputs, start_indices, updates): return KerasTensor(inputs.shape, dtype=inputs.dtype) +class Scan(Operation): + def call(f, init, xs, length, reverse, unroll): + return backend.core.scan(f, init, xs, length, reverse, unroll) + + def compute_output_spec(self, f, init, xs, length, reverse, unroll): + return KerasTensor(xs.shape, dtype=xs.dtype) + @keras_export("keras.ops.slice_update") def slice_update(inputs, start_indices, updates): """Update an input by slicing in a tensor of updated values. @@ -205,13 +212,43 @@ def slice_update(inputs, start_indices, updates): @keras_export("keras.ops.scan") -def scan(f, - init, - xs, - length=None, - reverse=False, - unroll=1): - return backend.core.scan(f,init,xs,length,reverse,unroll) +def scan(f, init, xs, length=None, reverse=False, unroll=1): + """Scan a function over leading array axes while carrying along state. + + At a high level, this operation does + carry, y = f(carry, x) and adds the y value to an array which is + returned at the end along with the last carry.The x is taken from + the array of xs. The initial state of the carry can be set by using + the init argument.In the case of None argument for xs a None array + will be initialized by using the length argument like this + xs = [None]*length.Example: + + ```python + def f(carry, x): + x += 1 + carry = x + return carry, x + + inputs = keras.ops.scan(f, 1, [0,1,2,3,4,5,6,7,8,9], length=None, reverse=False, unroll=1) + ``` + + Args: + f: The function that will be used to scan over the array xs. + init: The initial state of the carry argument for the scan function. + xs: The array that will be scanned over. + length: The length of the None xs array in the case of None xs argument. + reverse: If set to true the xs will be reversed at the start and + the ys array will be reversed at the end. + unroll: Optional positive int specifying, in the underlying operation of + the scan primitive, how many scan iterations to unroll within a single + iteration of a loop(Supported only on jax backend). + + Returns: + A scanned array and a carry element. + """ + if any_symbolic_tensors((init, xs)): + return SliceUpdate().symbolic_call(f, init, xs, length, reverse, unroll) + return backend.core.scan(f, init, xs, length, reverse, unroll) class WhileLoop(Operation): diff --git a/keras/ops/core_test.py b/keras/ops/core_test.py index af751efe9b0..d3c127c9cc7 100644 --- a/keras/ops/core_test.py +++ b/keras/ops/core_test.py @@ -10,6 +10,7 @@ from keras import testing from keras.backend.common.keras_tensor import KerasTensor from keras.ops import core +import jax class CoreOpsStaticShapeTest(testing.TestCase): @@ -215,15 +216,65 @@ def test_slice_update(self): outputs = core.slice_update(inputs, start_indices, updates) self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2])) - def test_scan(self): - x = [0,1,2,3,4,5,6] - x_target = [1, 2, 3, 4, 5, 6, 7] - def f(carry,x): - x +=carry - return carry,x - carr , y = core.scan(f,1,x) - self.assertEqual(x_target, y) + def f(carry, x): + if type(carry) is list or type(carry) is tuple: + x += carry[0] + else: + print(x) + print('carry',carry) + x += carry + + + return carry, x + + + init_carr = np.array(1) + xs = np.array([0, 1, 2, 3, 4, 5, 6]) + + carry_op, ys_op = core.scan(f, init_carr, xs, length=len(xs), reverse=False) + carry_jax, ys_jax = jax.lax.scan(f, init_carr, xs, length=len(xs), reverse=False) + + ys_op = ys_op.tolist() + ys_jax = ys_jax.tolist() + + self.assertEqual(carry_op, carry_jax) + self.assertListEqual(ys_op, ys_jax) + + init_carr = np.array(1.1) + xs = np.array([0.1, 1.2, 2.3, 3.4, 4.5, 5.6, 6.7]) + + carry_op, ys_op = core.scan(f, init_carr, xs, length=len(xs), reverse=True) + carry_jax, ys_jax = jax.lax.scan(f, init_carr, xs, length=len(xs), reverse=True) + + self.assertEqual(carry_op, carry_jax) + self.assertEqual(ys_op, ys_jax) + + init_carr = np.array('a') + xs = np.array(['q', 'w', 'e', 'r', 't', 'y', 'u']) + + carry_op, ys_op = core.scan(f, init_carr, xs, length=len(xs), reverse=False) + carry_jax, ys_jax = jax.lax.scan(f, init_carr, xs, length=len(xs), reverse=False) + + self.assertEqual(carry_op, carry_jax) + self.assertEqual(ys_op, ys_jax) + + def f(carry, x): + if x is None: + x = 1 + x += carry + return carry, x + + + init_carr = np.array(1) + xs = None + + carry_op, ys_op = core.scan(f, init_carr, xs, length=25, reverse=False) + carry_jax, ys_jax = jax.lax.scan(f, init_carr, xs, length=25, reverse=False) + + self.assertEqual(carry_op, carry_jax) + self.assertEqual(ys_op, ys_jax) + def test_while_loop(self): diff --git a/shell/format.sh b/shell/format.sh index f2992e44f89..42900132f9b 100755 --- a/shell/format.sh +++ b/shell/format.sh @@ -1,5 +1,5 @@ #!/bin/bash -set -Eeuo pipefail +set -Eeu pipefail base_dir=$(dirname $(dirname $0)) From b83a8d8abd7c1e1498a2feed0aac522d899139e9 Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Tue, 10 Oct 2023 20:10:36 +0300 Subject: [PATCH 03/11] updated documentation formated the code added a tests --- keras/backend/jax/core.py | 2 +- keras/backend/numpy/core.py | 2 +- keras/backend/tensorflow/core.py | 7 ++-- keras/ops/core.py | 3 +- keras/ops/core_test.py | 55 +++++++++++--------------------- 5 files changed, 26 insertions(+), 43 deletions(-) diff --git a/keras/backend/jax/core.py b/keras/backend/jax/core.py index 007a5762165..fa615315915 100644 --- a/keras/backend/jax/core.py +++ b/keras/backend/jax/core.py @@ -264,7 +264,7 @@ def slice_update(inputs, start_indices, updates): def scan(f, init, xs, length=None, reverse=False, unroll=1): - print('here jax') + return jax.lax.scan(f, init, xs, length, reverse, unroll) diff --git a/keras/backend/numpy/core.py b/keras/backend/numpy/core.py index 9411ed2a9e5..0d622f5834a 100644 --- a/keras/backend/numpy/core.py +++ b/keras/backend/numpy/core.py @@ -205,7 +205,7 @@ def while_loop( def scan(f, init, xs, length=None, reverse=False, unroll=1): - print('here numpy') + if xs is None: xs = [None] * length carry = init diff --git a/keras/backend/tensorflow/core.py b/keras/backend/tensorflow/core.py index cdb9f2bdc47..fcacb5bc800 100644 --- a/keras/backend/tensorflow/core.py +++ b/keras/backend/tensorflow/core.py @@ -201,15 +201,16 @@ def scan(f, init, xs, length=None, reverse=False, unroll=1): if reverse: np.flip(xs) - init = (init, np.array(0,dtype=init.dtype)) + init = (init, np.array(0, dtype=init.dtype)) - carry,ys = tf.scan(f, xs, initializer=init) + carry, ys = tf.scan(f, xs, initializer=init) ys = ys.numpy() + carry = carry.numpy() if reverse: np.flip(ys) - return carry.numpy()[-1], ys + return carry[0], ys def while_loop( diff --git a/keras/ops/core.py b/keras/ops/core.py index 2b424bc3c4a..936d8ad2376 100644 --- a/keras/ops/core.py +++ b/keras/ops/core.py @@ -175,6 +175,7 @@ def call(f, init, xs, length, reverse, unroll): def compute_output_spec(self, f, init, xs, length, reverse, unroll): return KerasTensor(xs.shape, dtype=xs.dtype) + @keras_export("keras.ops.slice_update") def slice_update(inputs, start_indices, updates): """Update an input by slicing in a tensor of updated values. @@ -229,7 +230,7 @@ def f(carry, x): carry = x return carry, x - inputs = keras.ops.scan(f, 1, [0,1,2,3,4,5,6,7,8,9], length=None, reverse=False, unroll=1) + inputs = keras.ops.scan(f,1,[0,1,2,3],length=None,reverse=False,unroll=1) ``` Args: diff --git a/keras/ops/core_test.py b/keras/ops/core_test.py index d3c127c9cc7..1cdeb97e81f 100644 --- a/keras/ops/core_test.py +++ b/keras/ops/core_test.py @@ -1,3 +1,4 @@ +import jax import numpy as np import pytest @@ -10,7 +11,6 @@ from keras import testing from keras.backend.common.keras_tensor import KerasTensor from keras.ops import core -import jax class CoreOpsStaticShapeTest(testing.TestCase): @@ -220,20 +220,21 @@ def test_scan(self): def f(carry, x): if type(carry) is list or type(carry) is tuple: x += carry[0] + carry = carry[0] else: - print(x) - print('carry',carry) - x += carry - + x += carry return carry, x - init_carr = np.array(1) xs = np.array([0, 1, 2, 3, 4, 5, 6]) - carry_op, ys_op = core.scan(f, init_carr, xs, length=len(xs), reverse=False) - carry_jax, ys_jax = jax.lax.scan(f, init_carr, xs, length=len(xs), reverse=False) + carry_op, ys_op = core.scan( + f, init_carr, xs, length=len(xs), reverse=False + ) + carry_jax, ys_jax = jax.lax.scan( + f, init_carr, xs, length=len(xs), reverse=False + ) ys_op = ys_op.tolist() ys_jax = ys_jax.tolist() @@ -244,38 +245,18 @@ def f(carry, x): init_carr = np.array(1.1) xs = np.array([0.1, 1.2, 2.3, 3.4, 4.5, 5.6, 6.7]) - carry_op, ys_op = core.scan(f, init_carr, xs, length=len(xs), reverse=True) - carry_jax, ys_jax = jax.lax.scan(f, init_carr, xs, length=len(xs), reverse=True) - - self.assertEqual(carry_op, carry_jax) - self.assertEqual(ys_op, ys_jax) - - init_carr = np.array('a') - xs = np.array(['q', 'w', 'e', 'r', 't', 'y', 'u']) - - carry_op, ys_op = core.scan(f, init_carr, xs, length=len(xs), reverse=False) - carry_jax, ys_jax = jax.lax.scan(f, init_carr, xs, length=len(xs), reverse=False) - - self.assertEqual(carry_op, carry_jax) - self.assertEqual(ys_op, ys_jax) - - def f(carry, x): - if x is None: - x = 1 - x += carry - return carry, x - - - init_carr = np.array(1) - xs = None + carry_op, ys_op = core.scan( + f, init_carr, xs, length=len(xs), reverse=True + ) + carry_jax, ys_jax = jax.lax.scan( + f, init_carr, xs, length=len(xs), reverse=True + ) - carry_op, ys_op = core.scan(f, init_carr, xs, length=25, reverse=False) - carry_jax, ys_jax = jax.lax.scan(f, init_carr, xs, length=25, reverse=False) + ys_op = ys_op.tolist() + ys_jax = ys_jax.tolist() self.assertEqual(carry_op, carry_jax) - self.assertEqual(ys_op, ys_jax) - - + self.assertListEqual(ys_op, ys_jax) def test_while_loop(self): def cond(x, y): From aed1f55e52a0dd6fc672e596f03b23cbc51ffc2a Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Tue, 17 Oct 2023 20:51:20 +0300 Subject: [PATCH 04/11] some fixes --- keras/backend/numpy/core.py | 4 ++-- keras/backend/tensorflow/core.py | 13 +++++-------- keras/backend/torch/core.py | 5 +++-- keras/ops/core_test.py | 30 ++++++++++++++++++++++++++++++ 4 files changed, 40 insertions(+), 12 deletions(-) diff --git a/keras/backend/numpy/core.py b/keras/backend/numpy/core.py index 0d622f5834a..730b6420e7c 100644 --- a/keras/backend/numpy/core.py +++ b/keras/backend/numpy/core.py @@ -211,12 +211,12 @@ def scan(f, init, xs, length=None, reverse=False, unroll=1): carry = init ys = [] if reverse: - xs.reverse() + xs = np.flip(xs) for x in xs: carry, y = f(carry, x) ys.append(y) if reverse: - ys.reverse() + ys = np.flip(ys) return carry, np.stack(ys) diff --git a/keras/backend/tensorflow/core.py b/keras/backend/tensorflow/core.py index fcacb5bc800..de855334f5d 100644 --- a/keras/backend/tensorflow/core.py +++ b/keras/backend/tensorflow/core.py @@ -199,18 +199,15 @@ def scan(f, init, xs, length=None, reverse=False, unroll=1): if xs is None: xs = [None] * length if reverse: - np.flip(xs) + tf.reverse(xs,[0]) - init = (init, np.array(0, dtype=init.dtype)) + init = (init, tf.zeros_like(0,dtype=init.dtype)) carry, ys = tf.scan(f, xs, initializer=init) - ys = ys.numpy() - carry = carry.numpy() - if reverse: - np.flip(ys) - - return carry[0], ys + if carry[0].dtype is tf.float64: + return tf.cast(carry[0],dtype=tf.float32), ys.numpy() + return carry[0], ys.numpy() def while_loop( diff --git a/keras/backend/torch/core.py b/keras/backend/torch/core.py index c77070fc73f..afcff54f21b 100644 --- a/keras/backend/torch/core.py +++ b/keras/backend/torch/core.py @@ -340,14 +340,15 @@ def scan(f, init, xs, length=None, reverse=False, unroll=1): xs = [None] * length carry = init ys = [] + xs = torch.tensor(xs) if reverse: - xs.reverse() + xs = torch.flip(xs,[-1]) for x in xs: carry, y = f(carry, x) ys.append(y) if reverse: ys.reverse() - return carry, np.stack(ys) + return carry, np.array(ys) def while_loop( diff --git a/keras/ops/core_test.py b/keras/ops/core_test.py index 1cdeb97e81f..342ffadc2df 100644 --- a/keras/ops/core_test.py +++ b/keras/ops/core_test.py @@ -78,6 +78,31 @@ def test_unstack(self): ): core.unstack(x, axis=axis) + def test_scan(self): + def f(carry, x): + if type(carry) is list or type(carry) is tuple: + x += carry[0] + carry = carry[0] + else: + + x += carry + return carry, x + + init_carr = np.array(1) + xs = np.array([0, 1, 2, 3, 4, 5, 6]) + + carry_op, ys_op = core.scan( + f, init_carr, xs, length=len(xs), reverse=False + ) + carry_jax, ys_jax = jax.lax.scan( + f, init_carr, xs, length=len(xs), reverse=False + ) + + ys_op = ys_op.tolist() + ys_jax = ys_jax.tolist() + + self.assertEqual(ys_jax.shape, ys_op.shape) + self.assertEqual(ys_jax.dtype, ys_op.dtype) class CoreOpsCorrectnessTest(testing.TestCase): def test_scatter(self): @@ -256,6 +281,11 @@ def f(carry, x): ys_jax = ys_jax.tolist() self.assertEqual(carry_op, carry_jax) + + for i in range(len(ys_op)): + ys_op[i] = round(ys_op[0],2) + ys_jax[i] = round(ys_jax[0], 2) + self.assertListEqual(ys_op, ys_jax) def test_while_loop(self): From bcfc4565c89bb6498a633e3b0679d8cba0264f34 Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Mon, 23 Oct 2023 21:05:27 +0300 Subject: [PATCH 05/11] some fixes --- keras/backend/numpy/core.py | 2 +- keras/backend/tensorflow/core.py | 8 ++-- keras/backend/torch/core.py | 4 +- keras/ops/core.py | 6 +-- keras/ops/core_test.py | 65 +++++++++++++++----------------- shell/format.sh | 2 +- 6 files changed, 42 insertions(+), 45 deletions(-) diff --git a/keras/backend/numpy/core.py b/keras/backend/numpy/core.py index 730b6420e7c..8b5cd654b2b 100644 --- a/keras/backend/numpy/core.py +++ b/keras/backend/numpy/core.py @@ -204,7 +204,7 @@ def while_loop( return loop_vars -def scan(f, init, xs, length=None, reverse=False, unroll=1): +def scan(f, init, xs, length=None, reverse=False): if xs is None: xs = [None] * length diff --git a/keras/backend/tensorflow/core.py b/keras/backend/tensorflow/core.py index de855334f5d..b5d88616f93 100644 --- a/keras/backend/tensorflow/core.py +++ b/keras/backend/tensorflow/core.py @@ -194,19 +194,19 @@ def slice_update(inputs, start_indices, updates): return dynamic_update_slice(inputs, updates, start_indices) -def scan(f, init, xs, length=None, reverse=False, unroll=1): +def scan(f, init, xs, length=None, reverse=False): if xs is None: xs = [None] * length if reverse: - tf.reverse(xs,[0]) + tf.reverse(xs, [0]) - init = (init, tf.zeros_like(0,dtype=init.dtype)) + init = (init, tf.zeros_like(0, dtype=init.dtype)) carry, ys = tf.scan(f, xs, initializer=init) if carry[0].dtype is tf.float64: - return tf.cast(carry[0],dtype=tf.float32), ys.numpy() + return tf.cast(carry[0], dtype=tf.float32), ys.numpy() return carry[0], ys.numpy() diff --git a/keras/backend/torch/core.py b/keras/backend/torch/core.py index afcff54f21b..7e7cdfedc0f 100644 --- a/keras/backend/torch/core.py +++ b/keras/backend/torch/core.py @@ -335,14 +335,14 @@ def slice_update(inputs, start_indices, updates): return outputs -def scan(f, init, xs, length=None, reverse=False, unroll=1): +def scan(f, init, xs, length=None, reverse=False): if xs is None: xs = [None] * length carry = init ys = [] xs = torch.tensor(xs) if reverse: - xs = torch.flip(xs,[-1]) + xs = torch.flip(xs, [-1]) for x in xs: carry, y = f(carry, x) ys.append(y) diff --git a/keras/ops/core.py b/keras/ops/core.py index 936d8ad2376..3a5de5991d0 100644 --- a/keras/ops/core.py +++ b/keras/ops/core.py @@ -213,7 +213,7 @@ def slice_update(inputs, start_indices, updates): @keras_export("keras.ops.scan") -def scan(f, init, xs, length=None, reverse=False, unroll=1): +def scan(f, init, xs, length=None, reverse=False): """Scan a function over leading array axes while carrying along state. At a high level, this operation does @@ -248,8 +248,8 @@ def f(carry, x): A scanned array and a carry element. """ if any_symbolic_tensors((init, xs)): - return SliceUpdate().symbolic_call(f, init, xs, length, reverse, unroll) - return backend.core.scan(f, init, xs, length, reverse, unroll) + return SliceUpdate().symbolic_call(f, init, xs, length, reverse) + return backend.core.scan(f, init, xs, length, reverse) class WhileLoop(Operation): diff --git a/keras/ops/core_test.py b/keras/ops/core_test.py index 342ffadc2df..82e702cb794 100644 --- a/keras/ops/core_test.py +++ b/keras/ops/core_test.py @@ -104,6 +104,7 @@ def f(carry, x): self.assertEqual(ys_jax.shape, ys_op.shape) self.assertEqual(ys_jax.dtype, ys_op.dtype) + class CoreOpsCorrectnessTest(testing.TestCase): def test_scatter(self): # Test 1D @@ -251,42 +252,38 @@ def f(carry, x): x += carry return carry, x - init_carr = np.array(1) - xs = np.array([0, 1, 2, 3, 4, 5, 6]) - - carry_op, ys_op = core.scan( - f, init_carr, xs, length=len(xs), reverse=False - ) - carry_jax, ys_jax = jax.lax.scan( - f, init_carr, xs, length=len(xs), reverse=False - ) - - ys_op = ys_op.tolist() - ys_jax = ys_jax.tolist() - - self.assertEqual(carry_op, carry_jax) - self.assertListEqual(ys_op, ys_jax) - - init_carr = np.array(1.1) - xs = np.array([0.1, 1.2, 2.3, 3.4, 4.5, 5.6, 6.7]) - - carry_op, ys_op = core.scan( - f, init_carr, xs, length=len(xs), reverse=True - ) - carry_jax, ys_jax = jax.lax.scan( - f, init_carr, xs, length=len(xs), reverse=True - ) - - ys_op = ys_op.tolist() - ys_jax = ys_jax.tolist() - - self.assertEqual(carry_op, carry_jax) + test_cases = [ + (np.array([0, 1, 2, 3, 4, 5, 6]), 1), + (np.array([0.1, 1.2, 2.3, 3.4, 4.5, 5.6, 6.7]), 1.1), + (np.array([123, 423, 3, 78, 43, 13]), 1.1), + (np.array([-1, -2, -3, -4, -5, -6]), -2), + (np.array([0, 0, 0, 0, 0, 0, 0]), 0), + (np.array([1.1, 2, 3, 4, 5, 6, 7]), 1), + ] + for test_case in test_cases: + test_input_arr = test_case[0] + test_input_carry = test_case[1] + + carry_op, ys_op = core.scan( + f, + test_input_carry, + test_input_arr, + length=len(test_input_arr), + reverse=False, + ) + carry_jax, ys_jax = jax.lax.scan( + f, + test_input_carry, + test_input_arr, + length=len(test_input_arr), + reverse=False, + ) - for i in range(len(ys_op)): - ys_op[i] = round(ys_op[0],2) - ys_jax[i] = round(ys_jax[0], 2) + ys_op = ys_op.tolist() + ys_jax = ys_jax.tolist() - self.assertListEqual(ys_op, ys_jax) + self.assertEqual(carry_op, carry_jax) + self.assertAllClose(ys_op, ys_jax) def test_while_loop(self): def cond(x, y): diff --git a/shell/format.sh b/shell/format.sh index 42900132f9b..f2992e44f89 100755 --- a/shell/format.sh +++ b/shell/format.sh @@ -1,5 +1,5 @@ #!/bin/bash -set -Eeu pipefail +set -Eeuo pipefail base_dir=$(dirname $(dirname $0)) From e756750e1f1089c6595c8ca7b8f604b1dc5e5992 Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Wed, 8 Nov 2023 21:38:33 +0200 Subject: [PATCH 06/11] cores fixed --- keras/backend/jax/core.py | 1 - keras/backend/numpy/core.py | 10 ++++++++-- keras/backend/tensorflow/core.py | 18 ++++++++++++++---- keras/backend/torch/core.py | 11 ++++++++--- keras/ops/core_test.py | 18 ++++++++++++------ 5 files changed, 42 insertions(+), 16 deletions(-) diff --git a/keras/backend/jax/core.py b/keras/backend/jax/core.py index fa615315915..359a8f81a91 100644 --- a/keras/backend/jax/core.py +++ b/keras/backend/jax/core.py @@ -264,7 +264,6 @@ def slice_update(inputs, start_indices, updates): def scan(f, init, xs, length=None, reverse=False, unroll=1): - return jax.lax.scan(f, init, xs, length, reverse, unroll) diff --git a/keras/backend/numpy/core.py b/keras/backend/numpy/core.py index 8b5cd654b2b..78577563d92 100644 --- a/keras/backend/numpy/core.py +++ b/keras/backend/numpy/core.py @@ -205,7 +205,6 @@ def while_loop( def scan(f, init, xs, length=None, reverse=False): - if xs is None: xs = [None] * length carry = init @@ -215,9 +214,16 @@ def scan(f, init, xs, length=None, reverse=False): for x in xs: carry, y = f(carry, x) ys.append(y) + ys = np.array(ys) if reverse: ys = np.flip(ys) - return carry, np.stack(ys) + + if isinstance(ys, np.integer): + ys = ys.astype(np.int32) + + if isinstance(ys, np.floating): + ys = ys.astype(np.float32) + return carry, ys def fori_loop(lower, upper, body_fun, init_val): diff --git a/keras/backend/tensorflow/core.py b/keras/backend/tensorflow/core.py index b5d88616f93..b4a2df75a58 100644 --- a/keras/backend/tensorflow/core.py +++ b/keras/backend/tensorflow/core.py @@ -195,14 +195,24 @@ def slice_update(inputs, start_indices, updates): def scan(f, init, xs, length=None, reverse=False): - if xs is None: xs = [None] * length if reverse: tf.reverse(xs, [0]) - - init = (init, tf.zeros_like(0, dtype=init.dtype)) - + if ( + isinstance(init, float) + or isinstance(init, np.floating) + or any(isinstance(x, float) for x in xs) + or isinstance(xs, np.floating) + ): + init = ( + tf.cast(init, dtype=tf.float32), + tf.zeros_like(0, dtype=tf.float32), + ) + xs = tf.cast(xs, tf.float32) + else: + init = (init, tf.zeros_like(0, dtype=tf.int32)) + xs = tf.cast(xs, tf.int32) carry, ys = tf.scan(f, xs, initializer=init) if carry[0].dtype is tf.float64: diff --git a/keras/backend/torch/core.py b/keras/backend/torch/core.py index 7e7cdfedc0f..c865695b4bf 100644 --- a/keras/backend/torch/core.py +++ b/keras/backend/torch/core.py @@ -346,9 +346,14 @@ def scan(f, init, xs, length=None, reverse=False): for x in xs: carry, y = f(carry, x) ys.append(y) - if reverse: - ys.reverse() - return carry, np.array(ys) + + ys = np.array(ys) + if ys.dtype == np.int64: + ys = ys.astype(np.int32) + + if ys.dtype == np.float64: + ys = ys.astype(np.float32) + return carry, ys def while_loop( diff --git a/keras/ops/core_test.py b/keras/ops/core_test.py index 82e702cb794..66f6e73616b 100644 --- a/keras/ops/core_test.py +++ b/keras/ops/core_test.py @@ -84,7 +84,6 @@ def f(carry, x): x += carry[0] carry = carry[0] else: - x += carry return carry, x @@ -98,8 +97,8 @@ def f(carry, x): f, init_carr, xs, length=len(xs), reverse=False ) - ys_op = ys_op.tolist() - ys_jax = ys_jax.tolist() + ys_op = ys_op + ys_jax = ys_jax self.assertEqual(ys_jax.shape, ys_op.shape) self.assertEqual(ys_jax.dtype, ys_op.dtype) @@ -248,8 +247,16 @@ def f(carry, x): x += carry[0] carry = carry[0] else: - + if ( + isinstance(x, np.floating) + or isinstance(carry, np.floating) + or isinstance(x, float) + or isinstance(carry, float) + ): + x = float(x) # x.astype(np.float32) + carry = float(carry) # carry.astype(np.float32) x += carry + return carry, x test_cases = [ @@ -258,7 +265,7 @@ def f(carry, x): (np.array([123, 423, 3, 78, 43, 13]), 1.1), (np.array([-1, -2, -3, -4, -5, -6]), -2), (np.array([0, 0, 0, 0, 0, 0, 0]), 0), - (np.array([1.1, 2, 3, 4, 5, 6, 7]), 1), + (np.array([1.1, 2, 3, 4, 5, 6, 7]), 9), ] for test_case in test_cases: test_input_arr = test_case[0] @@ -281,7 +288,6 @@ def f(carry, x): ys_op = ys_op.tolist() ys_jax = ys_jax.tolist() - self.assertEqual(carry_op, carry_jax) self.assertAllClose(ys_op, ys_jax) From f40911ac499fe8299a0e89d51d9f7d5532c9f9de Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Thu, 9 Nov 2023 18:56:52 +0200 Subject: [PATCH 07/11] some fixes with regards to CoreOpsStaticShapeTest --- keras/backend/numpy/core.py | 12 +++++++++--- keras/backend/tensorflow/core.py | 2 +- 2 files changed, 10 insertions(+), 4 deletions(-) diff --git a/keras/backend/numpy/core.py b/keras/backend/numpy/core.py index 78577563d92..697fda4e8b5 100644 --- a/keras/backend/numpy/core.py +++ b/keras/backend/numpy/core.py @@ -218,11 +218,17 @@ def scan(f, init, xs, length=None, reverse=False): if reverse: ys = np.flip(ys) - if isinstance(ys, np.integer): - ys = ys.astype(np.int32) + if len(ys) > 0: + if isinstance(ys[0], np.integer): + ys = ys.astype(np.int32) + if isinstance(ys[0], np.floating): + ys = ys.astype(np.float32) + if isinstance(ys, np.integer): + carry = carry.astype(np.int32) if isinstance(ys, np.floating): - ys = ys.astype(np.float32) + carry = carry.astype(np.float32) + return carry, ys diff --git a/keras/backend/tensorflow/core.py b/keras/backend/tensorflow/core.py index b4a2df75a58..1050884f571 100644 --- a/keras/backend/tensorflow/core.py +++ b/keras/backend/tensorflow/core.py @@ -211,7 +211,7 @@ def scan(f, init, xs, length=None, reverse=False): ) xs = tf.cast(xs, tf.float32) else: - init = (init, tf.zeros_like(0, dtype=tf.int32)) + init = (tf.cast(init, dtype=tf.int32), tf.zeros_like(0, dtype=tf.int32)) xs = tf.cast(xs, tf.int32) carry, ys = tf.scan(f, xs, initializer=init) From d5b1e8f5ee98c71f801b78d20a3ef665625ef8d3 Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Wed, 15 Nov 2023 21:23:46 +0200 Subject: [PATCH 08/11] fixes according to the comments --- keras/backend/numpy/core.py | 3 +- keras/backend/tensorflow/core.py | 32 ++++++++++------- keras/backend/torch/core.py | 8 +++-- keras/ops/core.py | 11 +++--- keras/ops/core_test.py | 61 ++++++++++++++++++++------------ 5 files changed, 70 insertions(+), 45 deletions(-) diff --git a/keras/backend/numpy/core.py b/keras/backend/numpy/core.py index 697fda4e8b5..8f693a3bdf8 100644 --- a/keras/backend/numpy/core.py +++ b/keras/backend/numpy/core.py @@ -204,7 +204,7 @@ def while_loop( return loop_vars -def scan(f, init, xs, length=None, reverse=False): +def scan(f, init, xs, length=None, reverse=False, unroll=False): if xs is None: xs = [None] * length carry = init @@ -221,7 +221,6 @@ def scan(f, init, xs, length=None, reverse=False): if len(ys) > 0: if isinstance(ys[0], np.integer): ys = ys.astype(np.int32) - if isinstance(ys[0], np.floating): ys = ys.astype(np.float32) if isinstance(ys, np.integer): diff --git a/keras/backend/tensorflow/core.py b/keras/backend/tensorflow/core.py index 1050884f571..45396c4381c 100644 --- a/keras/backend/tensorflow/core.py +++ b/keras/backend/tensorflow/core.py @@ -194,29 +194,35 @@ def slice_update(inputs, start_indices, updates): return dynamic_update_slice(inputs, updates, start_indices) -def scan(f, init, xs, length=None, reverse=False): +def scan(f, init, xs, length=None, reverse=False, unroll=False): if xs is None: xs = [None] * length if reverse: tf.reverse(xs, [0]) - if ( - isinstance(init, float) - or isinstance(init, np.floating) - or any(isinstance(x, float) for x in xs) - or isinstance(xs, np.floating) + for x in xs: + print(type(x)) + print("any(type(x) is float for x in xs)") + print(any(np.issubdtype(type(x), np.floating) for x in xs)) + if type(init) is float or any( + np.issubdtype(type(x), np.floating) for x in xs ): init = ( - tf.cast(init, dtype=tf.float32), - tf.zeros_like(0, dtype=tf.float32), + tf.cast(init, dtype=tf.double), + tf.zeros_like(0, dtype=tf.double), ) - xs = tf.cast(xs, tf.float32) + xs = tf.cast(xs, dtype=tf.double) else: - init = (tf.cast(init, dtype=tf.int32), tf.zeros_like(0, dtype=tf.int32)) - xs = tf.cast(xs, tf.int32) + init = (tf.cast(init, dtype=tf.int64), tf.zeros_like(0, dtype=tf.int64)) + carry, ys = tf.scan(f, xs, initializer=init) - if carry[0].dtype is tf.float64: - return tf.cast(carry[0], dtype=tf.float32), ys.numpy() + # if carry[0].dtype is tf.float64: + # return tf.cast(carry[0], dtype=tf.float32), ys.numpy() + # carry = tf.cast(carry[0], dtype=tf.int64) + if ys.dtype == tf.int64: + ys = tf.cast(ys, dtype=tf.int32) + if ys.dtype == tf.double: + ys = tf.cast(ys, dtype=tf.float32) return carry[0], ys.numpy() diff --git a/keras/backend/torch/core.py b/keras/backend/torch/core.py index c865695b4bf..ff186d016fe 100644 --- a/keras/backend/torch/core.py +++ b/keras/backend/torch/core.py @@ -335,7 +335,7 @@ def slice_update(inputs, start_indices, updates): return outputs -def scan(f, init, xs, length=None, reverse=False): +def scan(f, init, xs, length=None, reverse=False, unroll=False): if xs is None: xs = [None] * length carry = init @@ -343,16 +343,18 @@ def scan(f, init, xs, length=None, reverse=False): xs = torch.tensor(xs) if reverse: xs = torch.flip(xs, [-1]) + for x in xs: carry, y = f(carry, x) ys.append(y) ys = np.array(ys) - if ys.dtype == np.int64: - ys = ys.astype(np.int32) if ys.dtype == np.float64: ys = ys.astype(np.float32) + if ys.dtype == np.int64: + ys = ys.astype(np.int32) + return carry, ys diff --git a/keras/ops/core.py b/keras/ops/core.py index 3a5de5991d0..7d7a74ab95a 100644 --- a/keras/ops/core.py +++ b/keras/ops/core.py @@ -173,7 +173,10 @@ def call(f, init, xs, length, reverse, unroll): return backend.core.scan(f, init, xs, length, reverse, unroll) def compute_output_spec(self, f, init, xs, length, reverse, unroll): - return KerasTensor(xs.shape, dtype=xs.dtype) + return [ + KerasTensor(xs.shape, dtype=xs.dtype), + KerasTensor(init.shape, dtype=init.dtype), + ] @keras_export("keras.ops.slice_update") @@ -213,7 +216,7 @@ def slice_update(inputs, start_indices, updates): @keras_export("keras.ops.scan") -def scan(f, init, xs, length=None, reverse=False): +def scan(f, init, xs, length=None, reverse=False, unroll=False): """Scan a function over leading array axes while carrying along state. At a high level, this operation does @@ -248,8 +251,8 @@ def f(carry, x): A scanned array and a carry element. """ if any_symbolic_tensors((init, xs)): - return SliceUpdate().symbolic_call(f, init, xs, length, reverse) - return backend.core.scan(f, init, xs, length, reverse) + return SliceUpdate().symbolic_call(f, init, xs, length, reverse, unroll) + return backend.core.scan(f, init, xs, length, reverse, unroll) class WhileLoop(Operation): diff --git a/keras/ops/core_test.py b/keras/ops/core_test.py index 66f6e73616b..e8f61208c7a 100644 --- a/keras/ops/core_test.py +++ b/keras/ops/core_test.py @@ -81,27 +81,42 @@ def test_unstack(self): def test_scan(self): def f(carry, x): if type(carry) is list or type(carry) is tuple: - x += carry[0] carry = carry[0] + if type(carry) is int or type(x) is int: + carry = int(carry) + + x += carry + + elif type(carry) is float or type(x) is float: + x = float(x) + carry = float(carry) + x += carry else: x += carry - return carry, x - init_carr = np.array(1) - xs = np.array([0, 1, 2, 3, 4, 5, 6]) + return carry, x - carry_op, ys_op = core.scan( - f, init_carr, xs, length=len(xs), reverse=False - ) - carry_jax, ys_jax = jax.lax.scan( - f, init_carr, xs, length=len(xs), reverse=False - ) + test_cases = [ + (np.array([0, 1, 2, 3, 4, 5, 6]), 1), + (np.array([0.1, 1.2, 2.3, 3.4, 4.5, 5.6, 6.7]), 1.1), + (np.array([123, 423, 3, 78, 43, 13]), 1.1), + (np.array([-1, -2, -3, -4, -5, -6]), -2), + (np.array([0, 0, 0, 0, 0, 0, 0]), 0), + (np.array([1.1, 2, 3, 4, 5, 6, 7]), 9), + ] + for test_case in test_cases: + init_carr = test_case[1] + xs = test_case[0] - ys_op = ys_op - ys_jax = ys_jax + carry_jax, ys_jax = jax.lax.scan( + f, init_carr, xs, length=len(xs), reverse=False + ) + carry_op, ys_op = core.scan( + f, init_carr, xs, length=len(xs), reverse=False + ) - self.assertEqual(ys_jax.shape, ys_op.shape) - self.assertEqual(ys_jax.dtype, ys_op.dtype) + self.assertEqual(ys_jax.shape, ys_op.shape) + self.assertEqual(ys_jax.dtype, ys_op.dtype) class CoreOpsCorrectnessTest(testing.TestCase): @@ -244,17 +259,17 @@ def test_slice_update(self): def test_scan(self): def f(carry, x): if type(carry) is list or type(carry) is tuple: - x += carry[0] carry = carry[0] + if type(carry) is int or type(x) is int: + carry = int(carry) + + x += carry + + elif type(carry) is float or type(x) is float: + x = float(x) + carry = float(carry) + x += carry else: - if ( - isinstance(x, np.floating) - or isinstance(carry, np.floating) - or isinstance(x, float) - or isinstance(carry, float) - ): - x = float(x) # x.astype(np.float32) - carry = float(carry) # carry.astype(np.float32) x += carry return carry, x From ef216c4c6d114f5e7ba2130e5b4aa3ae4d83a4a3 Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Tue, 5 Dec 2023 19:08:15 +0200 Subject: [PATCH 09/11] format fixed --- keras/ops/core_test.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/keras/ops/core_test.py b/keras/ops/core_test.py index 085ca346e67..efa5040ff62 100644 --- a/keras/ops/core_test.py +++ b/keras/ops/core_test.py @@ -1,5 +1,6 @@ -import jax import contextlib + +import jax import numpy as np import pytest from absl.testing import parameterized From 611e974ad7512db4550af6aca1e6128f904c332c Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Sat, 9 Dec 2023 18:45:00 +0200 Subject: [PATCH 10/11] fixed default argument --- keras/ops/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/ops/core.py b/keras/ops/core.py index 2147610b7ca..0f736ae1007 100644 --- a/keras/ops/core.py +++ b/keras/ops/core.py @@ -217,7 +217,7 @@ def slice_update(inputs, start_indices, updates): @keras_export("keras.ops.scan") -def scan(f, init, xs, length=None, reverse=False, unroll=False): +def scan(f, init, xs, length=None, reverse=False, unroll=1): """Scan a function over leading array axes while carrying along state. At a high level, this operation does From 71d3173b9386f3e610e643ece9e7687ff645f33e Mon Sep 17 00:00:00 2001 From: Valko Milev Date: Mon, 8 Jan 2024 23:34:44 +0200 Subject: [PATCH 11/11] fixes according to comments --- keras/backend/tensorflow/core.py | 3 --- keras/ops/core.py | 2 +- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/keras/backend/tensorflow/core.py b/keras/backend/tensorflow/core.py index 6ecaf9ddb98..6cf10863fde 100644 --- a/keras/backend/tensorflow/core.py +++ b/keras/backend/tensorflow/core.py @@ -248,9 +248,6 @@ def scan(f, init, xs, length=None, reverse=False, unroll=False): carry, ys = tf.scan(f, xs, initializer=init) - # if carry[0].dtype is tf.float64: - # return tf.cast(carry[0], dtype=tf.float32), ys.numpy() - # carry = tf.cast(carry[0], dtype=tf.int64) if ys.dtype == tf.int64: ys = tf.cast(ys, dtype=tf.int32) if ys.dtype == tf.double: diff --git a/keras/ops/core.py b/keras/ops/core.py index 0f736ae1007..4d22f6915f2 100644 --- a/keras/ops/core.py +++ b/keras/ops/core.py @@ -252,7 +252,7 @@ def f(carry, x): A scanned array and a carry element. """ if any_symbolic_tensors((init, xs)): - return SliceUpdate().symbolic_call(f, init, xs, length, reverse, unroll) + return Scan().symbolic_call(f, init, xs, length, reverse, unroll) return backend.core.scan(f, init, xs, length, reverse, unroll)