Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[EVT] Add support for Row/Col broadcast PtrArray #2033

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

jwfromm
Copy link

@jwfromm jwfromm commented Jan 8, 2025

To enable FP8 grouped gemm with rowwise scaling in the epilogue, we need to be able to provide a list of pointers to the scales for each group. This PR extends Sm90Row/ColBroadcast to support PTRArray to handle this case. Now if the ElementInput type is specified as a pointer, the corresponding input is an array of pointers, enabling grouped gemm. For example in our case we define EVT nodes like this:

  using XScale = cutlass::epilogue::fusion::Sm90ColBroadcast<
      0,
      TileShape,
      ElementComputeEpilogue*, // Indicate input is array of pointers.
      ElementComputeEpilogue,
      cute::Stride<cute::Int<1>, cute::Int<0>, cute::Int<0>>>;

  using WScale = cutlass::epilogue::fusion::Sm90RowBroadcast<
      0,
      TileShape,
      ElementComputeEpilogue*, // Indicate input is array of pointers.
      ElementComputeEpilogue,
      cute::Stride<cute::Int<0>, cute::Int<1>, cute::Int<0>>>;

@ANIKET-SHIVAM ANIKET-SHIVAM self-requested a review January 8, 2025 19:02
@jiawenliu64
Copy link

@ANIKET-SHIVAM @hwu36 Do you have a timeline to review this? We need this feature enabled on cutlass ASAP to unblock our usecases at Meta, e.g., pytorch/FBGEMM#3560

@jwfromm
Copy link
Author

jwfromm commented Jan 13, 2025

Thanks for taking a look @Skylion007, I've incorporated your feedback if you'd like to check again to make sure this all looks good.

Copy link

@Skylion007 Skylion007 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@jiawenliu64
Copy link

Thanks! Can you merge this to cutlass to unblock e.g., pytorch/FBGEMM#3560?

@Skylion007
Copy link

Need a nvidia employee to to do that. @eqy might know who to ping.

@jiawenliu64
Copy link

cc. @hwu36

@hwu36
Copy link
Collaborator

hwu36 commented Jan 15, 2025

We are working on this.

@jwfromm
Copy link
Author

jwfromm commented Jan 23, 2025

@ANIKET-SHIVAM I've refactored this PR so that the functionality is built into the existing Row/Col EVT nodes. Can you take another look?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants