Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Groupwise scaling along M for FP8 gemm #2037

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

soundOfDestiny
Copy link

Background (copied from #1932)

As we adopt narrower datatypes, traditional scaling methods struggle to maintain accuracy, particularly with 8-bit floating-point types (e.g., e5m2_t, e4m3_t). The typical GEMM operation uses tensorwise scaling with $D = alpha * (A @ B) + beta * C$, but narrower datatypes necessitate more finer-grained scaling techniques. Before we dive deep into groupwise scaling below is a glossary of various scaling methods:

  1. Tensorwise Scaling: Uses a single scaling factor per tensor, applied in the epilogue.
  2. Rowwise Scaling: Uses a row vector for scaling, with dimensions Mx1 for operand A and 1xN for operand B, avoiding the scaling along the reduction dimension. This can also be handled in the epilogue with EpilogueVisitorTree.
  3. Blockwise Scaling (Blockwise Scaling for FP8 #1932): Introduces a 2D scaling tensor, assigning one scaling value per CTA Block. Since this scaling involves the reduction dimension (M, N, K), it must be applied during the mainloop, impacting performance. Blockwise Scaling for FP8 #1932 implements blockwise scaling for CUTLASS F8 GEMM, staging scaling tensors via shared memory, and preparing for future support of groupwise scaling.
  4. Groupwise Scaling (along M in A tensor, this PR): Uses a 2D scaling tensor with multiple scaling values per CTA Block. Scaling granularity is independent of CTA Block configuration, allowing greater flexibility for future implementations.

Summary

As #1932 adds blockwise scaling strategy, this PR is a patch based on #1932 and adds groupwise scaling strategy along M in A tensor. Scaling granularity along M is made independent of CTA Block configuration, however, scaling granularities along N and K are still blockwise (i.e. one scaling value per CTA Block).

This PR restricts scaling granularity along M to a factor of TILE_SHAPE_M in CTA Block configuration, while one can set the GEMM scaling granularity along M to exactly TILE_SHAPE_M (i.e. fallback to blockwise scaling strategy) and call repeat_interleave method on input tensor ScaleA to simulate the situation that scaling granularity is multiplies of TILE_SHAPE_M.

Groupwise Scaling

In this implementation, we load scaling tensors with more elements than #1932 to shared memory since there might be various scaling along M per CTA Block. However, each thread only needs to load at most 2 scale values for A tensor and exactly one scale value for B tensor from shared memory to registers per iteration because WGMMA accumulators of each thread involve only 2 rows in result tensor.

Performance

I haven't observed a performance degradation compared with #1932
blockwise scaling

./64_hopper_fp8_warp_specialized_gemm_with_blockwise_scaling 
  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0112583 ms
  GFLOPS: 95373.3

groupwise scaling (this PR, setting scaling granularity along M to 64)

./64_hopper_fp8_warp_specialized_gemm_with_groupwise_scaling 
  Disposition: Passed
  Problem Size: 1024x512x1024x1
  Rasterization: Heuristic with a maximum CTA swizzle of 1
  Avg runtime: 0.0112435 ms
  GFLOPS: 95499.3

@zhyncs
Copy link

zhyncs commented Jan 17, 2025

Hi @hwu36 This PR is from the DeepSeek Team. Could you help review and merge it? The SGLang team wants to implement block-wise FP8 using CUTLASS for DeepSeek V3. This PR is essential for us. Thanks!

@ll2088
Copy link

ll2088 commented Jan 21, 2025

Hi @hwu36 This PR is from the DeepSeek Team. Could you help review and merge it? The SGLang team wants to implement block-wise FP8 using CUTLASS for DeepSeek V3. This PR is essential for us. Thanks!

Hi @zhyncs zh This PR looks like a example demo,Has the integration with SGLang been done? Could you post a PR about the integration code with SGLang?

@zhyncs
Copy link

zhyncs commented Jan 21, 2025

@ll2088
Our current open source version https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py has been referenced and adapted by other projects, including vLLM and LightLLM.
The version developed based on CUTLASS is currently based on this branch. https://github.com/soundOfDestiny/cutlass/tree/f8_groupwise_scaling_pr_branch.
We hope the official CUTLASS will review and merge this PR soon so we can use the official version. Currently, v3.7.0 includes block-wise but not per-token-per-128-channel support.

@ll2088
Copy link

ll2088 commented Jan 21, 2025

@ll2088 Our current open source version https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py has been referenced and adapted by other projects, including vLLM and LightLLM. The version developed based on CUTLASS is currently based on this branch. https://github.com/soundOfDestiny/cutlass/tree/f8_groupwise_scaling_pr_branch. We hope the official CUTLASS will review and merge this PR soon so we can use the official version. Currently, v3.7.0 includes block-wise but not per-token-per-128-channel support.

The version developed based on CUTLASS in SGLang, Does it PRed? Could you post it here?

@zhyncs
Copy link

zhyncs commented Jan 21, 2025

Not yet.

@soundOfDestiny soundOfDestiny force-pushed the f8_groupwise_scaling_pr_branch branch from 9d997ce to a08ef31 Compare January 21, 2025 06:57
@ll2088
Copy link

ll2088 commented Jan 21, 2025

image
@soundOfDestiny using TileShape = Shape<_1,_128,_128>; why does it not work? compile problem occurs.

@ll2088
Copy link

ll2088 commented Jan 21, 2025

image @soundOfDestiny using TileShape = Shape<_1,_128,_128>; why does it not work? compile problem occurs.

And why does ScaleMsPerTile = 128 not work? @soundOfDestiny

@ll2088
Copy link

ll2088 commented Jan 21, 2025

ad5c27dc5369702a20ba7d80c218083a
51f4cfba6b99089f4beac0af8b411f8e
@zhyncs ScaleMsPerTile=128 is not supported here, the shared memory is not enough.

/workspace/applied-ai/kernels/cuda/cutlass_gemm/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:336 Setting smem size to 234496
/workspace/applied-ai/kernels/cuda/cutlass_gemm/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:343 cudaFuncSetAttribute() returned error: invalid argument
Got cutlass error: Error Internal at: 673

@soundOfDestiny soundOfDestiny force-pushed the f8_groupwise_scaling_pr_branch branch from a08ef31 to 0c08d7c Compare January 21, 2025 14:40
@soundOfDestiny
Copy link
Author

ad5c27dc5369702a20ba7d80c218083a 51f4cfba6b99089f4beac0af8b411f8e @zhyncs ScaleMsPerTile=128 is not supported here, the shared memory is not enough.

/workspace/applied-ai/kernels/cuda/cutlass_gemm/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:336 Setting smem size to 234496 /workspace/applied-ai/kernels/cuda/cutlass_gemm/cutlass/include/cutlass/gemm/device/gemm_universal_adapter.h:343 cudaFuncSetAttribute() returned error: invalid argument Got cutlass error: Error Internal at: 673

The issue of incorrect calculation of shared memory size has appeared since #1932.
It has been fixed in latest commit.

@soundOfDestiny soundOfDestiny force-pushed the f8_groupwise_scaling_pr_branch branch from 0c08d7c to df73dd0 Compare January 21, 2025 14:50
@soundOfDestiny soundOfDestiny force-pushed the f8_groupwise_scaling_pr_branch branch from df73dd0 to 3197c81 Compare January 21, 2025 14:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants