Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tf.SparseTensor and tf.IndexedSlices support to a number of T… #18633

Merged
merged 1 commit into from
Oct 25, 2023

Conversation

hertschuh
Copy link
Collaborator

@hertschuh hertschuh commented Oct 17, 2023

…ensorFlow ops.

The following element-wise unary ops now support tf.SparseTensor and tf.IndexedSlices. The output is of the same type as the input.

  • abs
  • absolute
  • arcsin
  • arcsinh
  • arctan
  • arctanh
  • ceil
  • conj
  • conjugate
  • copy
  • expm1
  • floor
  • imag
  • log1p
  • negative
  • real
  • round
  • sign
  • sin
  • sinh
  • sqrt
  • square
  • tan
  • tanh

The following element-wise unary ops now support tf.SparseTensor and tf.IndexedSlices. The output is dense.

  • arccos
  • arccosh
  • cos
  • cosh
  • exp
  • log
  • log10
  • log2
  • reciprocal

The following element-wise binary ops now support tf.SparseTensor and tf.IndexedSlices. The output type depends on the two inputs and the op.

  • add (already supported tf.SparseTensor)
  • subtract (already supported tf.SparseTensor)
  • maximum (already supported tf.SparseTensor)
  • minimum (already supported tf.SparseTensor)
  • multiply (already supported tf.SparseTensor)
  • mod
  • divide
  • true_divide
  • floor_divide

The following reduction op now supports tf.IndexedSlices. The output is an tf.IndexedSlices unless dimension 0 is reduced or the rank of the output is 1 or less.

  • mean

This is in preparation for supporting sparse gradients in optimizers.

@codecov-commenter
Copy link

codecov-commenter commented Oct 17, 2023

Codecov Report

Attention: 7 lines in your changes are missing coverage. Please review.

Comparison is base (b7152af) 78.48% compared to head (64ef001) 78.65%.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #18633      +/-   ##
==========================================
+ Coverage   78.48%   78.65%   +0.16%     
==========================================
  Files         335      336       +1     
  Lines       33090    33369     +279     
  Branches     6486     6529      +43     
