You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
zhipeng93
changed the title
[BUG] Hopper groupgemm example fails for (1638, 6144, 3584)
[BUG] Hopper groupgemm example fails for mnk(1638, 6144, 3584)
Oct 18, 2024
Its the can_implement check failing for this problem size.
The layout for C and D is ColumnMajor. So for the TMA based epilogue to work, you need to have at least 128 bits alignment.
M=1638 doesn't satisfy that condition when output is half_t.
For future reference: If you enabled -DCUTLASS_DEBUG_TRACE_LEVEL=2 during CMake, you would have seen something like this: include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp:432 CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.
So what you can do is:
Either have layout for C/D be changed to RowMajor.
Or, try the non-TMA epilogue PtrArrayNoSmemWarpSpecialized (but then you won't have EVT/fusion support, including activations)
Describe the bug
The example code here [1] fails to run mnk=(1638, 6144, 3584) and
Got cutlass error: Invalid status at: 670
.Steps/Code to reproduce bug
Expected behavior
The example code can run.
Environment details (please complete the following information):
Docker, H800
[1] https://github.com/NVIDIA/cutlass/tree/main/examples/57_hopper_grouped_gemm
The text was updated successfully, but these errors were encountered: