Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Hopper groupgemm example fails for mnk(1638, 6144, 3584) #1884

Closed
zhipeng93 opened this issue Oct 18, 2024 · 4 comments
Closed

[BUG] Hopper groupgemm example fails for mnk(1638, 6144, 3584) #1884

zhipeng93 opened this issue Oct 18, 2024 · 4 comments
Assignees
Labels
? - Needs Triage bug Something isn't working

Comments

@zhipeng93
Copy link

zhipeng93 commented Oct 18, 2024

Describe the bug

The example code here [1] fails to run mnk=(1638, 6144, 3584) and Got cutlass error: Invalid status at: 670.

Steps/Code to reproduce bug

cd cutlass/examples/57_hopper_grouped_gemm

nvcc -arch=sm_90a -I ../../include -I ../common/ -I ../../tools/util/include --expt-relaxed-constexpr -DNDEBUG 57_hopper_grouped_gemm.cu

./a.out --m=1638 --n=6144 --k=3584 --groups=5

Expected behavior
The example code can run.

Environment details (please complete the following information):
Docker, H800

[1] https://github.com/NVIDIA/cutlass/tree/main/examples/57_hopper_grouped_gemm

@zhipeng93 zhipeng93 added ? - Needs Triage bug Something isn't working labels Oct 18, 2024
@zhipeng93 zhipeng93 changed the title [BUG] Hopper groupgemm example fails for (1638, 6144, 3584) [BUG] Hopper groupgemm example fails for mnk(1638, 6144, 3584) Oct 18, 2024
@jiawenliu64
Copy link

Hi folks, we also met the same issue in group gemm. Wonder if it's fixed?

cc. @hwu36 @Junkai-Wu

@hwu36
Copy link
Collaborator

hwu36 commented Oct 31, 2024

@ANIKET-SHIVAM

@ANIKET-SHIVAM ANIKET-SHIVAM self-assigned this Oct 31, 2024
@ANIKET-SHIVAM
Copy link
Collaborator

ANIKET-SHIVAM commented Oct 31, 2024

Its the can_implement check failing for this problem size.
The layout for C and D is ColumnMajor. So for the TMA based epilogue to work, you need to have at least 128 bits alignment.
M=1638 doesn't satisfy that condition when output is half_t.
For future reference: If you enabled -DCUTLASS_DEBUG_TRACE_LEVEL=2 during CMake, you would have seen something like this:
include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp:432 CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.
So what you can do is:

  • Either have layout for C/D be changed to RowMajor.
  • Or, try the non-TMA epilogue PtrArrayNoSmemWarpSpecialized (but then you won't have EVT/fusion support, including activations)

@ANIKET-SHIVAM
Copy link
Collaborator

Closing this. Hope the above resolved the issue. Feel free to reopen if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
? - Needs Triage bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants