Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added onnx friendly merging #43

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

NatanBagrov
Copy link

Added onnx friendly merging
Also, together with a test for correctness. NOTE: this only supports 'mean' and 'sum' reductions
fixes #31

…TE: this only supports 'mean' and 'sum' reductions
@dbolya
Copy link
Owner

dbolya commented May 31, 2023

Thanks for writing this! Do you know what the speed difference might be between this and the native pytorch solution (w/o onnx)?

Essentially, this is a native pytorch implementation of scatter_reduce, not an onnx-specific implementation (kinda like a polyfill). This might be useful for more than just onnx (e.g., any platform where scatter_reduce is not implemented, maybe like directml). If I accept this, I'd probably re-write it to test if scatter_reduce is available, and then if not use this implementation instead, rather than use a manual flag. (Also it looks like you left some autolinter changes in the commits, not sure if you meant to do that or not.)

@NatanBagrov
Copy link
Author

Hi, I'll check the numbers and will update.
Re testing availability - I'm not sure if this can be done without actually converting to ONNX.
See this and this.

@dbolya
Copy link
Owner

dbolya commented Jun 5, 2023

Hmm, would it be possible to just attempt to run the function and then fall back to the polyfill if there's an error? I see you suggested a try-except over the entire conversion process, but is it possible to just try-except the function itself?

Alternatively, maybe there's a way to detect whether an onnx / jit trace is currently being performed instead of a normal forward pass, and set the flag then.

@ravitejaroyal
Copy link

Added onnx friendly merging Also, together with a test for correctness. NOTE: this only supports 'mean' and 'sum' reductions fixes #31

@NatanBagrov, did you able to export model to ONNX format with ONNX friendly code? I am getting assertion error due to below line.
dst = dst.scatter(-2, dst_idx.expand(n, r, c), src, reduce="add")
AssertionError: A mismatch between the number of arguments (5) and their descriptors (4) was found at symbolic function 'scatter'. If you believe this is not due to custom symbolic implementation within your code or an external library, please file an issue at https://github.com/pytorch/pytorch/issues/new?template=bug-report.yml to report this bug.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Request: Add compatiblity with ONNX (OnnxStableDiffusionPipeline)
3 participants