Generalize the PyTorch avg_pool linalg lowering algorithm for the case where count_include_pad = false. #4010
+473
−151
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Currently the avg_pool2d PyTorch operation supports the cases where count_include_pad is true and false, but the avg_pool1d and avg_pool3d only the true case is supported (which is simpler).
The count_include_pad = false support for avg_pool2d was added by @AmosLewis in this change (reviewed by @rsuderman and @nirvedhmeshram) : #3235
In this change I generalized the logic added above. I also did some refactoring to the original code to reduce the size of the functions and to avoid redundancy when possible.
@sahas3 @dixinzhou @rafaelubal