Skip to content

Commit

Permalink
Add tf.SparseTensor and tf.IndexedSlices support to a number of T…
Browse files Browse the repository at this point in the history
…ensorFlow ops.

The following element-wise unary ops now support `tf.SparseTensor` and `tf.IndexedSlices`. The output is of the same type as the input.
- abs
- absolute
- arcsin
- arcsinh
- arctan
- arctanh
- ceil
- conj
- conjugate
- copy
- expm1
- floor
- imag
- log1p
- negative
- real
- round
- sign
- sin
- sinh
- sqrt
- square
- tan
- tanh

The following element-wise unary ops now support `tf.SparseTensor` and `tf.IndexedSlices`. The output is dense.
- arccos
- arccosh
- cos
- cosh
- exp
- log
- log10
- log2
- reciprocal

The following element-wise binary ops now support `tf.SparseTensor` and `tf.IndexedSlices`. The output type depends on the two inputs and the op.
- add
- subtract
- maximum
- minimum
- multiply
- mod
- divide
- true_divide
- floor_divide

The following reduction op now supports `tf.IndexedSlices`. The output is an `tf.IndexedSlices` unless dimension 0 is reduced or the rank of the output is 1 or less.
- mean

This is in preparation for supporting sparse gradients in optimizers.
  • Loading branch information
hertschuh committed Oct 23, 2023
1 parent 742a2d0 commit 68c42aa
Show file tree
Hide file tree
Showing 6 changed files with 1,380 additions and 276 deletions.
3 changes: 2 additions & 1 deletion keras/backend/common/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from keras import ops
from keras.backend.common import dtypes
from keras.backend.common.variables import ALLOWED_DTYPES
from keras.backend.torch.core import to_torch_dtype
from keras.testing import test_case
from keras.testing.test_utils import named_product

Expand All @@ -13,6 +12,8 @@ class DtypesTest(test_case.TestCase, parameterized.TestCase):
"""Test the dtype to verify that the behavior matches JAX."""

if backend.backend() == "torch":
from keras.backend.torch.core import to_torch_dtype

# TODO: torch doesn't support uint64.
ALL_DTYPES = []
for x in ALLOWED_DTYPES:
Expand Down
2 changes: 2 additions & 0 deletions keras/backend/tensorflow/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ def convert_to_tensor(x, dtype=None, sparse=True):
def convert_to_numpy(x):
if isinstance(x, tf.SparseTensor):
x = tf.sparse.to_dense(x)
elif isinstance(x, tf.IndexedSlices):
x = tf.convert_to_tensor(x)
return np.array(x)


Expand Down
Loading

0 comments on commit 68c42aa

Please sign in to comment.