-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[EVT] Add support for Row/Col broadcast PtrArray #2033
base: main
Are you sure you want to change the base?
Conversation
@ANIKET-SHIVAM @hwu36 Do you have a timeline to review this? We need this feature enabled on cutlass ASAP to unblock our usecases at Meta, e.g., pytorch/FBGEMM#3560 |
include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
Outdated
Show resolved
Hide resolved
include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp
Outdated
Show resolved
Hide resolved
Thanks for taking a look @Skylion007, I've incorporated your feedback if you'd like to check again to make sure this all looks good. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Thanks! Can you merge this to cutlass to unblock e.g., pytorch/FBGEMM#3560? |
Need a nvidia employee to to do that. @eqy might know who to ping. |
cc. @hwu36 |
We are working on this. |
bfba683
to
dadd7c2
Compare
dadd7c2
to
2d14744
Compare
@ANIKET-SHIVAM I've refactored this PR so that the functionality is built into the existing Row/Col EVT nodes. Can you take another look? |
To enable FP8 grouped gemm with rowwise scaling in the epilogue, we need to be able to provide a list of pointers to the scales for each group. This PR extends Sm90Row/ColBroadcast to support PTRArray to handle this case. Now if the ElementInput type is specified as a pointer, the corresponding input is an array of pointers, enabling grouped gemm. For example in our case we define EVT nodes like this: