-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Groupwise scaling along M for FP8 gemm #2037
base: main
Are you sure you want to change the base?
Groupwise scaling along M for FP8 gemm #2037
Conversation
Hi @hwu36 This PR is from the DeepSeek Team. Could you help review and merge it? The SGLang team wants to implement block-wise FP8 using CUTLASS for DeepSeek V3. This PR is essential for us. Thanks! |
Hi @zhyncs zh This PR looks like a example demo,Has the integration with SGLang been done? Could you post a PR about the integration code with SGLang? |
@ll2088 |
The version developed based on CUTLASS in SGLang, Does it PRed? Could you post it here? |
Not yet. |
9d997ce
to
a08ef31
Compare
|
And why does ScaleMsPerTile = 128 not work? @soundOfDestiny |
/workspace/applied-ai/kernels/cuda/cutlass_gemm/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:336 Setting smem size to 234496 |
a08ef31
to
0c08d7c
Compare
The issue of incorrect calculation of shared memory size has appeared since #1932. |
0c08d7c
to
df73dd0
Compare
df73dd0
to
3197c81
Compare
Background (copied from #1932)
As we adopt narrower datatypes, traditional scaling methods struggle to maintain accuracy, particularly with 8-bit floating-point types (e.g.,$D = alpha * (A @ B) + beta * C$ , but narrower datatypes necessitate more finer-grained scaling techniques. Before we dive deep into groupwise scaling below is a glossary of various scaling methods:
e5m2_t
,e4m3_t
). The typical GEMM operation uses tensorwise scaling withSummary
As #1932 adds blockwise scaling strategy, this PR is a patch based on #1932 and adds groupwise scaling strategy along M in A tensor. Scaling granularity along M is made independent of CTA Block configuration, however, scaling granularities along N and K are still blockwise (i.e. one scaling value per CTA Block).
This PR restricts scaling granularity along M to a factor of
TILE_SHAPE_M
in CTA Block configuration, while one can set the GEMM scaling granularity along M to exactlyTILE_SHAPE_M
(i.e. fallback to blockwise scaling strategy) and callrepeat_interleave
method on input tensorScaleA
to simulate the situation that scaling granularity is multiplies ofTILE_SHAPE_M
.Groupwise Scaling
In this implementation, we load scaling tensors with more elements than #1932 to shared memory since there might be various scaling along M per CTA Block. However, each thread only needs to load at most 2 scale values for A tensor and exactly one scale value for B tensor from shared memory to registers per iteration because WGMMA accumulators of each thread involve only 2 rows in result tensor.
Performance
I haven't observed a performance degradation compared with #1932
blockwise scaling
groupwise scaling (this PR, setting scaling granularity along M to 64)