Skip to content

Commit

Permalink
Add further tests/docs to vectorized_map.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Nov 2, 2023
1 parent 7f29a09 commit 9197591
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 7 deletions.
31 changes: 24 additions & 7 deletions keras/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -596,16 +596,33 @@ def cond(pred, true_fn, false_fn):

# TODO: also create an Op subclass VectorizedMap.
@keras_export("keras.ops.vectorized_map")
def vectorized_map(function, x):
"""Parallel map of `function` on axis 0 of tensor `x`.
def vectorized_map(function, elements):
"""Parallel map of `function` on axis 0 of tensor(s) `elements`.
Schematically, `vectorized_map` implements the following:
Schematically, `vectorized_map` implements the following,
in the case of a single tensor input `elements`:
```python
def vectorized_map(function, x)
def vectorized_map(function, elements)
outputs = []
for element in x:
outputs.append(function(element))
for e in elements:
outputs.append(function(e))
return stack(outputs)
```
In the case of an iterable of tensors `elements`,
it implements the following:
```python
def vectorized_map(function, elements)
batch_size = elements[0].shape[0]
outputs = []
for index in range(batch_size):
outputs.append(function([e[index] for e in elements]))
return np.stack(outputs)
```
In this case, `function` is expected to take as input
a single list of tensor arguments.
"""
return backend.core.vectorized_map(function, x)
return backend.core.vectorized_map(function, elements)
10 changes: 10 additions & 0 deletions keras/ops/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,3 +402,13 @@ def fn(x):
self.assertAllClose(
backend.convert_to_numpy(output), np.zeros((2, 2, 3))
)

# Case: multiple args
def fn(elems):
x, y = elems
return x + y

output = ops.vectorized_map(fn, [ops.ones((2, 3)), ops.ones((2, 3))])
self.assertAllClose(
backend.convert_to_numpy(output), 2 * np.ones((2, 3))
)

0 comments on commit 9197591

Please sign in to comment.