==========================================
+ Hits        25972    26246     +274     
- Misses       5546     5550       +4     
- Partials     1572     1573       +1     
Flag Coverage Δ
keras 78.55% <98.05%> (+0.16%) ⬆️
keras-jax 62.21% <45.68%> (-1.18%) ⬇️
keras-numpy 56.55% <45.68%> (-1.13%) ⬇️
keras-tensorflow 63.79% <98.05%> (-0.73%) ⬇️
keras-torch 64.97% <45.68%> (-0.19%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files Coverage Δ
keras/backend/tensorflow/core.py 96.29% <100.00%> (+0.05%) ⬆️
keras/ops/numpy.py 95.84% <100.00%> (+0.06%) ⬆️
keras/backend/tensorflow/numpy.py 95.79% <97.26%> (+0.24%) ⬆️
keras/backend/tensorflow/sparse.py 97.56% <97.56%> (ø)

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@hertschuh hertschuh marked this pull request as draft October 17, 2023 01:20
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

from keras.backend.tensorflow.core import convert_to_tensor


@sparse.element_wise_binary_on_union(tf.sparse.add)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's apply the following replacements:

element_wise -> elementwise (everywhere)
on_union -> just union (to be consistent with intersection)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: is this new system really better than tf.sparse.add? Does it have performance implications? The previous "if + use tf.sparse" setup was very readable and easy to maintain. The new system is fairly high in abstraction and complexity. Are the benefits commensurate?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do the replacements

is this new system really better than tf.sparse.add?

So there are 2 times 3 cases: SparseTensor / IndexedSlices and union (add, maximum...), intersection (multiply), division (divide, true_divide, mod, floor_divide). tf.sparse happens to mostly support the "union" operators out of the box, so it was very little code. But it didn't support the other use cases and IndexedSlices is not supported for any operation. There is a lot of code for the 5 other use cases (for instance, this CL moved the long chunk of code to get multiply working for sparse from numpy.py to sparse.py).

One thing we could do differently is have the elementwise_binary_union decorator only add IndexedSlices support and not SparseTensors to have the tf.sparse.add code inline. I initially had it this way, but then I decided to make it look consistent.

Does it have performance implications?

For tf.sparse.add specifically, the only performance impact is the call to the decorator, which we need anyway for the IndexedSlices support. It does use tf.sparse.add directly, this is not re-implemented.

The new system is fairly high in abstraction and complexity.

I agree. After all was done to handle all the possible combinations, it was a lot more complex than initially expected.

Are the benefits commensurate?

Well the benefit is that with one decorator, you make an op SparseTensor and IndexedSlices compatible. And it's self documenting, you know which ops support SparseTensor and IndexedSlices (we could generate documentation based on the decorator). I could have decorated more ops actually (e.g. logical operations).

I want to look at JAX and see what approach makes sense for JAX. Maybe it will be similar, maybe it will make me rethink this whole thing.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the clarifications, this all sounds reasonable. Please let me know what you think about the JAX situation, and then we can finalize the design.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

One other thing is that implementing support in the ops directly adds complexity. For instance, in practice, this union / intersection business is never needed for IndexedSlices. In optimizers the only times IndexedSlices are combined, they always have the same indices. But in the ops, we can't just make this assumption.

@fchollet
Copy link
Collaborator

To note, cond ops can be quite expensive. We should check what performance looks like.

@hertschuh
Copy link
Collaborator Author

To note, cond ops can be quite expensive. We should check what performance looks like.

Interesting. Some of the conds are optimizations (shortcut when indices are the same, which is a very common use case with sparse gradients). But some are workarounds for TF bugs. For instance, if after computing the intersection nothing is else (empty SparseTensor with no indices), a bunch of trivial operations fails (e.g. reshape) even though they're perfectly valid, so I had to special case them.

@hertschuh hertschuh force-pushed the indexed_slices_ops branch 4 times, most recently from ab111e2 to 68c42aa Compare October 23, 2023 19:17
@hertschuh
Copy link
Collaborator Author

To note, cond ops can be quite expensive. We should check what performance looks like.

Interesting. Some of the conds are optimizations (shortcut when indices are the same, which is a very common use case with sparse gradients). But some are workarounds for TF bugs. For instance, if after computing the intersection nothing is else (empty SparseTensor with no indices), a bunch of trivial operations fails (e.g. reshape) even though they're perfectly valid, so I had to special case them.

Looking into this a bit more, I found a way to check that the indices are the same at trace time. So, in the case of optimizers for instance, the conds should no longer be in the graph because when we combine two IndexedSlices, they come from the same gradient and have the same indices.

@hertschuh hertschuh marked this pull request as ready for review October 23, 2023 19:33
…ensorFlow ops.

The following element-wise unary ops now support `tf.SparseTensor` and `tf.IndexedSlices`. The output is of the same type as the input.
- abs
- absolute
- arcsin
- arcsinh
- arctan
- arctanh
- ceil
- conj
- conjugate
- copy
- expm1
- floor
- imag
- log1p
- negative
- real
- round
- sign
- sin
- sinh
- sqrt
- square
- tan
- tanh

The following element-wise unary ops now support `tf.SparseTensor` and `tf.IndexedSlices`. The output is dense.
- arccos
- arccosh
- cos
- cosh
- exp
- log
- log10
- log2
- reciprocal

The following element-wise binary ops now support `tf.SparseTensor` and `tf.IndexedSlices`. The output type depends on the two inputs and the op.
- add
- subtract
- maximum
- minimum
- multiply
- mod
- divide
- true_divide
- floor_divide

The following reduction op now supports `tf.IndexedSlices`. The output is an `tf.IndexedSlices` unless dimension 0 is reduced or the rank of the output is 1 or less.
- mean

This is in preparation for supporting sparse gradients in optimizers.
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thank you for the contribution!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Oct 25, 2023
@fchollet fchollet merged commit fc2829d into keras-team:master Oct 25, 2023
6 checks passed
@google-ml-butler google-ml-butler bot removed ready to pull Ready to be merged into the codebase kokoro:force-run labels Oct 25, 2023
@hertschuh hertschuh deleted the indexed_slices_ops branch October 25, 2023 16:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants