From bdd361a2e36c66c2112f75ab84607fc48e804b58 Mon Sep 17 00:00:00 2001
From: DefTruth <31974251+DefTruth@users.noreply.github.com>
Date: Wed, 25 Dec 2024 13:46:35 +0800
Subject: [PATCH] =?UTF-8?q?[FA2]=20fa2/hgemm=20manually=20smem=20swizzle?=
=?UTF-8?q?=F0=9F=8E=89=20(#185)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
* Update flash_attn_mma.py
* Create makefile
* Create README.md
* Update and rename matrix_trans_swizzle.cu to mat_trans_swizzle.cu
* Update hgemm_mma_swizzle.cu
* Update mat_trans_swizzle.cu
* Update and rename flash_attn_mma_swizzle_qkv.cu to flash_attn_mma_share_kv_swizzle.cu
* Create flash_attn_mma_share_qkv_swizzle.cu
* Create flash_attn_mma_split_q_swizzle.cu
* Create flash_attn_mma_split_kv_swizzle.cu
* Create flash_attn_mma_tiling_qk_swizzle.cu
* Create flash_attn_mma_tiling_qkv_swizzle.cu
* Update flash_attn_mma_share_qkv_swizzle.cu
* Update flash_attn_mma_split_kv_swizzle.cu
* Update flash_attn_mma_split_q_swizzle.cu
* Update flash_attn_mma_tiling_qk_swizzle.cu
* Update flash_attn_mma_tiling_qkv_swizzle.cu
* Update README.md
* Update hgemm_mma_swizzle.cu
* Update makefile
* Update README.md
* Update README.md
* Update mat_trans_swizzle.cu
* Update makefile
* Update hgemm_mma_swizzle.cu
* Update hgemm_mma_swizzle.cu
* Update README.md
* Update hgemm_mma_stage.cu
* Update hgemm_mma.cu
* Update makefile
* Update utils.h
* Create mma_simple_swizzle.cu
* Update makefile
* Update mma_simple_swizzle.cu
* Update hgemm_mma_swizzle.cu
* Update makefile
* Update utils.py
* Update makefile
* Create hgemm_mma_stage_swizzle.cu
* Update hgemm.py
* Update hgemm.cc
* Update mat_trans_swizzle.cu
* Update flash_attn_mma_tiling_qk_swizzle.cu
* Update flash_attn.cc
* Update flash_attn_mma.py
* Update flash_attn_mma.py
* Update flash_attn_mma.py
* Update flash_attn_mma_tiling_qk_swizzle.cu
* Update flash_attn_mma_tiling_qk_swizzle.cu
* Update flash_attn_mma_share_kv_swizzle.cu
* Update README.md
* Update README.md
* Create print_swizzle_layout.py
* Update flash_attn_mma_tiling_qk_swizzle.cu
* Update flash_attn_mma_share_kv_swizzle.cu
* Update README.md
* Update hgemm_mma_stage_swizzle.cu
* Update README.md
* Update README.md
* Update README.md
* Update mma_simple_swizzle.cu
* Create print_swizzle_layout.py
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
* Update README.md
---
README.md | 64 +-
kernels/flash-attn/README.md | 2 +-
kernels/flash-attn/flash_attn_mma.py | 15 +-
.../mma/flash_attn_mma_share_kv_swizzle.cu | 934 ++++++++++++++
...cu => flash_attn_mma_share_qkv_swizzle.cu} | 4 +-
.../mma/flash_attn_mma_split_kv_swizzle.cu | 2 +
.../mma/flash_attn_mma_split_q_swizzle.cu | 2 +
.../mma/flash_attn_mma_tiling_qk_swizzle.cu | 1085 +++++++++++++++++
.../mma/flash_attn_mma_tiling_qkv_swizzle.cu | 2 +
kernels/flash-attn/pybind/flash_attn.cc | 8 +
.../flash-attn/tools/print_swizzle_layout.py | 46 +
kernels/hgemm/README.md | 5 +-
kernels/hgemm/hgemm.py | 10 +-
kernels/hgemm/makefile | 7 +
kernels/hgemm/mma/hgemm_mma.cu | 7 +-
kernels/hgemm/mma/hgemm_mma_stage.cu | 100 +-
kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu | 853 +++++++++++++
kernels/hgemm/pybind/hgemm.cc | 4 +
kernels/hgemm/tools/utils.py | 1 +
kernels/hgemm/utils/utils.h | 8 +-
kernels/swizzle/README.md | 125 ++
kernels/swizzle/hgemm_mma_swizzle.cu | 427 ++++++-
kernels/swizzle/makefile | 17 +
kernels/swizzle/mat_trans_swizzle.cu | 108 ++
kernels/swizzle/matrix_trans_swizzle.cu | 35 -
kernels/swizzle/mma_simple_swizzle.cu | 202 +++
kernels/swizzle/print_swizzle_layout.py | 46 +
27 files changed, 3940 insertions(+), 179 deletions(-)
create mode 100644 kernels/flash-attn/mma/flash_attn_mma_share_kv_swizzle.cu
rename kernels/flash-attn/mma/{flash_attn_mma_swizzle_qkv.cu => flash_attn_mma_share_qkv_swizzle.cu} (98%)
create mode 100644 kernels/flash-attn/mma/flash_attn_mma_split_kv_swizzle.cu
create mode 100644 kernels/flash-attn/mma/flash_attn_mma_split_q_swizzle.cu
create mode 100644 kernels/flash-attn/mma/flash_attn_mma_tiling_qk_swizzle.cu
create mode 100644 kernels/flash-attn/mma/flash_attn_mma_tiling_qkv_swizzle.cu
create mode 100644 kernels/flash-attn/tools/print_swizzle_layout.py
create mode 100644 kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu
create mode 100644 kernels/swizzle/README.md
create mode 100644 kernels/swizzle/makefile
create mode 100644 kernels/swizzle/mat_trans_swizzle.cu
delete mode 100644 kernels/swizzle/matrix_trans_swizzle.cu
create mode 100644 kernels/swizzle/mma_simple_swizzle.cu
create mode 100644 kernels/swizzle/print_swizzle_layout.py
diff --git a/README.md b/README.md
index ed849bcf..09f48c0e 100644
--- a/README.md
+++ b/README.md
@@ -35,9 +35,9 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d
|✔️|✔️|✔️|✔️|
|Copy Async|Tile MMA (More Threads)|Tile Warp (More Values)|Multi Stages (2/3/4)|
|✔️|✔️|✔️|✔️|
-|Reg Double Buffers|Block Swizzle|Warp Swizzle|SMEM Swizzle (CuTe)|
+|Reg Double Buffers|Block Swizzle|Warp Swizzle|SMEM Swizzle (CuTe/MMA)|
|✔️|✔️|✔️|✔️|
-|Collective Store (Warp Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32|
+|Collective Store (Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32|
|✔️|✔️|✔️|✔️|
@@ -48,7 +48,7 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
|Tensor Cores|Loop over Seqlen/Headdim |Tile Block (Br, Bc)|MMA (m16n8k16)|
|:---:|:---:|:---:|:---:|
|✔️|✔️|✔️|✔️|
-|Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads)|
+|Pack LDST (128 bits)|SMEM **Swizzle**/Padding |Copy Async|Tile MMA (More Threads)|
|✔️|✔️|✔️|✔️|
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
@@ -160,7 +160,6 @@ The kernels listed here will guide you through a step-by-step progression, rangi
|📖 CUDA Kernel| 📖 Elem DType| 📖 Acc DType| 📖 Docs | 📖 Level |
|:---|:---|:---|:---|:---|
-| ✔️ [nsys/ncu(timeline/ptx/sass)](./kernels/nvidia-nsight/)|/|/|[link](./kernels/nvidia-nsight/)|⭐️|
| ✔️ [elementwise_f32](./kernels/elementwise/elementwise.cu)|f32|/|[link](./kernels/elementwise/)|⭐️|
| ✔️ [elementwise_f32x4](./kernels/elementwise/elementwise.cu)|f32|/|[link](./kernels/elementwise/)|⭐️|
| ✔️ [elementwise_f16](./kernels/elementwise/elementwise.cu)|f16|/|[link](./kernels/elementwise/)|⭐️|
@@ -205,27 +204,27 @@ The kernels listed here will guide you through a step-by-step progression, rangi
| ✔️ [mat_trans_f32_diagonal2d](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️|
| ✔️ [mat_trans_f32x4_col2row{2d}](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️|
| ✔️ [mat_trans_f32x4_row2col{2d}](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️|
-| ✔️ [warp_reduce_[all]](./kernels/reduce/block_all_reduce.cu)|all|all|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_f32_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_f32x4_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_f16_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_f16_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_f16x2_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_f16x2_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_f16x8_pack_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_f16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_bf16_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_bf16_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_bf16x2_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_bf16x2_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_bf16x8_pack_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_bf16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_fp8_e4m3_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_fp8_e5m2_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_fp8_e4m3x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_fp8_e5m2x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_i8_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
-| ✔️ [reduce_i8x16_pack_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [warp_reduce_{all}](./kernels/reduce/block_all_reduce.cu)|all|all|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_f32_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_f32x4_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_f16_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_f16_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_f16x2_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_f16x2_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_f16x8_pack_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_f16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_bf16_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_bf16_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_bf16x2_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_bf16x2_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_bf16x8_pack_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_bf16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_fp8_e4m3_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️|
+| ✔️ [block_all_reduce_fp8_e5m2_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️|
+| ✔️ [block_all_reduce_fp8_e4m3x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️|
+| ✔️ [block_all_reduce_fp8_e5m2x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️|
+| ✔️ [block_all_reduce_i8_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
+| ✔️ [block_all_reduce_i8x16_pack_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️|
| ✔️ [dot_product_f32](./kernels/dot-product/dot_product.cu)|f32|f32|[link](./kernels/dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f32x4](./kernels/dot-product/dot_product.cu)|f32|f32|[link](./kernels/dot-product/)|⭐️⭐️|
| ✔️ [dot_product_f16_f32](./kernels/dot-product/dot_product.cu)|f16|f32|[link](./kernels/dot-product/)|⭐️⭐️|
@@ -262,7 +261,8 @@ The kernels listed here will guide you through a step-by-step progression, rangi
| ✔️ [rms_norm_f16x8_pack_f32](./kernels/rms-norm/rms_norm.cu)|f16|f32|[link](./kernels/rms-norm/)|⭐️⭐️|
| ✔️ [rms_norm_f16_f32](./kernels/rms-norm/rms_norm.cu)|f16|f32|[link](./kernels/rms-norm/)|⭐️⭐️|
| ✔️ [nms_f32](./kernels/nms/nms.cu)|f32|/|[link](./kernels/nms)|⭐️⭐️|
-| ✔️ [notes v1(deprecated)](./kernels/notes-v1.cu)|f32|f32|/|⭐️|
+| ✔️ [notes v1(deprecated)](./kernels/notes-v1.cu)|f32|f32|/|⭐️⭐️|
+| ✔️ [How to profile with nsys/ncu(timeline/ptx/sass)](./kernels/nvidia-nsight/)|/|/|[link](./kernels/nvidia-nsight/)|⭐️⭐️|
### 📚 Hard ⭐⭐⭐️ ([©️back👆🏻](#cuda-kernel))
@@ -284,7 +284,7 @@ The kernels listed here will guide you through a step-by-step progression, rangi
| ✔️ [sgemm_t_8x8_sliced_k16...dbuf](./kernels/sgemm/sgemm_async.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_t_8x8_sliced_k16...async](./kernels/sgemm/sgemm_async.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [sgemm_wmma_m16n16k8...stages*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
-| ✔️ [sgemm_wmma_m16n16k8...swizzle*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
+| ✔️ [sgemm_wmma_m16n16k8...swizzle{+block}*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_naive_f16](./kernels/hgemm/naive/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️|
| ✔️ [hgemm_sliced_k_f16](./kernels/hgemm/naive/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_t_8x8_sliced_k_f16x4](./kernels/hgemm/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
@@ -299,12 +299,13 @@ The kernels listed here will guide you through a step-by-step progression, rangi
| ✔️ [hgemm_wmma_m16n16k16...dbuf*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m32n8k16....dbuf*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_wmma_m16n16k16...stages*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
-| ✔️ [hgemm_wmma_m16n16k16...swizzle*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
+| ✔️ [hgemm_wmma_m16n16k16...swizzle{+block}*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_m16n8k16...naive*](./kernels/hgemm/mma/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_m16n8k16...mma2x4*](./kernels/hgemm/mma/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_m16n8k16...stages*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
-| ✔️ [hgemm_mma_m16n8k16...swizzle*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
-| ✔️ [hgemm_mma_stages{swizzle}...cute*](./kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
+| ✔️ [hgemm_mma_m16n8k16...swizzle{+block}*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
+| ✔️ [hgemm_mma_m16n8k16...swizzle{+smem}*](./kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
+| ✔️ [hgemm_mma_stages_swizzle{+smem}...cute*](./kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️|
| ✔️ [hgemm_mma_cublas*](./kernels/hgemm/cublas/hgemm_cublas.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️|
### 📚 Hard+ ⭐️⭐️⭐️⭐️ & Hard++ ⭐️⭐️⭐️⭐️⭐️ ([©️back👆🏻](#cuda-kernel))
@@ -318,11 +319,14 @@ The kernels listed here will guide you through a step-by-step progression, rangi
| ✔️ [flash_attn_mma_stages...shared_kv*](./kernels/flash-attn/mma/flash_attn_mma_share_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma_stages...shared_qkv*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ✔️ [flash_attn_mma_stages...tiling_qk*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
+| ✔️ [flash_attn_mma...tiling_qk_swizzle{+smem}*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk_swizzle.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ? [flash_attn_mma_stages_split_kv{f32}*](./kernels/flash-attn/mma/flash_attn_mma_split_kv_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ? [flash_attn_mma_stages_split_q{f32}*](./kernels/flash-attn/mma/flash_attn_mma_split_q_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️|
| ? [flash_attn_mma_stages...shared_kv{f32}*](./kernels/flash-attn/mma/flash_attn_mma_share_kv_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ? [flash_attn_mma_stages...shared_qkv{f32}*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
| ? [flash_attn_mma_stages...tiling_qk{f32}*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️|
+| ✔️ [How to implement MMA smem swizzle*](./kernels/swizzle/mma_simple_swizzle.cu)|f16|f16|[link](./kernels/swizzle)|⭐️⭐️⭐️⭐️|
+
## 📖 博客目录
diff --git a/kernels/flash-attn/README.md b/kernels/flash-attn/README.md
index 20c3ee42..171849ff 100644
--- a/kernels/flash-attn/README.md
+++ b/kernels/flash-attn/README.md
@@ -5,7 +5,7 @@
|Tensor Cores|Loop over Seqlen/HeadDim |Tile Block (Br, Bc)|MMA (m16n8k16)|
|:---:|:---:|:---:|:---:|
|✔️|✔️|✔️|✔️|
-|Pack LDST (pack 128 bits)|SMEM Padding|Copy Async (cp.async.cg/ca)|Tile MMA (More Threads)
+|Pack LDST (pack 128 bits)|SMEM **Swizzle**/Padding |Copy Async (cp.async.cg/ca)|Tile MMA (More Threads)
|✔️|✔️|✔️|✔️|
|Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shfl & Reg Reuse)|**Split KV/Q**|
|✔️|✔️|✔️|✔️|
diff --git a/kernels/flash-attn/flash_attn_mma.py b/kernels/flash-attn/flash_attn_mma.py
index ebf3ef71..369cd9ec 100644
--- a/kernels/flash-attn/flash_attn_mma.py
+++ b/kernels/flash-attn/flash_attn_mma.py
@@ -83,6 +83,7 @@ def get_args():
'./mma/flash_attn_mma_share_kv.cu',
'./mma/flash_attn_mma_share_qkv.cu',
'./mma/flash_attn_mma_tiling_qk.cu',
+ './mma/flash_attn_mma_tiling_qk_swizzle.cu',
'./pybind/flash_attn.cc'
],
extra_cuda_cflags=[
@@ -218,11 +219,11 @@ def run_benchmark(perf_func: callable,
else:
improve = 0
MAX_TFLOPS = TFLOPS
- print(f"{out_info:>32}: {out_val}, time:{mean_time:<.6f}ms, "
+ print(f"{out_info:>38}: {out_val}, time:{mean_time:<.6f}ms, "
f"TFLOPS:{TFLOPS:<6.2f}(+{improve:.2f}%)")
else:
- if not only_show_improved or "flash" in tag:
- print(f"{out_info:>32}: {out_val}, time:{mean_time:<.6f}ms, "
+ if not only_show_improved or "flash" in tag or "sdpa" in tag:
+ print(f"{out_info:>38}: {out_val}, time:{mean_time:<.6f}ms, "
f"TFLOPS:{TFLOPS:<6.2f}")
if show_matrix: print(out)
@@ -296,7 +297,7 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
diff = torch.abs(out_flash_or_sdpa.float() - out_mma.float())
all_close = str(torch.allclose(out_flash_or_sdpa.float(), out_mma.float(), atol=1e-2))
pretty_print_line(
- f"{true_tag} vs {tag:<18}, all close: {all_close:<6}, "
+ f"{true_tag} vs {tag:<22}, all close: {all_close:<6}, "
f"max diff: {diff.max().item():.6f}, min diff: {diff.min().item():.6f}, "
f"mean diff: {diff.mean().item():.6f}"
)
@@ -340,6 +341,8 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, k, v, "mma(split-q+share-kv+stage2)", o, stages=2)
out_mma_tiling_qk1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk, q, k, v, "mma(split-q+tiling-qk+stage1)", o, stages=1)
out_mma_tiling_qk2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk, q, k, v, "mma(split-q+tiling-qk+stage2)", o, stages=2)
+ out_mma_tiling_qk_sw1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk_swizzle, q, k, v, "mma(split-q+tiling-qk+swizzle+stage1)", o, stages=1)
+ out_mma_tiling_qk_sw2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk_swizzle, q, k, v, "mma(split-q+tiling-qk+swizzle+stage2)", o, stages=2)
if D <= 256:
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
if args.run_torch_sdpa:
@@ -360,9 +363,13 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
check_all_close(out_flash, out_mma_share_qkv2, "out_mma_share_qkv2", args.check_all)
check_all_close(out_flash, out_mma_tiling_qk1, "out_mma_tiling_qk1", args.check_all)
check_all_close(out_flash, out_mma_tiling_qk2, "out_mma_tiling_qk2", args.check_all)
+ check_all_close(out_flash, out_mma_tiling_qk_sw1, "out_mma_tiling_qk_sw1", args.check_all)
+ check_all_close(out_flash, out_mma_tiling_qk_sw2, "out_mma_tiling_qk_sw2", args.check_all)
pretty_print_line()
elif args.run_torch_sdpa:
pretty_print_line()
check_all_close(out_sdpa, out_mma_tiling_qk1, "out_mma_tiling_qk1", args.check_all, False)
check_all_close(out_sdpa, out_mma_tiling_qk2, "out_mma_tiling_qk2", args.check_all, False)
+ check_all_close(out_sdpa, out_mma_tiling_qk_sw1, "out_mma_tiling_qk_sw1", args.check_all, False)
+ check_all_close(out_sdpa, out_mma_tiling_qk_sw2, "out_mma_tiling_qk_sw2", args.check_all, False)
pretty_print_line()
diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_kv_swizzle.cu b/kernels/flash-attn/mma/flash_attn_mma_share_kv_swizzle.cu
new file mode 100644
index 00000000..10730cb9
--- /dev/null
+++ b/kernels/flash-attn/mma/flash_attn_mma_share_kv_swizzle.cu
@@ -0,0 +1,934 @@
+#include "utils.h"
+
+// Write FlashAttention-2 from scratch using Tensor Cores with MMA PTX instruction.
+// The input is Q,K,V, 4D tensor with shape [batch_size, num_heads, seq_len, head_dim].
+// The output is O, a 4D tensor with shape [batch_size, num_heads, seq_len, head_dim].
+
+// The FlashAttention-2 algorithm is described in the following paper:
+// https://arxiv.org/pdf/2307.08691
+
+// Q,K,V,O: [batch_size, num_heads, seq_len, head_dim], [B,H,N,d]
+// each block processes Q_tile with shape [Br,d] and full K,V with shape [N,d]
+
+// Split Q across MMA(Warps) and keep access KV for all MMA(Warps),
+// in order to reduce the comm between warps via smem and warp shuffle.
+
+// MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
+// | 64x64 | warp_KV 0 |
+// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+
+// MMA = m16n8k16, Br=16x8=128, Bc=8x16=128, layout: 8 warps
+// | 128x128 | warp_KV 0 |
+// | warp_QP 0 | MMA 0 ... MMA 0 (x16) |
+// | warp_QP 1 | MMA 1 ... MMA 1 (x16) |
+// | warp_QP 2 | MMA 2 ... MMA 2 (x16) |
+// | warp_QP 3 | MMA 3 ... MMA 3 (x16) |
+// | warp_QP 4 | MMA 4 ... MMA 4 (x16) |
+// | warp_QP 5 | MMA 5 ... MMA 5 (x16) |
+// | warp_QP 6 | MMA 6 ... MMA 6 (x16) |
+// | warp_QP 7 | MMA 7 ... MMA 7 (x16) |
+
+// MMA = m16n8k16, Br=16x8=128, Bc=8x8=64, layout: 8 warps
+// | 128x64 | warp_KV 0 |
+// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+// | warp_QP 4 | MMA 4 ... MMA 4 (x8) |
+// | warp_QP 5 | MMA 5 ... MMA 5 (x8) |
+// | warp_QP 6 | MMA 6 ... MMA 6 (x8) |
+// | warp_QP 7 | MMA 7 ... MMA 7 (x8) |
+
+// Manually apply SMEM swizzling instead of padding in
+// Split-Q kernels to reduce bank conflicts.
+
+// i: row index; j: col index.
+// e.g kColStride = 64, kStep = 8 -> load 8 half as 128 bits memory issue.
+template
+static __device__ __forceinline__ int swizzle_permuted_j(int i, int j) {
+ // -------------------------------------------
+ // --------------swizzle layout---------------
+ // -------------col 0~64, step 8--------------
+ // -------------------------------------------
+ // | row 0 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // | row 1 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // | row 2 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // | row 3 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // -------------------------------------------
+ // | row 4 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // | row 5 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // | row 6 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // | row 7 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // -------------------------------------------
+ // | row 8 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // | row 9 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // | row 10 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // | row 11 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // -------------------------------------------
+ // | row 12 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // | row 13 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // | row 14 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // | row 15 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // -------------------------------------------
+ // swizzle: ((int(j / kStep) ^ int(i / 4)) % int(kColStride / kStep)) * kStep;
+ static_assert(kStep == 4 || kStep == 8, "kStep must be 8 or 4.");
+ static_assert(kColStride % kStep == 0, "kColStride must be multiple of kStep.");
+ if constexpr (kStep == 8) {
+ return (((j >> 3) ^ (i >> 2)) % (kColStride >> 3)) << 3;
+ } else {
+ static_assert(kStep == 4);
+ return (((j >> 2) ^ (i >> 2)) % (kColStride >> 2)) << 2;
+ }
+}
+
+template
+static __device__ __forceinline__ int swizzle_permuted_Q_j(int i, int j) {
+ return swizzle_permuted_j(i, j);
+}
+
+template
+static __device__ __forceinline__ int swizzle_permuted_K_j(int i, int j) {
+ return swizzle_permuted_j(i, j);
+}
+
+template
+static __device__ __forceinline__ int swizzle_permuted_V_j(int i, int j) {
+ return swizzle_permuted_j(i, j);
+}
+
+template<
+ const int kHeadDim, // Headdim, 32,64,128
+ const int kMmaAtomM, // MMA Atom M, 16
+ const int kMmaAtomN, // MMA Atom N, 8
+ const int kMmaAtomK, // MMA Atom K, 16
+ const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
+ const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
+ const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
+ const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
+ const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M
+ const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N
+ const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M
+ const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
+ const int kStage,
+ const int kPad
+ >
+__global__ void __launch_bounds__(
+ WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK)
+flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q,
+ half* K,
+ half* V,
+ half* O,
+ int QKV_seqlen,
+ int QKV_head) {
+ // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN.
+ // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
+ static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16
+ static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T
+ static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V
+ static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T
+ // kWarpTileHeadDimV: d=8*(1|2|3|4|...) = 8|...|32|64|96|128|..., etc.
+ // e.g, kWarpTileHeadDimV = 8 -> d = 8*8 = 64; 16 -> d = 8*16 = 128.
+ static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == (
+ kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V
+ static_assert(kStage < 3 && kStage > 0);
+ static_assert(kPad >= 0 && kPad % 8 == 0); // 0,8,16
+ constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64
+ constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64
+ static_assert(Br >= Bc); // for shared memory reuse.
+ constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads
+ // Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen.
+ const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K_tile[Bc,d]
+ const float scale = 1.0f / sqrt((float) kHeadDim);
+
+ // grid(div_ceil(QKV_seqlen, Br), QKV_batch * QKV_head), (x,y,z)
+ const int QKV_batch_id = blockIdx.y / QKV_head; // Batch size
+ const int QKV_head_id = blockIdx.y % QKV_head; // Head num
+ const int Q_tile_id = blockIdx.x; // Q tile_id, range [0, Tr]
+ const int O_tile_id = Q_tile_id; // O tile_id, same as Q.
+ const int tid = threadIdx.x; // within block
+ const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block
+ const int lane_id = tid % WARP_SIZE; // 0~31
+ const int warp_QP = warp_id; // 0,1,2,3 or 0~7
+ const int warp_KV = 0; // 0
+ // MMA Layout [Br,Bc]=[64,64], MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
+ // | 64x64 | warp_KV 0 |
+ // | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+ // | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+ // | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+ // | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+ // MMA Layout [Br,Bc]=[128,128], MMA = m16n8k16, Br=16x8=128, Bc=8x16=128, layout: 8 warps
+ // | 128x128 | warp_KV 0 |
+ // | warp_QP 0 | MMA 0 ... MMA 0 (x16) |
+ // | warp_QP 1 | MMA 1 ... MMA 1 (x16) |
+ // | warp_QP 2 | MMA 2 ... MMA 2 (x16) |
+ // | warp_QP 3 | MMA 3 ... MMA 3 (x16) |
+ // | warp_QP 4 | MMA 4 ... MMA 4 (x16) |
+ // | warp_QP 5 | MMA 5 ... MMA 5 (x16) |
+ // | warp_QP 6 | MMA 6 ... MMA 6 (x16) |
+ // | warp_QP 7 | MMA 7 ... MMA 7 (x16) |
+ const int Q_gmem_offset = ((QKV_batch_id * QKV_head * QKV_seqlen * kHeadDim) +
+ (QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d]
+ const int K_gmem_offset = ((QKV_batch_id * QKV_head * QKV_seqlen * kHeadDim) +
+ (QKV_head_id * QKV_seqlen * kHeadDim)); // K [seqlen,d]
+ const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d]
+ const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d]
+
+ // Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 128 threads.
+ int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 2, row 0~64
+ int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kHeadDim / (kNumThreads / Br)); // (tid % 2) * 32, 0,32,...
+ // Mapping K gmem -> tid -> smem, K[Bc,d]=[64 or 128,64], 128 threads.
+ int load_smem_K_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64
+ int load_smem_K_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,...
+ // Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 128 threads.
+ int load_smem_V_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64
+ int load_smem_V_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,...
+ // global Q row of current head for tile [Br,d] per block.
+ int load_gmem_Q_Br = Q_tile_id * Br + load_smem_Q_Br;
+ if (load_gmem_Q_Br >= QKV_seqlen) return;
+ // KV tile gmem load index starts from 0 and increments with
+ // each iteration as we loop over seqlen.
+ int load_gmem_K_Bc_offset = 0;
+ int load_gmem_V_Bc_offset = 0;
+
+ // Shared memory for Q,K,V, we don not need additional smem for O
+ // collective store which perform via registers reuse and warp shuffle.
+ extern __shared__ half smem[];
+ constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M
+ constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // K[Bc,d]
+ half* Q_tile_smem = smem; // 8M/16M
+ half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M
+ half* V_tile_smem = K_tile_smem; // KV shared the same smem
+ // NOTE: KV may shared same smem to reduce smem usage for kStage 1
+ // stage 1, w shared KV smem, Br=Bc=64, d=64: 8M+(8M) =16M, +Pad(2M) = 18M
+ // stage 1, w shared KV smem, Br=Bc=128, d=64: 16M+16M =32M, +Pad(4M) = 36M
+ // stage 1, w shared KV smem, Br=Bc=64, d=128: 16M+16M =32M, +Pad(2M) = 36M
+ // stage 1, w shared KV smem, Br=Bc=64, d=256: 32M+32M =64M, +Pad(2M) = 66M
+ // stage 1, w shared KV smem, Br=64,Bc=32, d=256: 32M+16M =48M, +Pad(2M) = 50M
+ // stage 1, w shared KV smem, Br=128,Bc=16,d=256: 64M+16M =80M, +Pad(2M) = 82M
+
+ uint32_t smem_Q_base_ptr = __cvta_generic_to_shared(Q_tile_smem);
+ uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem);
+ uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem);
+
+ // --------------------- Registers/SMEM for thread block -------------------------
+ // block m_old, l_old, store in lane, use float to keep precision.
+ float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2]
+ float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2]
+ fill_2D_regs(lane_block_row_max_old, -INFINITY);
+ fill_2D_regs(lane_block_row_sum_old, 0.0f);
+
+ // ---------------------- Registers for S=Q@K^T/O=P@V ----------------------------
+ // registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc] and O=P[Br,Bc]@V[Bc,d]=[Br,d].
+ // Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4], e.g R_Q[4][1][4] 16 regs.
+ // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store.
+ // Then we can load Q from smem only once and reuse it for
+ // processes. This will reduce large io-access for Q smem while N is large.
+ // FIXME(DefTruth): why can not get good performance for headdim >= 64 ?
+ // Will enable it untill I have figure out the performance issues.
+ constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8) && (kHeadDim < 64);
+ constexpr bool kCanPrefetchKVg2s = (kStage == 2); // whether prefetch KV g2s.
+ constexpr int kPrefetchKg2sSmemId = 0; // smem id for K g2s, 0.
+ constexpr int kPrefetchVg2sSmemId = kCanPrefetchKVg2s ? 1 : 0; // smem id for V g2s, 1.
+ constexpr int kNumPrefetchQs2r = (kCanPrefetchQs2r) ? (kHeadDim / kMmaAtomK) : 1;
+ uint32_t R_Q[kNumPrefetchQs2r][kWarpTileSeqLenQ][4]; // [4/8/1][1][4]
+ uint32_t R_K[kWarpTileSeqLenK][ 2]; // [8][2]
+ uint32_t R_V[kWarpTileHeadDimV][2]; // [8][2]
+ // registers for current tile_K_seqlen within, [64,64] = S_tile[Br,Bc]
+ // = Q_tile[Br,d] * K[Bc,d], each thread hold 2x32 bits regs.
+ uint32_t R_S[kWarpTileSeqLenQ][kWarpTileSeqLenK][ 2]; // [1][8][2]
+ // registers for tile_K_seqlen O=PV[Br,d]=P@V, [2][2/4][2], 8 or 16 regs.
+ uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2]
+ // registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs.
+ uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2]
+ fill_3D_regs(R_S, 0);
+ fill_3D_regs(R_D, 0);
+ fill_3D_regs(R_O, 0);
+
+ // load Q from gmem -> smem, only load once.
+ {
+ int load_gmem_Q_d = load_smem_Q_d;
+ int load_gmem_Q_addr = (Q_gmem_offset + load_gmem_Q_Br * kHeadDim + load_gmem_Q_d);
+ uint32_t load_smem_Q_ptr = (smem_Q_base_ptr + (
+ load_smem_Q_Br * (kHeadDim + kPad) + load_smem_Q_d) * sizeof(half));
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Br)); i += 8) {
+ CP_ASYNC_CG(load_smem_Q_ptr + i * 2, &Q[load_gmem_Q_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ }
+
+ // : for K^T[d,seqlen] with K^T_tile[d,Bc]
+ // tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc]
+ #pragma unroll 1
+ for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) {
+ // TODO: process last tile_K_seqlen ? pad to multiple of 8.
+
+ // Load K tile from gmem -> smem, always use smem part 0, send g2s
+ // memory issues before Prefetch Q s2r.
+ if constexpr (kCanPrefetchKVg2s) {
+ if (tile_K_seqlen == 0) {
+ load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc;
+ int load_gmem_K_d = load_smem_K_d;
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
+ uint32_t load_smem_K_ptr = (
+ smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
+ CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+
+ // Now, we have to wait curr K tile ready for Q@K^T MMA.
+ CP_ASYNC_WAIT_GROUP(0);
+ __syncthreads();
+ }
+ // : Load V tile async from gmem -> smem 1, before Q@K^T
+ {
+ load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
+ int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc;
+ int load_gmem_V_d = load_smem_V_d;
+ int load_gmem_V_addr = (
+ V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
+ uint32_t load_smem_V_ptr = (
+ smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size +
+ load_smem_V_Bc * (kHeadDim + kPad) +
+ load_smem_V_d) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
+ CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ }
+ } else {
+ load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc;
+ int load_gmem_K_d = load_smem_K_d;
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
+ uint32_t load_smem_K_ptr = (
+ smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
+ CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ // Now, we have to wait curr K tile ready for Q@K^T MMA.
+ CP_ASYNC_WAIT_GROUP(0);
+ __syncthreads();
+ }
+
+ // : Load Q tile from smem -> regs, before Q@K^T.
+ if constexpr (kCanPrefetchQs2r) {
+ // Wait Q ready and let K copy async, then prefetch Q from smem -> regs.
+ // NOTE: we only need to load Q once from smem -> regs, and then reuse it.
+ if (tile_K_seqlen == 0) {
+ CP_ASYNC_WAIT_GROUP(0);
+ __syncthreads();
+
+ #pragma unroll
+ for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K]
+ int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM;
+ int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15
+ int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8
+ uint32_t lane_smem_Q_ptr = (
+ smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) +
+ lane_smem_Q_d) * sizeof(half)
+ );
+ LDMATRIX_X4(R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1],
+ R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3],
+ lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4]
+ }
+ }
+ __syncthreads(); // wait all warps ready.
+ } // end if tile_K_seqlen == 0
+ } // end if kCanPrefetchQs2r
+
+ // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
+ // Matmul with NT layout, Q row major, K^T col major.
+ // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
+ // S_tile[Br,Bc]=Q_tile[Br,d]@K[Bc,d]
+ //
+ fill_3D_regs(R_S, 0);
+ #pragma unroll
+ for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
+ // smem -> reg, load m16k16 smem Q, offset d according tile_K_d.
+ // ldmatrix.x4 for Q_tile_smem.
+ if constexpr (!kCanPrefetchQs2r) {
+ // load Q from smem -> regs in each loop w/o prefetch Q s2r.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K]
+ int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM;
+ int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15
+ int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8
+ uint32_t lane_smem_Q_ptr = (
+ smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) +
+ lane_smem_Q_d) * sizeof(half)
+ );
+ LDMATRIX_X4(R_Q[0][i][0], R_Q[0][i][1], R_Q[0][i][2], R_Q[0][i][3],
+ lane_smem_Q_ptr); // now, R_Q[1][1][4]
+ }
+ }
+
+ // smem -> reg, load k16n8 from smem K, offset d according tile_K_d.
+ // ldmatrix.x2 for K_tile_smem, [Bc,kMmaAtomK] from [Bc,d]=[K,N]
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ // load k16n8 via ldmatrix.x2 from K_tile_smem[Bc,d].
+ // K[Bc,d] with row major means K^T[d,Bc] in col major.
+ int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN;
+ int lane_smem_K_Bc = warp_smem_K_Bc + lane_id % 8; // 0~7
+ int lane_smem_K_d = tile_K_d * kMmaAtomK + ((lane_id / 8) % 2) * 8; // 0,8
+ uint32_t lane_smem_K_ptr = (
+ smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
+ lane_smem_K_Bc * (kHeadDim + kPad) +
+ lane_smem_K_d) * sizeof(half)
+ );
+ LDMATRIX_X2(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K
+ } // end for kWarpTileSeqLenK
+
+ if constexpr (kCanPrefetchQs2r) {
+ // MMA compute
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) {
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ HMMA16816(R_S[i][j][0], R_S[i][j][1],
+ R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1],
+ R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3],
+ R_K[j][0], R_K[j][1],
+ R_S[i][j][0], R_S[i][j][1]);
+ }
+ }
+ } else {
+ // MMA compute
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) {
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ HMMA16816(R_S[i][j][0], R_S[i][j][1],
+ R_Q[0][i][0], R_Q[0][i][1], R_Q[0][i][2], R_Q[0][i][3],
+ R_K[j][0], R_K[j][1],
+ R_S[i][j][0], R_S[i][j][1]);
+ }
+ }
+ }
+ } // end loop over d, S=Q@K^T
+ __syncthreads();
+
+ // : If kCanPrefetchKVg2s is not enable,
+ // we will load V g2s here, before rowmax and rowsum.
+ if constexpr (!kCanPrefetchKVg2s) {
+ load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...)
+ int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc;
+ int load_gmem_V_d = load_smem_V_d;
+ int load_gmem_V_addr = (
+ V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
+ uint32_t load_smem_V_ptr = (
+ smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size +
+ load_smem_V_Bc * (kHeadDim + kPad) +
+ load_smem_V_d) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
+ CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ }
+
+ // : load next K tile from gmem -> smem 0, before P@V.
+ if constexpr (kCanPrefetchKVg2s) {
+ if ((tile_K_seqlen + 1) < Tc) {
+ load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...)
+ int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc;
+ int load_gmem_K_d = load_smem_K_d;
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
+ uint32_t load_smem_K_ptr = (
+ smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size +
+ load_smem_K_Bc * (kHeadDim + kPad) +
+ load_smem_K_d) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) {
+ CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ }
+ }
+
+ // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
+ // | 64x64 | warp_KV 0 |
+ // | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+ // | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+ // | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+ // | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+
+ // Online safe softmax, warp/block reduce max/sum, row wise
+ float lane_row_max_new[kWarpTileSeqLenQ][2]; // [1][2]
+ float lane_row_sum_new[kWarpTileSeqLenQ][2]; // [1][2]
+ fill_2D_regs(lane_row_max_new, -INFINITY);
+ fill_2D_regs(lane_row_sum_new, 0.0f);
+
+ // Row max for [Br,Bc] tile, Thread -> Warp -> Block.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) {
+ // Thread level reduce max across kWarpTileSeqLenK dim, namely Bc.
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+ // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+ // The layout of the fragments held by different threads for C. (m16n8k16)
+ // Row\Col 0 1 2 3 4 5 6 7
+ // 0 T0: {c0, c1} T1: {c0, c1} T2: {c0, c1} T3: {c0, c1}
+ // 1 T4: {c0, c1} T5: {c0, c1} T6: {c0, c1} T7: {c0, c1}
+ // 2 ...
+ // ...
+ // 7 T28: {c0, c1} T29: {c0, c1} T30: {c0, c1} T31: {c0, c1}
+ // 8 T0: {c2, c3} T1: {c2, c3} T2: {c2, c3} T3: {c2, c3}
+ // 9 T4: {c2, c3} T5: {c2, c3} T6: {c2, c3} T7: {c2, c3}
+ // 10 ...
+ // ...
+ // 15 T28: {c2, c3} T29: {c2, c3} T30: {c2, c3} T31: {c2, c3}
+ float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3}
+ // This should be the row max after S = (Q @ K^T) / sqrt(d)
+ float tmp_max_0 = max(t_reg_S_0.x, t_reg_S_0.y) * scale;
+ float tmp_max_1 = max(t_reg_S_1.x, t_reg_S_1.y) * scale;
+ lane_row_max_new[i][0] = max(lane_row_max_new[i][0], tmp_max_0);
+ lane_row_max_new[i][1] = max(lane_row_max_new[i][1], tmp_max_1);
+ } // end for kWarpTileSeqLenK
+
+ // Warp level reduce max, warp_size = 4
+ // Each thread contains the maximum of 2 rows of Br,
+ // and only the values of T0, T4, ..., T28 are used.
+ lane_row_max_new[i][0] = warp_reduce_max(lane_row_max_new[i][0]);
+ lane_row_max_new[i][1] = warp_reduce_max(lane_row_max_new[i][1]);
+ } // end for kWarpTileSeqLenQ
+ __syncthreads();
+
+ // Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) {
+ // Use latest global row max without update.
+ // Br 0, row_id, 0~7, 16~23, 32~39, 48~55;
+ float block_row_max_new_0 = lane_row_max_new[i][0];
+ // Br 1, row_id, 8~15, 24~31, 40~47, 56~63;
+ float block_row_max_new_1 = lane_row_max_new[i][1];
+
+ float block_row_max_old_0 = lane_block_row_max_old[i][0];
+ float block_row_max_old_1 = lane_block_row_max_old[i][1];
+ // Apply m_new = max(m_old, m_new) here.
+ block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0);
+ block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1);
+
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3}
+ // P = Exp(S - m_new), fmaf(x, y, z) = x * y + z;
+ t_reg_S_0.x = __expf(__fmaf_rn(t_reg_S_0.x, scale, - block_row_max_new_0));
+ t_reg_S_0.y = __expf(__fmaf_rn(t_reg_S_0.y, scale, - block_row_max_new_0));
+ t_reg_S_1.x = __expf(__fmaf_rn(t_reg_S_1.x, scale, - block_row_max_new_1));
+ t_reg_S_1.y = __expf(__fmaf_rn(t_reg_S_1.y, scale, - block_row_max_new_1));
+ lane_row_sum_new[i][0] += (t_reg_S_0.x + t_reg_S_0.y);
+ lane_row_sum_new[i][1] += (t_reg_S_1.x + t_reg_S_1.y);
+ // Update R_S for P[Br,Bc] = Exp(S-m), point wise.
+ HALF2(R_S[i][j][0]) = __float22half2_rn(t_reg_S_0);
+ HALF2(R_S[i][j][1]) = __float22half2_rn(t_reg_S_1);
+ } // end for kWarpTileSeqLenK
+
+ // Warp level reduce sum, warp_size = 4
+ lane_row_sum_new[i][0] = warp_reduce_sum(lane_row_sum_new[i][0]);
+ lane_row_sum_new[i][1] = warp_reduce_sum(lane_row_sum_new[i][1]);
+ } // end for kWarpTileSeqLenQ
+ __syncthreads();
+
+ // Compute P[Br,Bc] @ V[Bc,d] = [Br,d] = [64, 64/128], partion Attention.
+ // Here, we have to wait V ready before compute O = P @ V
+ if constexpr (kCanPrefetchKVg2s) {
+ if ((tile_K_seqlen + 1) < Tc) {
+ CP_ASYNC_WAIT_GROUP(1); // we have send V & K g2s, wait V and let K async.
+ } else {
+ CP_ASYNC_WAIT_GROUP(0); // we have only send V g2s.
+ }
+ } else {
+ CP_ASYNC_WAIT_GROUP(0);
+ }
+ __syncthreads();
+
+ // : P[Br,Bc]@V[Bc,d]=[Br,d]=[64,64/128], partion Attention.
+ // Matmul with NN layout: P[Br,Bc] row major, V[Bc,d] row major.
+ // Make sure to clear the states in R_O before MMA for P@V for each step.
+
+ // NOTE: Values for P[Br,Bc] already in R_S registers, can we use these
+ // registers for P(A) matrix directly ? How to do that ?
+ // according to the A matrix layout for MMA m16n8k16 instruction.
+ // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+ // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+ // The layout of the fragments held by different threads for A matrix with .f16.
+ // R\C 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
+ // 0 T0: {a0, a1} T1: {a0, a1} T2: {a0, a1} T3: {a0, a1} T0: {a4, a5} T1: {a4, a5} T2: {a4, a5} T3: {a4, a5}
+ // 1 T4: {a0, a1} T5: {a0, a1} T6: {a0, a1} T7: {a0, a1} T4: {a4, a5} T5: {a4, a5} T6: {a4, a5} T7: {a4, a5}
+ // 2 (dashed arrow pointing right)
+ // ...
+ // 7 T28: {a0, a1} T29: {a0, a1} T30: {a0, a1} T31: {a0, a1} T28: {a4, a5} T29: {a4, a5} T30: {a4, a5} T31: {a4, a5}
+ // 8 T0: {a2, a3} T1: {a2, a3} T2: {a2, a3} T3: {a2, a3} T0: {a6, a7} T1: {a6, a7} T2: {a6, a7} T3: {a6, a7}
+ // 9 T4: {a2, a3} T5: {a2, a3} T6: {a2, a3} T7: {a2, a3} T4: {a6, a7} T5: {a6, a7} T6: {a6, a7} T7: {a6, a7}
+ // 10 (dashed arrow pointing right)
+ // ...
+ // 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7}
+
+ //
+ fill_3D_regs(R_O, 0);
+ #pragma unroll
+ for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) {
+ // Load k16n8 V from smem -> regs, R_KV, ldmatrix.x2.trans.
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) {
+ int warp_smem_V_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; // d, matmaul N
+ int lane_smem_V_Bc = tile_V_Bc * kMmaAtomK + lane_id % 16; // 0~15; Bc, matmul K
+ int lane_smem_V_d = warp_smem_V_d; // 0
+ uint32_t lane_smem_V_ptr = (
+ smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size +
+ lane_smem_V_Bc * (kHeadDim + kPad) +
+ lane_smem_V_d) * sizeof(half)
+ );
+ LDMATRIX_X2_T(R_V[j][0], R_V[j][1], lane_smem_V_ptr); // R_V
+ }
+
+ // For R_S[1][8][2], mapping the layout below of P matrix.
+ // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
+ // | 64x64 | warp_KV 0 |
+ // | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+ // | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+ // | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+ // | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+ // tile_V_Bc = 0, all curr MMAs(0~4) need slice P[:, 0:16], 0, 1; stored in all MMAs.
+ // tile_V_Bc = 1, all curr MMAs(0~4) need slice P[:, 16:32], 2, 3; stored in all MMAs.
+ // tile_V_Bc = 2, all curr MMAs(0~4) need slice P[:, 32:48], 4, 5; stored in all MMAs.
+ // tile_V_Bc = 3, all curr MMAs(0~4) need slice P[:, 48:64], 6, 7; stored in all MMAs.
+ int w = tile_V_Bc * 2; // MMA(Warp) selected, 0, 2, 4, 6
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ...
+ HMMA16816(R_O[i][j][0], R_O[i][j][1],
+ R_S[i][w][0], R_S[i][w][1], R_S[i][w + 1][0], R_S[i][w + 1][1],
+ R_V[j][0], R_V[j][1],
+ R_O[i][j][0], R_O[i][j][1]);
+ }
+ }
+ } // end for V Bc.
+ __syncthreads();
+
+ // Rescale O -> Update row sum Exp -> then, Update row max.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=kWarpTileSeqLenP=1
+ // m = max(m_old, m_new), l = exp(m_old - m) * l_old + l_new (FA2 paper)
+ // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; Br 1, row_id, 8~15, 24~31, 40~47, 56~63
+ float block_row_max_new_0 = lane_row_max_new[i][0];
+ float block_row_max_new_1 = lane_row_max_new[i][1];
+ float block_row_sum_new_0 = lane_row_sum_new[i][0];
+ float block_row_sum_new_1 = lane_row_sum_new[i][1];
+
+ float block_row_max_old_0 = lane_block_row_max_old[i][0];
+ float block_row_max_old_1 = lane_block_row_max_old[i][1];
+ // NOTE: max(-inf, val) = val.
+ block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0);
+ block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1);
+ // Avoid inf value while using m_old for rescaling O.
+ block_row_max_old_0 = (tile_K_seqlen > 0 ? block_row_max_old_0 :
+ block_row_max_new_0);
+ block_row_max_old_1 = (tile_K_seqlen > 0 ? block_row_max_old_1 :
+ block_row_max_new_1);
+
+ // rescale factor for O and l, exp(m_old - m)
+ float rescale_o_factor_0 = __expf(block_row_max_old_0 - block_row_max_new_0);
+ float rescale_o_factor_1 = __expf(block_row_max_old_1 - block_row_max_new_1);
+ // 0. Rescale O: Online rescaling O each tile_K_seqlen step, need m_new, m_old.
+ // m = max(m_old, m_new), O_new[Br,d] = exp(m_old - m) * O_old + P@V
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ...
+ float2 t_reg_O_0 = __half22float2(HALF2(R_O[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_O_1 = __half22float2(HALF2(R_O[i][j][1])); // 8~15 {c2, c3}
+ float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3}
+ // Note that the formula in the FA2 paper is incorrect; here,
+ // the inverse of the exp function should not be taken, as it
+ // would result in an error during rescaling, namely, you have
+ // use exp(m_old - m_new), not 1/(m_old - m_new).
+ // O_new[Br,d] = exp(m_old - m_new) * O_old + P@V
+ t_reg_D_0.x = __fmaf_rn(rescale_o_factor_0, t_reg_D_0.x, t_reg_O_0.x);
+ t_reg_D_0.y = __fmaf_rn(rescale_o_factor_0, t_reg_D_0.y, t_reg_O_0.y);
+ t_reg_D_1.x = __fmaf_rn(rescale_o_factor_1, t_reg_D_1.x, t_reg_O_1.x);
+ t_reg_D_1.y = __fmaf_rn(rescale_o_factor_1, t_reg_D_1.y, t_reg_O_1.y);
+ HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0);
+ HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1);
+ } // end for kWarpTileHeadDimV.
+
+ // Now, we can update m, l after O has been scaled.
+ // 1. First, update block row sum Exp for each lane which
+ // need both m_new and m_old.
+ float block_row_sum_old_0 = lane_block_row_sum_old[i][0];
+ float block_row_sum_old_1 = lane_block_row_sum_old[i][1];
+ // Update l = exp(m_old - m_new) * l_old + row_sum(P).
+ lane_block_row_sum_old[i][0] = (__fmaf_rn(
+ rescale_o_factor_0, block_row_sum_old_0, block_row_sum_new_0));
+ lane_block_row_sum_old[i][1] = (__fmaf_rn(
+ rescale_o_factor_1, block_row_sum_old_1, block_row_sum_new_1));
+ // 2. Then, update block row max for each lane.
+ lane_block_row_max_old[i][0] = block_row_max_new_0;
+ lane_block_row_max_old[i][1] = block_row_max_new_1;
+ }
+
+ if constexpr (kCanPrefetchKVg2s) {
+ if ((tile_K_seqlen + 1) < Tc) {
+ // now, we have to wait next K tile ready in smem.
+ CP_ASYNC_WAIT_GROUP(0);
+ __syncthreads();
+ }
+ }
+
+ } // end loop over N
+ __syncthreads();
+
+ // Finaly, we still have to rescale O once more.
+ // O_output(D) = ( 1/l_final ) * O_final (FA2 paper)
+ // NOTE: Here, we choose to reuse R_O as final output
+ // in order to reduce regs usage.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
+ float rescale_factor_0 = __frcp_rn(lane_block_row_sum_old[i][0]);
+ float rescale_factor_1 = __frcp_rn(lane_block_row_sum_old[i][1]);
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ...
+ float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3}
+ t_reg_D_0.x = rescale_factor_0 * t_reg_D_0.x;
+ t_reg_D_0.y = rescale_factor_0 * t_reg_D_0.y;
+ t_reg_D_1.x = rescale_factor_1 * t_reg_D_1.x;
+ t_reg_D_1.y = rescale_factor_1 * t_reg_D_1.y;
+ HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0);
+ HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1);
+ }
+ }
+
+ // Store O(D): Write O[Br,d] from regs -> gmem, collective store
+ // with reg reuse & warp shuffle. need R_Z[2][4].
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8
+
+ if constexpr (kCanPrefetchQs2r && kNumPrefetchQs2r > 1) {
+ // reuse R_Q[4/8][1][4] for collective store.
+ R_Q[0][0][0] = R_D[i][j][0]; R_Q[1][0][0] = R_D[i][j][1]; // warp_size 4
+ R_Q[0][0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4);
+ R_Q[0][0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4);
+ R_Q[0][0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4);
+ R_Q[1][0][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4);
+ R_Q[1][0][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4);
+ R_Q[1][0][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4);
+ // st.global.v4 128 bits. [Br,d]
+ if (lane_id % 4 == 0) {
+ // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56
+ int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM;
+ int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7
+ // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56)
+ int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN;
+ int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8)
+ int store_gmem_O_addr_0 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d);
+ int store_gmem_O_addr_1 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d);
+ LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Q[0][0][0]);
+ LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Q[1][0][0]);
+ }
+ } else {
+ // we have to use new R_Z regs for collective store.
+ uint32_t R_Z[2][4];
+ R_Z[0][0] = R_D[i][j][0]; R_Z[1][0] = R_D[i][j][1]; // warp_size 4
+ R_Z[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4);
+ R_Z[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4);
+ R_Z[0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4);
+ R_Z[1][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4);
+ R_Z[1][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4);
+ R_Z[1][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4);
+ // st.global.v4 128 bits. [Br,d]
+ if (lane_id % 4 == 0) {
+ // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56
+ int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM;
+ int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7
+ // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56)
+ int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN;
+ int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8)
+ int store_gmem_O_addr_0 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d);
+ int store_gmem_O_addr_1 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d);
+ LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Z[0][0]);
+ LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Z[1][0]);
+ }
+ } // end if kCanPrefetchQs2r
+ } // end for kWarpTileHeadDimV
+ } // end for kWarpTileSeqLenQ
+}
+
+// Launch kernel for flash_attn_mma_stages_split_q
+template
+void launch_flash_attn_mma_stages_split_q_shared_kv(
+ torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) {
+ // Now: fixed tile BrxBc=128x64 for d < 128, 128x16 for d >= 128
+ // TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size.
+ constexpr int kMmaAtomM = 16;
+ constexpr int kMmaAtomN = 8;
+ constexpr int kMmaAtomK = 16;
+ constexpr int kMmaTileSeqLenQ = (kHeadDim < 128) ? 8 : 8;
+ constexpr int kMmaTileSeqLenK = 1;
+ constexpr int kMmaTileSeqLenP = (kHeadDim < 128) ? 8 : 8;
+ constexpr int kMmaTileHeadDimV = 1;
+ constexpr int kWarpTileSeqLenQ = 1;
+ constexpr int kWarpTileSeqLenK = (kHeadDim < 128) ? 8 : 2;
+ constexpr int kWarpTileSeqLenP = 1;
+ constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); // 8,16,32,....
+ constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64
+ constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64
+ constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads
+ constexpr int kPad = 8;
+
+ // static int kMaxSramPerBlock;
+ // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0);
+ // Calculate SRAM size needed per block, Q,K/V smem size, KV shared the same smem.
+ constexpr int KV_tile_size = (Bc * (kHeadDim + kPad));
+ const int smem_max_size = ((Br * (kHeadDim + kPad)) +
+ (kStage * KV_tile_size)) * sizeof(half);
+
+ const int QKV_batch = Q.size(0);
+ const int QKV_head = Q.size(1);
+ const int QKV_seqlen = Q.size(2); // QKV_seqlen
+ assert(QKV_seqlen % max(Br, Bc) == 0); // multiple of max(Br, Bc)
+
+ // TODO: How to apply block swizzle to improve L2 Cache hit rate?
+ // NOTE: reorder (B,H,Tr) -> (Tr,B*H) seems can improve L2 Cache hit rate.
+ // This might be because SM schedules blocks starting from the x-dimension.
+ // Placing Tr at the forefront ensures that identical KV pairs are placed
+ // in consecutive scheduling queues, thereby improving L2 Cache hit rates.
+ // Tr(=N/Br), batch_size x num_heads
+ dim3 grid(div_ceil(QKV_seqlen, Br), QKV_batch * QKV_head);
+ dim3 block(kNumThreads); // 4/8 warps per block
+
+ cudaFuncSetAttribute(
+ flash_attn_mma_stages_split_q_shared_kv_kernel<
+ kHeadDim,
+ kMmaAtomM,
+ kMmaAtomN,
+ kMmaAtomK,
+ kMmaTileSeqLenQ,
+ kMmaTileSeqLenK,
+ kMmaTileSeqLenP,
+ kMmaTileHeadDimV,
+ kWarpTileSeqLenQ,
+ kWarpTileSeqLenK,
+ kWarpTileSeqLenP,
+ kWarpTileHeadDimV,
+ kStage,
+ kPad
+ >,
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
+ // kMaxSramPerBlock
+ 98304
+ );
+
+ flash_attn_mma_stages_split_q_shared_kv_kernel<
+ kHeadDim,
+ kMmaAtomM,
+ kMmaAtomN,
+ kMmaAtomK,
+ kMmaTileSeqLenQ,
+ kMmaTileSeqLenK,
+ kMmaTileSeqLenP,
+ kMmaTileHeadDimV,
+ kWarpTileSeqLenQ,
+ kWarpTileSeqLenK,
+ kWarpTileSeqLenP,
+ kWarpTileHeadDimV,
+ kStage,
+ kPad
+ ><<>>(
+ reinterpret_cast(Q.data_ptr()),
+ reinterpret_cast(K.data_ptr()),
+ reinterpret_cast(V.data_ptr()),
+ reinterpret_cast(O.data_ptr()),
+ QKV_seqlen,
+ QKV_head
+ );
+}
+
+void flash_attn_mma_stages_split_q_shared_kv(torch::Tensor Q,
+ torch::Tensor K,
+ torch::Tensor V,
+ torch::Tensor O,
+ int stages) {
+ CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
+ const int d = Q.size(3); // B, H, N, d
+
+ if (stages > 1) {
+ switch (d)
+ {
+ case 32:
+ launch_flash_attn_mma_stages_split_q_shared_kv<32, 2>(Q, K, V, O);
+ break;
+ case 64:
+ launch_flash_attn_mma_stages_split_q_shared_kv<64, 2>(Q, K, V, O);
+ break;
+ case 96:
+ launch_flash_attn_mma_stages_split_q_shared_kv<96, 2>(Q, K, V, O);
+ break;
+ case 128:
+ launch_flash_attn_mma_stages_split_q_shared_kv<128, 2>(Q, K, V, O);
+ break;
+ default:
+ throw std::runtime_error("headdim not support!");
+ break;
+ }
+ } else {
+ switch (d)
+ {
+ case 32:
+ launch_flash_attn_mma_stages_split_q_shared_kv<32, 1>(Q, K, V, O);
+ break;
+ case 64:
+ launch_flash_attn_mma_stages_split_q_shared_kv<64, 1>(Q, K, V, O);
+ break;
+ case 96:
+ launch_flash_attn_mma_stages_split_q_shared_kv<96, 1>(Q, K, V, O);
+ break;
+ case 128:
+ launch_flash_attn_mma_stages_split_q_shared_kv<128, 1>(Q, K, V, O);
+ break;
+ case 256:
+ launch_flash_attn_mma_stages_split_q_shared_kv<256, 1>(Q, K, V, O);
+ break;
+ default:
+ throw std::runtime_error("headdim not support!");
+ break;
+ }
+ }
+}
diff --git a/kernels/flash-attn/mma/flash_attn_mma_swizzle_qkv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_qkv_swizzle.cu
similarity index 98%
rename from kernels/flash-attn/mma/flash_attn_mma_swizzle_qkv.cu
rename to kernels/flash-attn/mma/flash_attn_mma_share_qkv_swizzle.cu
index 16d179ba..af846429 100644
--- a/kernels/flash-attn/mma/flash_attn_mma_swizzle_qkv.cu
+++ b/kernels/flash-attn/mma/flash_attn_mma_share_qkv_swizzle.cu
@@ -1,2 +1,2 @@
-// TODO: Manually apply SMEM swizzling instead of padding in
-// Split-Q kernels to reduce bank conflicts.
+// TODO: Manually apply SMEM swizzling instead of padding in
+// Split-Q kernels to reduce bank conflicts.
diff --git a/kernels/flash-attn/mma/flash_attn_mma_split_kv_swizzle.cu b/kernels/flash-attn/mma/flash_attn_mma_split_kv_swizzle.cu
new file mode 100644
index 00000000..af846429
--- /dev/null
+++ b/kernels/flash-attn/mma/flash_attn_mma_split_kv_swizzle.cu
@@ -0,0 +1,2 @@
+// TODO: Manually apply SMEM swizzling instead of padding in
+// Split-Q kernels to reduce bank conflicts.
diff --git a/kernels/flash-attn/mma/flash_attn_mma_split_q_swizzle.cu b/kernels/flash-attn/mma/flash_attn_mma_split_q_swizzle.cu
new file mode 100644
index 00000000..af846429
--- /dev/null
+++ b/kernels/flash-attn/mma/flash_attn_mma_split_q_swizzle.cu
@@ -0,0 +1,2 @@
+// TODO: Manually apply SMEM swizzling instead of padding in
+// Split-Q kernels to reduce bank conflicts.
diff --git a/kernels/flash-attn/mma/flash_attn_mma_tiling_qk_swizzle.cu b/kernels/flash-attn/mma/flash_attn_mma_tiling_qk_swizzle.cu
new file mode 100644
index 00000000..d73fc339
--- /dev/null
+++ b/kernels/flash-attn/mma/flash_attn_mma_tiling_qk_swizzle.cu
@@ -0,0 +1,1085 @@
+#include "utils.h"
+
+// Write FlashAttention-2 from scratch using Tensor Cores with MMA PTX instruction.
+// The input is Q,K,V, 4D tensor with shape [batch_size, num_heads, seq_len, head_dim].
+// The output is O, a 4D tensor with shape [batch_size, num_heads, seq_len, head_dim].
+
+// The FlashAttention-2 algorithm is described in the following paper:
+// https://arxiv.org/pdf/2307.08691
+
+// Q,K,V,O: [batch_size, num_heads, seq_len, head_dim], [B,H,N,d]
+// each block processes Q_tile with shape [Br,d] and full K,V with shape [N,d]
+
+// Split Q across MMA(Warps) and keep access KV for all MMA(Warps),
+// in order to reduce the comm between warps via smem and warp shuffle.
+
+// MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
+// | 64x64 | warp_KV 0 |
+// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+
+// MMA = m16n8k16, Br=16x8=128, Bc=8x16=128, layout: 8 warps
+// | 128x128 | warp_KV 0 |
+// | warp_QP 0 | MMA 0 ... MMA 0 (x16) |
+// | warp_QP 1 | MMA 1 ... MMA 1 (x16) |
+// | warp_QP 2 | MMA 2 ... MMA 2 (x16) |
+// | warp_QP 3 | MMA 3 ... MMA 3 (x16) |
+// | warp_QP 4 | MMA 4 ... MMA 4 (x16) |
+// | warp_QP 5 | MMA 5 ... MMA 5 (x16) |
+// | warp_QP 6 | MMA 6 ... MMA 6 (x16) |
+// | warp_QP 7 | MMA 7 ... MMA 7 (x16) |
+
+// MMA = m16n8k16, Br=16x8=128, Bc=8x8=64, layout: 8 warps
+// | 128x64 | warp_KV 0 |
+// | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+// | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+// | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+// | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+// | warp_QP 4 | MMA 4 ... MMA 4 (x8) |
+// | warp_QP 5 | MMA 5 ... MMA 5 (x8) |
+// | warp_QP 6 | MMA 6 ... MMA 6 (x8) |
+// | warp_QP 7 | MMA 7 ... MMA 7 (x8) |
+
+// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
+// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
+// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
+// extend D (head dimension) up to 1024. Performance optimizations are ongoing.
+// Stay tuned for updates ~
+
+// Manually apply SMEM swizzling instead of padding in
+// Split-Q kernels to reduce bank conflicts.
+
+// i: row index; j: col index.
+// e.g kColStride = 64, kStep = 8 -> load 8 half as 128 bits memory issue.
+template
+static __device__ __forceinline__ int swizzle_permuted_j(int i, int j) {
+ // -------------------------------------------
+ // --------------swizzle layout---------------
+ // -------------col 0~64, step 8--------------
+ // -------------------------------------------
+ // | row 0 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // | row 1 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // | row 2 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // | row 3 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // -------------------------------------------
+ // | row 4 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // | row 5 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // | row 6 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // | row 7 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // -------------------------------------------
+ // | row 8 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // | row 9 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // | row 10 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // | row 11 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // -------------------------------------------
+ // | row 12 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // | row 13 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // | row 14 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // | row 15 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // -------------------------------------------
+ // swizzle: ((int(j / kStep) ^ int(i / 4)) % int(kColStride / kStep)) * kStep;
+ static_assert(kStep == 4 || kStep == 8, "kStep must be 8 or 4.");
+ static_assert(kColStride % kStep == 0, "kColStride must be multiple of kStep.");
+ if constexpr (kStep == 8) {
+ return (((j >> 3) ^ (i >> 2)) % (kColStride >> 3)) << 3;
+ } else {
+ static_assert(kStep == 4);
+ return (((j >> 2) ^ (i >> 2)) % (kColStride >> 2)) << 2;
+ }
+}
+
+// i: row index; j: col index
+// e.g kColStride = kMmaAtomK = 16, kStep = 8 -> load 8 half as 128 bits memory issue.
+template
+static __device__ __forceinline__ int swizzle_permuted_Q_j(int i, int j) {
+ // -------------------
+ // --swizzle layout---
+ // -col 0~16, step 8--
+ // -------------------
+ // | row 0 | (0, 8) |
+ // | row 1 | (0, 8) |
+ // | row 2 | (0, 8) |
+ // | row 3 | (0, 8) |
+ // -------------------
+ // | row 4 | (8, 0) |
+ // | row 5 | (8, 0) |
+ // | row 6 | (8, 0) |
+ // | row 7 | (8, 0) |
+ // -------------------
+ // | row 8 | (0, 8) |
+ // | row 9 | (0, 8) |
+ // | row 10 | (0, 8) |
+ // | row 11 | (0, 8) |
+ // -------------------
+ // | row 12 | (8, 0) |
+ // | row 13 | (8, 0) |
+ // | row 14 | (8, 0) |
+ // | row 15 | (8, 0) |
+ // -------------------
+ return swizzle_permuted_j(i, j);
+}
+
+// i: row index; j: col index
+// e.g kColStride = kMmaAtomK = 16, kStep = 8 -> load 8 half as 128 bits memory issue.
+template
+static __device__ __forceinline__ int swizzle_permuted_K_j(int i, int j) {
+ // -------------------
+ // --swizzle layout---
+ // -col 0~16, step 8--
+ // -------------------
+ // | row 0 | (0, 8) |
+ // | row 1 | (0, 8) |
+ // | row 2 | (0, 8) |
+ // | row 3 | (0, 8) |
+ // -------------------
+ // | row 4 | (8, 0) |
+ // | row 5 | (8, 0) |
+ // | row 6 | (8, 0) |
+ // | row 7 | (8, 0) |
+ // -------------------
+ // | row 8 | (0, 8) |
+ // | row 9 | (0, 8) |
+ // | row 10 | (0, 8) |
+ // | row 11 | (0, 8) |
+ // -------------------
+ // | row 12 | (8, 0) |
+ // | row 13 | (8, 0) |
+ // | row 14 | (8, 0) |
+ // | row 15 | (8, 0) |
+ // -------------------
+ return swizzle_permuted_j(i, j);
+}
+
+// i: row index; j: col index.
+// e.g kColStride = kHeadDim = 64, kStep = 8 -> load 8 half as 128 bits memory issue.
+template
+static __device__ __forceinline__ int swizzle_permuted_V_j(int i, int j) {
+ // -------------------------------------------
+ // --------------swizzle layout---------------
+ // -------------col 0~64, step 8--------------
+ // -------------------------------------------
+ // | row 0 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // | row 1 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // | row 2 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // | row 3 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // -------------------------------------------
+ // | row 4 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // | row 5 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // | row 6 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // | row 7 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // -------------------------------------------
+ // | row 8 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // | row 9 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // | row 10 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // | row 11 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // -------------------------------------------
+ // | row 12 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // | row 13 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // | row 14 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // | row 15 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // -------------------------------------------
+ return swizzle_permuted_j(i, j);
+}
+
+template<
+ const int kHeadDim, // Headdim, 32,64,128
+ const int kMmaAtomM, // MMA Atom M, 16
+ const int kMmaAtomN, // MMA Atom N, 8
+ const int kMmaAtomK, // MMA Atom K, 16
+ const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
+ const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)]
+ const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
+ const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ]
+ const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M
+ const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N
+ const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M
+ const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...
+ const int kStage,
+ const int kPadQ,
+ const int kPadK,
+ const int kPadV
+ >
+__global__ void __launch_bounds__(
+ WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK)
+flash_attn_mma_stages_split_q_tiling_qk_swizzle_kernel(half* Q,
+ half* K,
+ half* V,
+ half* O,
+ int QKV_seqlen,
+ int QKV_head) {
+ // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN.
+ // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
+ static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16
+ static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T
+ static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V
+ static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T
+ // kWarpTileHeadDimV: d=8*(1|2|3|4|...) = 8|...|32|64|96|128|..., etc.
+ // e.g, kWarpTileHeadDimV = 8 -> d = 8*8 = 64; 16 -> d = 8*16 = 128.
+ static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == (
+ kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V
+ static_assert(kStage < 3 && kStage > 0);
+ static_assert(kPadQ >= 0 && kPadQ % 8 == 0); // 0,8,16
+ static_assert(kPadK >= 0 && kPadK % 8 == 0); // 0,8,16
+ static_assert(kPadV >= 0 && kPadV % 8 == 0); // 0,8,16
+ constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64
+ constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64
+ static_assert(Br >= Bc); // for shared memory reuse.
+ constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads
+ // Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen.
+ const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K_tile[Bc,d]
+ const float scale = 1.0f / sqrt((float) kHeadDim);
+
+ // grid(div_ceil(QKV_seqlen, Br), QKV_batch * QKV_head), (x,y,z)
+ const int QKV_batch_id = blockIdx.y / QKV_head; // Batch size
+ const int QKV_head_id = blockIdx.y % QKV_head; // Head num
+ const int Q_tile_id = blockIdx.x; // Q tile_id, range [0, Tr]
+ const int O_tile_id = Q_tile_id; // O tile_id, same as Q.
+ const int tid = threadIdx.x; // within block
+ const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block
+ const int lane_id = tid % WARP_SIZE; // 0~31
+ const int warp_QP = warp_id; // 0,1,2,3 or 0~7
+ const int warp_KV = 0; // 0
+ // MMA Layout [Br,Bc]=[64,64], MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
+ // | 64x64 | warp_KV 0 |
+ // | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+ // | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+ // | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+ // | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+ // MMA Layout [Br,Bc]=[128,128], MMA = m16n8k16, Br=16x8=128, Bc=8x16=128, layout: 8 warps
+ // | 128x128 | warp_KV 0 |
+ // | warp_QP 0 | MMA 0 ... MMA 0 (x16) |
+ // | warp_QP 1 | MMA 1 ... MMA 1 (x16) |
+ // | warp_QP 2 | MMA 2 ... MMA 2 (x16) |
+ // | warp_QP 3 | MMA 3 ... MMA 3 (x16) |
+ // | warp_QP 4 | MMA 4 ... MMA 4 (x16) |
+ // | warp_QP 5 | MMA 5 ... MMA 5 (x16) |
+ // | warp_QP 6 | MMA 6 ... MMA 6 (x16) |
+ // | warp_QP 7 | MMA 7 ... MMA 7 (x16) |
+ const int Q_gmem_offset = ((QKV_batch_id * QKV_head * QKV_seqlen * kHeadDim) +
+ (QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d]
+ const int K_gmem_offset = ((QKV_batch_id * QKV_head * QKV_seqlen * kHeadDim) +
+ (QKV_head_id * QKV_seqlen * kHeadDim)); // K [seqlen,d]
+ const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d]
+ const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d]
+
+ // Mapping Q gmem -> tid -> smem, Q[Br,kMmaAtomK]=[64/128,16], 128/256 threads.
+ int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 2, row 0~64
+ int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kMmaAtomK / (kNumThreads / Br)); // (tid % 2) * 8, 0,8,...
+ // Mapping K gmem -> tid -> smem, K[Bc,kMmaAtomK]=[64/128,16], 128 threads.
+ int load_smem_K_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64
+ int load_smem_K_d = (tid % (kNumThreads / Bc)) * (kMmaAtomK / (kNumThreads / Bc)); // (tid % 2) * 8, 0,8,...
+ // TODO: Mapping V gmem -> tid -> smem, V[kMmaAtomK,kMmaAtomN]=[16,64/128], 128 threads.
+ // Mapping V gmem -> tid -> smem, V[kMmaAtomK,d]=[16,64/128], 128 threads.
+ int load_smem_V_Bc = (tid / (kNumThreads / kMmaAtomK)); // kMmaAtomK 16, tid / 8, row 0~15
+ int load_smem_V_d = (tid % (kNumThreads / kMmaAtomK)) * (kHeadDim / (kNumThreads / kMmaAtomK)); // (tid % 8) * 8, 0,8,56...
+ // global Q row of current head for tile [Br,d] per block.
+ int load_gmem_Q_Br = Q_tile_id * Br + load_smem_Q_Br;
+ if (load_gmem_Q_Br >= QKV_seqlen) return;
+ constexpr bool kIsVCanLoadIn128b = (kHeadDim / (kNumThreads / kMmaAtomK)) % 8 == 0;
+ constexpr bool kIsVCanLoadIn64b = (kHeadDim / (kNumThreads / kMmaAtomK)) % 4 == 0;
+ static_assert(kIsVCanLoadIn128b || kIsVCanLoadIn64b, "V can't load in 128b or 64b."); // 32,64,128,192,256,...
+
+ // Shared memory for Q,K,V, we don not need additional smem for O
+ // collective store which perform via registers reuse and warp shuffle.
+ extern __shared__ half smem[];
+ // Split Q + Shared KV SMEM + Fine grain tiling, only need O(1) SRAM complexity.
+ constexpr int Q_tile_size = Br * (kMmaAtomK + kPadQ); // Q[Br,16], 64*16*2=2048 bytes, 2M
+ constexpr int K_tile_size = Bc * (kMmaAtomK + kPadK); // K[Bc,16], 2M
+ constexpr int V_tile_size = kMmaAtomK * (kHeadDim + kPadV); // V[16,d], 2M
+ // TODO: optimize QKV kStage smem store layout as in HGEMM.
+ half* Q_tile_smem = smem; // 8M/16M
+ half* K_tile_smem = Q_tile_smem + kStage * Q_tile_size; // 8M/16M
+ half* V_tile_smem = Q_tile_smem; // V may reuse all Q+K smem after Q@K^T.
+ // stage 1, Q/K smem = 64*16*2/1024=2M, V smem =16*d(64|128|...)*2/1024=2M/4M/..
+ // stage 1, total smem = max(QK_smem, V_smem) = 4M if d <= 64 else V_smem.
+ // stage 1, V shared QK smem, Br=Bc=64, d=64: 2M+(2M) =4M, +Pad(2M) = 6M
+ // stage 1, V shared QK smem, Br=Bc=128, d=64: 4M+4M =8M, +Pad(2M) = 10M
+ // stage 2, V shared QK smem, Br=Bc=64, d=64: 4M+(4M) =8M, +Pad(2M) = 10M
+ // stage 2, V shared QK smem, Br=Bc=128, d=64: 8M+8M =16M, +Pad(2M) = 18M
+ uint32_t smem_Q_base_ptr = __cvta_generic_to_shared(Q_tile_smem);
+ uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem);
+ uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem);
+
+ // --------------------- Registers/SMEM for thread block -------------------------
+ // block m_old, l_old, store in lane, use float to keep precision.
+ float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2]
+ float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2]
+ fill_2D_regs(lane_block_row_max_old, -INFINITY);
+ fill_2D_regs(lane_block_row_sum_old, 0.0f);
+
+ // ---------------------- Registers for S=Q@K^T/O=P@V ----------------------------
+ // registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc] and O=P[Br,Bc]@V[Bc,d]=[Br,d].
+ uint32_t R_Q[kWarpTileSeqLenQ][ 4]; // [1][4]
+ uint32_t R_K[kWarpTileSeqLenK][ 2]; // [8][2]
+ uint32_t R_V[kWarpTileHeadDimV][2]; // [8][2]
+ // NOTE: For R_V[kWarpTileHeadDimV][2], kWarpTileHeadDimV will increase with d.
+ // so, for large d, R_V will need more registers and cause performance down.
+ // We have to find a way to apply MMA level tiling for V(R_V) for large d.
+ // registers for current tile_K_seqlen within, [64,64] = S_tile[Br,Bc]
+ // = Q_tile[Br,d] * K[Bc,d], each thread hold 2x32 bits regs.
+ uint32_t R_S[kWarpTileSeqLenQ][kWarpTileSeqLenK][ 2]; // [1][8][2]
+ // registers for tile_K_seqlen O=PV[Br,d]=P@V, [2][2/4][2], 8 or 16 regs.
+ uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2]
+ // registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs.
+ uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2]
+ fill_3D_regs(R_S, 0);
+ fill_3D_regs(R_D, 0);
+ fill_3D_regs(R_O, 0);
+
+ // : for K^T[d,seqlen] with K^T_tile[d,Bc]
+ // tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc]
+ #pragma unroll 1
+ for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) {
+ // TODO: process last tile_K_seqlen ? pad to multiple of 8.
+
+ // Q/K g2s
+ if constexpr (kStage > 1) {
+ #pragma unroll
+ for (int stage = 0; stage < (kStage - 1); ++stage) {
+ // Q g2s
+ int load_gmem_Q_d = (stage * kMmaAtomK) + load_smem_Q_d; // 0,8
+ int load_gmem_Q_addr = (
+ Q_gmem_offset + load_gmem_Q_Br * kHeadDim + load_gmem_Q_d);
+ uint32_t load_smem_Q_ptr = (
+ smem_Q_base_ptr + (stage * Q_tile_size +
+ load_smem_Q_Br * (kMmaAtomK + kPadQ) +
+ swizzle_permuted_Q_j(
+ load_smem_Q_Br, load_smem_Q_d)) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kMmaAtomK / (kNumThreads / Br)); i += 8) {
+ CP_ASYNC_CG(load_smem_Q_ptr + i * 2, &Q[load_gmem_Q_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+
+ // K g2s
+ int load_gmem_K_Bc = (tile_K_seqlen * Bc) + load_smem_K_Bc; // < seqlen
+ int load_gmem_K_d = (stage * kMmaAtomK) + load_smem_K_d; // K [Bc,16] from [seqlen,d]
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
+ uint32_t load_smem_K_ptr = (
+ smem_K_base_ptr + (stage * K_tile_size +
+ load_smem_K_Bc * (kMmaAtomK + kPadK) +
+ swizzle_permuted_K_j(
+ load_smem_K_Bc, load_smem_K_d)) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kMmaAtomK / (kNumThreads / Bc)); i += 8) {
+ CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ } // end for stage
+
+ CP_ASYNC_WAIT_GROUP(kStage - 2); // s2->0, s3->1, s4->2
+ __syncthreads();
+ } // end if kStage > 1
+
+ // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc]
+ // Matmul with NT layout, Q row major, K^T col major.
+ // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major.
+ // S_tile[Br,Bc]=Q_tile[Br,d]@K[Bc,d]
+ //
+ fill_3D_regs(R_S, 0);
+ #pragma unroll
+ for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) {
+ // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0;
+ int smem_sel = (tile_K_d) % kStage;
+ // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2;
+ int smem_sel_next = (tile_K_d + (kStage - 1)) % kStage;
+
+ // stages for Q, K
+ if constexpr (kStage > 1) {
+ if ((tile_K_d + 1) < (kHeadDim / kMmaAtomK)) {
+ // next Q tile g2s
+ int load_gmem_Q_d = ((tile_K_d + 1) * kMmaAtomK) + load_smem_Q_d;
+ int load_gmem_Q_addr = (
+ Q_gmem_offset + load_gmem_Q_Br * kHeadDim + load_gmem_Q_d);
+ uint32_t load_smem_Q_ptr = (
+ smem_Q_base_ptr + (smem_sel_next * Q_tile_size +
+ load_smem_Q_Br * (kMmaAtomK + kPadQ) +
+ swizzle_permuted_Q_j(
+ load_smem_Q_Br, load_smem_Q_d)) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kMmaAtomK / (kNumThreads / Br)); i += 8) {
+ CP_ASYNC_CG(load_smem_Q_ptr + i * 2, &Q[load_gmem_Q_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+
+ // next K tile g2s
+ int load_gmem_K_Bc = tile_K_seqlen * Bc + load_smem_K_Bc; // < seqlen
+ int load_gmem_K_d = ((tile_K_d + 1) * kMmaAtomK) + load_smem_K_d; // K [Bc,16] from [seqlen,d]
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
+ uint32_t load_smem_K_ptr = (
+ smem_K_base_ptr + (smem_sel_next * K_tile_size +
+ load_smem_K_Bc * (kMmaAtomK + kPadK) +
+ swizzle_permuted_K_j(
+ load_smem_K_Bc, load_smem_K_d)) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kMmaAtomK / (kNumThreads / Bc)); i += 8) {
+ CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ }
+ } else {
+ // sync load curr Q, K g2s
+ // curr Q tile g2s
+ int load_gmem_Q_d = (tile_K_d * kMmaAtomK) + load_smem_Q_d;
+ int load_gmem_Q_addr = (
+ Q_gmem_offset + load_gmem_Q_Br * kHeadDim + load_gmem_Q_d);
+ uint32_t load_smem_Q_ptr = (
+ smem_Q_base_ptr + (smem_sel * Q_tile_size +
+ load_smem_Q_Br * (kMmaAtomK + kPadQ) +
+ swizzle_permuted_Q_j(
+ load_smem_Q_Br, load_smem_Q_d)) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kMmaAtomK / (kNumThreads / Br)); i += 8) {
+ CP_ASYNC_CG(load_smem_Q_ptr + i * 2, &Q[load_gmem_Q_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+
+ // curr K tile g2s
+ int load_gmem_K_Bc = (tile_K_seqlen * Bc) + load_smem_K_Bc; // < seqlen
+ int load_gmem_K_d = (tile_K_d * kMmaAtomK) + load_smem_K_d; // K [Bc,16] from [seqlen,d]
+ int load_gmem_K_addr = (
+ K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d);
+ uint32_t load_smem_K_ptr = (
+ smem_K_base_ptr + (smem_sel * K_tile_size +
+ load_smem_K_Bc * (kMmaAtomK + kPadK) +
+ swizzle_permuted_K_j(
+ load_smem_K_Bc, load_smem_K_d)) * sizeof(half)
+ );
+ #pragma unroll
+ for (int i = 0; i < (kMmaAtomK / (kNumThreads / Bc)); i += 8) {
+ CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16);
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ // Wait curr Q, K tile ready.
+ CP_ASYNC_WAIT_GROUP(0);
+ __syncthreads();
+ } // end if kStage > 1
+
+ // Q s2r
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K]
+ int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM;
+ int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15
+ int lane_smem_Q_d = (lane_id / 16) * 8; // 0,8
+ uint32_t lane_smem_Q_ptr = (
+ smem_Q_base_ptr + (smem_sel * Q_tile_size +
+ lane_smem_Q_Br * (kMmaAtomK + kPadQ) +
+ swizzle_permuted_Q_j(
+ lane_smem_Q_Br, lane_smem_Q_d)) * sizeof(half)
+ );
+ LDMATRIX_X4(R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3],
+ lane_smem_Q_ptr); // now, R_Q[1][4]
+ }
+
+ // smem -> reg, load k16n8 from smem K, offset d according tile_K_d.
+ // ldmatrix.x2 for K_tile_smem, [Bc,kMmaAtomK] from [Bc,d]=[K,N]
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ // load k16n8 via ldmatrix.x2 from K_tile_smem[Bc,d].
+ // K[Bc,d] with row major means K^T[d,Bc] in col major.
+ int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN;
+ int lane_smem_K_Bc = warp_smem_K_Bc + lane_id % 8; // 0~7
+ int lane_smem_K_d = ((lane_id / 8) % 2) * 8; // 0,8
+ uint32_t lane_smem_K_ptr = (
+ smem_K_base_ptr + (smem_sel * K_tile_size +
+ lane_smem_K_Bc * (kMmaAtomK + kPadK) +
+ swizzle_permuted_K_j(
+ lane_smem_K_Bc, lane_smem_K_d)) * sizeof(half)
+ );
+ LDMATRIX_X2(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K
+ } // end for kWarpTileSeqLenK
+ if constexpr (kStage < 2) {
+ // Wait Q, K s2r ready if kStage < 2 in order to avoid
+ // the next Q, K tile g2s overwrite.
+ __syncthreads();
+ }
+
+ // MMA compute
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) {
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ HMMA16816(R_S[i][j][0], R_S[i][j][1],
+ R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3],
+ R_K[j][0], R_K[j][1],
+ R_S[i][j][0], R_S[i][j][1]);
+ }
+ }
+
+ if constexpr (kStage > 1) {
+ // Wait next Q, K tile g2s ready.
+ CP_ASYNC_WAIT_GROUP(kStage - 2);
+ __syncthreads();
+ }
+
+ } // end loop over d, S=Q@K^T
+ __syncthreads();
+
+ // V g2s stages. (reuse Q+K smem) load [16,d] from [Bc,d]
+ if constexpr (kStage > 1) {
+ #pragma unroll
+ for (int stage = 0; stage < (kStage - 1); ++stage) {
+ // V g2s
+ int load_gmem_V_Bc = (
+ (tile_K_seqlen * Bc) + (stage * kMmaAtomK) + load_smem_V_Bc); // 0~15
+ int load_gmem_V_d = load_smem_V_d;
+ int load_gmem_V_addr = (
+ V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
+ uint32_t load_smem_V_ptr = (
+ smem_V_base_ptr + (stage * V_tile_size +
+ load_smem_V_Bc * (kHeadDim + kPadV) +
+ load_smem_V_d) * sizeof(half)
+ );
+ // headdim must be multiple of 32, (kHeadDim/8)%8==0 for 128 bits ld.
+ if constexpr (kIsVCanLoadIn128b) {
+ // 64,128,192,256,...
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / kMmaAtomK)); i += 8) {
+ CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16);
+ }
+ } else {
+ // 32,96,160,224
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / kMmaAtomK)); i += 4) {
+ CP_ASYNC_CA(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 8);
+ }
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ } // end for stage
+ }
+
+ // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
+ // | 64x64 | warp_KV 0 |
+ // | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+ // | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+ // | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+ // | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+
+ // Online safe softmax, warp/block reduce max/sum, row wise
+ float lane_row_max_new[kWarpTileSeqLenQ][2]; // [1][2]
+ float lane_row_sum_new[kWarpTileSeqLenQ][2]; // [1][2]
+ fill_2D_regs(lane_row_max_new, -INFINITY);
+ fill_2D_regs(lane_row_sum_new, 0.0f);
+
+ // Row max for [Br,Bc] tile, Thread -> Warp -> Block.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) {
+ // Thread level reduce max across kWarpTileSeqLenK dim, namely Bc.
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+ // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+ // The layout of the fragments held by different threads for C. (m16n8k16)
+ // Row\Col 0 1 2 3 4 5 6 7
+ // 0 T0: {c0, c1} T1: {c0, c1} T2: {c0, c1} T3: {c0, c1}
+ // 1 T4: {c0, c1} T5: {c0, c1} T6: {c0, c1} T7: {c0, c1}
+ // 2 ...
+ // ...
+ // 7 T28: {c0, c1} T29: {c0, c1} T30: {c0, c1} T31: {c0, c1}
+ // 8 T0: {c2, c3} T1: {c2, c3} T2: {c2, c3} T3: {c2, c3}
+ // 9 T4: {c2, c3} T5: {c2, c3} T6: {c2, c3} T7: {c2, c3}
+ // 10 ...
+ // ...
+ // 15 T28: {c2, c3} T29: {c2, c3} T30: {c2, c3} T31: {c2, c3}
+ float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3}
+ // This should be the row max after S = (Q @ K^T) / sqrt(d)
+ float tmp_max_0 = max(t_reg_S_0.x, t_reg_S_0.y) * scale;
+ float tmp_max_1 = max(t_reg_S_1.x, t_reg_S_1.y) * scale;
+ lane_row_max_new[i][0] = max(lane_row_max_new[i][0], tmp_max_0);
+ lane_row_max_new[i][1] = max(lane_row_max_new[i][1], tmp_max_1);
+ } // end for kWarpTileSeqLenK
+
+ // Warp level reduce max, warp_size = 4
+ // Each thread contains the maximum of 2 rows of Br,
+ // and only the values of T0, T4, ..., T28 are used.
+ lane_row_max_new[i][0] = warp_reduce_max(lane_row_max_new[i][0]);
+ lane_row_max_new[i][1] = warp_reduce_max(lane_row_max_new[i][1]);
+ } // end for kWarpTileSeqLenQ
+ __syncthreads();
+
+ // Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenQ; ++i) {
+ // Use latest global row max without update.
+ // Br 0, row_id, 0~7, 16~23, 32~39, 48~55;
+ float block_row_max_new_0 = lane_row_max_new[i][0];
+ // Br 1, row_id, 8~15, 24~31, 40~47, 56~63;
+ float block_row_max_new_1 = lane_row_max_new[i][1];
+
+ float block_row_max_old_0 = lane_block_row_max_old[i][0];
+ float block_row_max_old_1 = lane_block_row_max_old[i][1];
+ // Apply m_new = max(m_old, m_new) here.
+ block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0);
+ block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1);
+
+ #pragma unroll
+ for (int j = 0; j < kWarpTileSeqLenK; ++j) {
+ float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3}
+ // P = Exp(S - m_new), fmaf(x, y, z) = x * y + z;
+ t_reg_S_0.x = __expf(__fmaf_rn(t_reg_S_0.x, scale, - block_row_max_new_0));
+ t_reg_S_0.y = __expf(__fmaf_rn(t_reg_S_0.y, scale, - block_row_max_new_0));
+ t_reg_S_1.x = __expf(__fmaf_rn(t_reg_S_1.x, scale, - block_row_max_new_1));
+ t_reg_S_1.y = __expf(__fmaf_rn(t_reg_S_1.y, scale, - block_row_max_new_1));
+ lane_row_sum_new[i][0] += (t_reg_S_0.x + t_reg_S_0.y);
+ lane_row_sum_new[i][1] += (t_reg_S_1.x + t_reg_S_1.y);
+ // Update R_S for P[Br,Bc] = Exp(S-m), point wise.
+ HALF2(R_S[i][j][0]) = __float22half2_rn(t_reg_S_0);
+ HALF2(R_S[i][j][1]) = __float22half2_rn(t_reg_S_1);
+ } // end for kWarpTileSeqLenK
+
+ // Warp level reduce sum, warp_size = 4
+ lane_row_sum_new[i][0] = warp_reduce_sum(lane_row_sum_new[i][0]);
+ lane_row_sum_new[i][1] = warp_reduce_sum(lane_row_sum_new[i][1]);
+ } // end for kWarpTileSeqLenQ
+ __syncthreads();
+
+ // : P[Br,Bc]@V[Bc,d]=[Br,d]=[64,64/128], partion Attention.
+ // Matmul with NN layout: P[Br,Bc] row major, V[Bc,d] row major.
+ // Make sure to clear the states in R_O before MMA for P@V for each step.
+
+ // NOTE: Values for P[Br,Bc] already in R_S registers, can we use these
+ // registers for P(A) matrix directly ? How to do that ?
+ // according to the A matrix layout for MMA m16n8k16 instruction.
+ // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+ // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+ // The layout of the fragments held by different threads for A matrix with .f16.
+ // R\C 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
+ // 0 T0: {a0, a1} T1: {a0, a1} T2: {a0, a1} T3: {a0, a1} T0: {a4, a5} T1: {a4, a5} T2: {a4, a5} T3: {a4, a5}
+ // 1 T4: {a0, a1} T5: {a0, a1} T6: {a0, a1} T7: {a0, a1} T4: {a4, a5} T5: {a4, a5} T6: {a4, a5} T7: {a4, a5}
+ // 2 (dashed arrow pointing right)
+ // ...
+ // 7 T28: {a0, a1} T29: {a0, a1} T30: {a0, a1} T31: {a0, a1} T28: {a4, a5} T29: {a4, a5} T30: {a4, a5} T31: {a4, a5}
+ // 8 T0: {a2, a3} T1: {a2, a3} T2: {a2, a3} T3: {a2, a3} T0: {a6, a7} T1: {a6, a7} T2: {a6, a7} T3: {a6, a7}
+ // 9 T4: {a2, a3} T5: {a2, a3} T6: {a2, a3} T7: {a2, a3} T4: {a6, a7} T5: {a6, a7} T6: {a6, a7} T7: {a6, a7}
+ // 10 (dashed arrow pointing right)
+ // ...
+ // 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7}
+
+ // Wait V g2s stages ready.
+ if constexpr (kStage > 1) {
+ CP_ASYNC_WAIT_GROUP(kStage - 2); // s2->0, s3->1, s4->2
+ __syncthreads();
+ }
+
+ // P@V=[Br,Bc]@[Bc,d]
+ fill_3D_regs(R_O, 0);
+ #pragma unroll
+ for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) {
+ // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0;
+ int smem_sel = (tile_V_Bc) % kStage;
+ // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2;
+ int smem_sel_next = (tile_V_Bc + (kStage - 1)) % kStage;
+
+ // stages for V
+ if constexpr (kStage > 1) {
+ if ((tile_V_Bc + 1) < (Bc / kMmaAtomK)) {
+ // next V tile g2s
+ int load_gmem_V_Bc = (
+ (tile_K_seqlen * Bc) + (tile_V_Bc + 1) * kMmaAtomK + load_smem_V_Bc); // 0~15
+ int load_gmem_V_d = load_smem_V_d;
+ int load_gmem_V_addr = (
+ V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
+ uint32_t load_smem_V_ptr = (
+ smem_V_base_ptr + (smem_sel_next * V_tile_size +
+ load_smem_V_Bc * (kHeadDim + kPadV) +
+ load_smem_V_d) * sizeof(half)
+ );
+ // headdim must be multiple of 32, (kHeadDim/8)%8==0 for 128 bits ld.
+ if constexpr (kIsVCanLoadIn128b) {
+ // 64,128,192,256,...
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / kMmaAtomK)); i += 8) {
+ CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16);
+ }
+ } else {
+ // 32,96,160,224
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / kMmaAtomK)); i += 4) {
+ CP_ASYNC_CA(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 8);
+ }
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ }
+ } else {
+ // sync load curr V g2s
+ int load_gmem_V_Bc = (
+ (tile_K_seqlen * Bc) + (tile_V_Bc * kMmaAtomK) + load_smem_V_Bc); // 0~15
+ int load_gmem_V_d = load_smem_V_d;
+ int load_gmem_V_addr = (
+ V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d);
+ uint32_t load_smem_V_ptr = (
+ smem_V_base_ptr + (smem_sel * V_tile_size +
+ load_smem_V_Bc * (kHeadDim + kPadV) +
+ load_smem_V_d) * sizeof(half)
+ );
+ // headdim must be multiple of 32, (kHeadDim/8)%8==0 for 128 bits ld.
+ if constexpr (kIsVCanLoadIn128b) {
+ // 64,128,192,256,...
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / kMmaAtomK)); i += 8) {
+ CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16);
+ }
+ } else {
+ // 32,96,160,224
+ #pragma unroll
+ for (int i = 0; i < (kHeadDim / (kNumThreads / kMmaAtomK)); i += 4) {
+ CP_ASYNC_CA(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 8);
+ }
+ }
+ CP_ASYNC_COMMIT_GROUP();
+ // Wait curr V tile ready.
+ CP_ASYNC_WAIT_GROUP(0);
+ __syncthreads();
+ }
+
+ // Load k16n8 V from smem -> regs, R_KV, ldmatrix.x2.trans.
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) {
+ int warp_smem_V_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; // d, matmaul N
+ int lane_smem_V_Bc = lane_id % 16; // 0~15; Bc, matmul K
+ int lane_smem_V_d = warp_smem_V_d; // 0
+ uint32_t lane_smem_V_ptr = (
+ smem_V_base_ptr + (smem_sel * V_tile_size +
+ lane_smem_V_Bc * (kHeadDim + kPadV) +
+ lane_smem_V_d) * sizeof(half)
+ );
+ LDMATRIX_X2_T(R_V[j][0], R_V[j][1], lane_smem_V_ptr); // R_V
+ }
+ if constexpr (kStage < 2) {
+ // Wait V s2r ready if kStage < 2 in order to avoid
+ // the next V tile g2s overwrite.
+ __syncthreads();
+ }
+
+ // For R_S[1][8][2], mapping the layout below of P matrix.
+ // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps
+ // | 64x64 | warp_KV 0 |
+ // | warp_QP 0 | MMA 0 ... MMA 0 (x8) |
+ // | warp_QP 1 | MMA 1 ... MMA 1 (x8) |
+ // | warp_QP 2 | MMA 2 ... MMA 2 (x8) |
+ // | warp_QP 3 | MMA 3 ... MMA 3 (x8) |
+ // tile_V_Bc = 0, all curr MMAs(0~4) need slice P[:, 0:16], 0, 1; stored in all MMAs.
+ // tile_V_Bc = 1, all curr MMAs(0~4) need slice P[:, 16:32], 2, 3; stored in all MMAs.
+ // tile_V_Bc = 2, all curr MMAs(0~4) need slice P[:, 32:48], 4, 5; stored in all MMAs.
+ // tile_V_Bc = 3, all curr MMAs(0~4) need slice P[:, 48:64], 6, 7; stored in all MMAs.
+ int w = tile_V_Bc * 2; // MMA(Warp) selected, 0, 2, 4, 6
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ...
+ HMMA16816(R_O[i][j][0], R_O[i][j][1],
+ R_S[i][w][0], R_S[i][w][1], R_S[i][w + 1][0], R_S[i][w + 1][1],
+ R_V[j][0], R_V[j][1],
+ R_O[i][j][0], R_O[i][j][1]);
+ }
+ }
+
+ if constexpr (kStage > 1) {
+ // Wait next V tile g2s ready.
+ CP_ASYNC_WAIT_GROUP(kStage - 2);
+ __syncthreads();
+ }
+
+ } // end for V Bc.
+ __syncthreads();
+
+ // Rescale O -> Update row sum Exp -> then, Update row max.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=kWarpTileSeqLenP=1
+ // m = max(m_old, m_new), l = exp(m_old - m) * l_old + l_new (FA2 paper)
+ // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; Br 1, row_id, 8~15, 24~31, 40~47, 56~63
+ float block_row_max_new_0 = lane_row_max_new[i][0];
+ float block_row_max_new_1 = lane_row_max_new[i][1];
+ float block_row_sum_new_0 = lane_row_sum_new[i][0];
+ float block_row_sum_new_1 = lane_row_sum_new[i][1];
+
+ float block_row_max_old_0 = lane_block_row_max_old[i][0];
+ float block_row_max_old_1 = lane_block_row_max_old[i][1];
+ // NOTE: max(-inf, val) = val.
+ block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0);
+ block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1);
+ // Avoid inf value while using m_old for rescaling O.
+ block_row_max_old_0 = (tile_K_seqlen > 0 ? block_row_max_old_0 :
+ block_row_max_new_0);
+ block_row_max_old_1 = (tile_K_seqlen > 0 ? block_row_max_old_1 :
+ block_row_max_new_1);
+
+ // rescale factor for O and l, exp(m_old - m)
+ float rescale_o_factor_0 = __expf(block_row_max_old_0 - block_row_max_new_0);
+ float rescale_o_factor_1 = __expf(block_row_max_old_1 - block_row_max_new_1);
+ // 0. Rescale O: Online rescaling O each tile_K_seqlen step, need m_new, m_old.
+ // m = max(m_old, m_new), O_new[Br,d] = exp(m_old - m) * O_old + P@V
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ...
+ float2 t_reg_O_0 = __half22float2(HALF2(R_O[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_O_1 = __half22float2(HALF2(R_O[i][j][1])); // 8~15 {c2, c3}
+ float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3}
+ // Note that the formula in the FA2 paper is incorrect; here,
+ // the inverse of the exp function should not be taken, as it
+ // would result in an error during rescaling, namely, you have
+ // use exp(m_old - m_new), not 1/(m_old - m_new).
+ // O_new[Br,d] = exp(m_old - m_new) * O_old + P@V
+ t_reg_D_0.x = __fmaf_rn(rescale_o_factor_0, t_reg_D_0.x, t_reg_O_0.x);
+ t_reg_D_0.y = __fmaf_rn(rescale_o_factor_0, t_reg_D_0.y, t_reg_O_0.y);
+ t_reg_D_1.x = __fmaf_rn(rescale_o_factor_1, t_reg_D_1.x, t_reg_O_1.x);
+ t_reg_D_1.y = __fmaf_rn(rescale_o_factor_1, t_reg_D_1.y, t_reg_O_1.y);
+ HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0);
+ HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1);
+ } // end for kWarpTileHeadDimV.
+
+ // Now, we can update m, l after O has been scaled.
+ // 1. First, update block row sum Exp for each lane which
+ // need both m_new and m_old.
+ float block_row_sum_old_0 = lane_block_row_sum_old[i][0];
+ float block_row_sum_old_1 = lane_block_row_sum_old[i][1];
+ // Update l = exp(m_old - m_new) * l_old + row_sum(P).
+ lane_block_row_sum_old[i][0] = (__fmaf_rn(
+ rescale_o_factor_0, block_row_sum_old_0, block_row_sum_new_0));
+ lane_block_row_sum_old[i][1] = (__fmaf_rn(
+ rescale_o_factor_1, block_row_sum_old_1, block_row_sum_new_1));
+ // 2. Then, update block row max for each lane.
+ lane_block_row_max_old[i][0] = block_row_max_new_0;
+ lane_block_row_max_old[i][1] = block_row_max_new_1;
+ }
+ } // end loop over N
+ __syncthreads();
+
+ // Finaly, we still have to rescale O once more.
+ // O_output(D) = ( 1/l_final ) * O_final (FA2 paper)
+ // NOTE: Here, we choose to reuse R_O as final output
+ // in order to reduce regs usage.
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
+ float rescale_factor_0 = __frcp_rn(lane_block_row_sum_old[i][0]);
+ float rescale_factor_1 = __frcp_rn(lane_block_row_sum_old[i][1]);
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ...
+ float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1}
+ float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3}
+ t_reg_D_0.x = rescale_factor_0 * t_reg_D_0.x;
+ t_reg_D_0.y = rescale_factor_0 * t_reg_D_0.y;
+ t_reg_D_1.x = rescale_factor_1 * t_reg_D_1.x;
+ t_reg_D_1.y = rescale_factor_1 * t_reg_D_1.y;
+ HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0);
+ HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1);
+ }
+ }
+
+ // Store O(D): Write O[Br,d] from regs -> gmem, collective store
+ // with reg reuse & warp shuffle. need R_Z[2][4].
+ // TODO: reuse Q smem for collective store: regs -> smem -> gmem
+ #pragma unroll
+ for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1
+ #pragma unroll
+ for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8
+ // we have to use new R_Z regs for collective store.
+ uint32_t R_Z[2][4];
+ R_Z[0][0] = R_D[i][j][0]; R_Z[1][0] = R_D[i][j][1]; // warp_size 4
+ R_Z[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4);
+ R_Z[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4);
+ R_Z[0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4);
+ R_Z[1][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4);
+ R_Z[1][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4);
+ R_Z[1][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4);
+ // st.global.v4 128 bits. [Br,d]
+ if (lane_id % 4 == 0) {
+ // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56
+ int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM;
+ int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7
+ // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56)
+ int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN;
+ int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8)
+ int store_gmem_O_addr_0 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d);
+ int store_gmem_O_addr_1 = (
+ O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d);
+ LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Z[0][0]);
+ LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Z[1][0]);
+ }
+ } // end for kWarpTileHeadDimV
+ } // end for kWarpTileSeqLenQ
+}
+
+// Launch kernel for flash_attn_mma_stages_split_q_tiling_qk
+template
+void launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle(
+ torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) {
+ // Now: fixed tile BrxBc=128x128 for d>= 128, 64x64 for d<128.
+ // TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size.
+ constexpr int kMmaAtomM = 16;
+ constexpr int kMmaAtomN = 8;
+ constexpr int kMmaAtomK = 16;
+ constexpr int kMmaTileSeqLenQ = (kHeadDim < 128) ? 4 : 8;
+ constexpr int kMmaTileSeqLenK = 1;
+ constexpr int kMmaTileSeqLenP = (kHeadDim < 128) ? 4 : 8;
+ constexpr int kMmaTileHeadDimV = 1;
+ constexpr int kWarpTileSeqLenQ = 1;
+ constexpr int kWarpTileSeqLenK = (kHeadDim < 128) ? 8 : 16;
+ constexpr int kWarpTileSeqLenP = 1;
+ constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); // (d=64)8,(d=128)16,32,....
+ constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64
+ constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64
+ constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads
+ constexpr int kPadQ = 0;
+ constexpr int kPadK = 0;
+ constexpr int kPadV = 8;
+
+ // static int kMaxSramPerBlock;
+ // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0);
+ // Calculate SRAM size needed per block, Q,K,V smem size, V shared the QK smem.
+ constexpr int QK_smem_size = (kStage * (Br * (kMmaAtomK + kPadQ)) +
+ kStage * (Bc * (kMmaAtomK + kPadK)));
+ // Now, for V_smem_size, s=2, d=4M, 16 regs; d=128, 8M, 32 regs;
+ // d=256, 16M, 64 regs; d=512, 32M, 128 regs; d=1024, 64M, 256 regs;
+ // TODO: sub-tiling for d while perform P@V, kMmaAtomK * (kMmaAtomN)
+ constexpr int V_smem_size = (kStage * (kMmaAtomK * (kHeadDim + kPadV)));
+ // try to let V reuse all Q+K smem after Q@K^T, reduce smem usage.
+ const int smem_max_size = max(QK_smem_size, V_smem_size) * sizeof(half);
+
+ const int QKV_batch = Q.size(0);
+ const int QKV_head = Q.size(1);
+ const int QKV_seqlen = Q.size(2); // QKV_seqlen
+ assert(QKV_seqlen % max(Br, Bc) == 0); // multiple of max(Br, Bc)
+
+ // TODO: How to apply block swizzle to improve L2 Cache hit rate?
+ // NOTE: reorder (B,H,Tr) -> (Tr,B*H) seems can improve L2 Cache hit rate.
+ // This might be because SM schedules blocks starting from the x-dimension.
+ // Placing Tr at the forefront ensures that identical KV pairs are placed
+ // in consecutive scheduling queues, thereby improving L2 Cache hit rates.
+ // Tr(=N/Br), batch_size x num_heads
+ dim3 grid(div_ceil(QKV_seqlen, Br), QKV_batch * QKV_head);
+ dim3 block(kNumThreads); // 4/8 warps per block
+ // when N >= 6016, stage 1 will have precision gap, why?
+
+ cudaFuncSetAttribute(
+ flash_attn_mma_stages_split_q_tiling_qk_swizzle_kernel<
+ kHeadDim,
+ kMmaAtomM,
+ kMmaAtomN,
+ kMmaAtomK,
+ kMmaTileSeqLenQ,
+ kMmaTileSeqLenK,
+ kMmaTileSeqLenP,
+ kMmaTileHeadDimV,
+ kWarpTileSeqLenQ,
+ kWarpTileSeqLenK,
+ kWarpTileSeqLenP,
+ kWarpTileHeadDimV,
+ kStage,
+ kPadQ,
+ kPadK,
+ kPadV
+ >,
+ cudaFuncAttributeMaxDynamicSharedMemorySize,
+ // kMaxSramPerBlock
+ 98304
+ );
+
+ flash_attn_mma_stages_split_q_tiling_qk_swizzle_kernel<
+ kHeadDim,
+ kMmaAtomM,
+ kMmaAtomN,
+ kMmaAtomK,
+ kMmaTileSeqLenQ,
+ kMmaTileSeqLenK,
+ kMmaTileSeqLenP,
+ kMmaTileHeadDimV,
+ kWarpTileSeqLenQ,
+ kWarpTileSeqLenK,
+ kWarpTileSeqLenP,
+ kWarpTileHeadDimV,
+ kStage,
+ kPadQ,
+ kPadK,
+ kPadV
+ ><<>>(
+ reinterpret_cast(Q.data_ptr()),
+ reinterpret_cast(K.data_ptr()),
+ reinterpret_cast(V.data_ptr()),
+ reinterpret_cast(O.data_ptr()),
+ QKV_seqlen,
+ QKV_head
+ );
+}
+
+void flash_attn_mma_stages_split_q_tiling_qk_swizzle(torch::Tensor Q,
+ torch::Tensor K,
+ torch::Tensor V,
+ torch::Tensor O,
+ int stages) {
+ CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D]
+ CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D]
+ const int d = Q.size(3); // B, H, N, d
+
+ if (stages > 1) {
+ switch (d)
+ {
+ case 32:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<32, 2>(Q, K, V, O);
+ break;
+ case 64:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<64, 2>(Q, K, V, O);
+ break;
+ case 96:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<96, 2>(Q, K, V, O);
+ break;
+ case 128:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<128, 2>(Q, K, V, O);
+ break;
+ case 256:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<256, 2>(Q, K, V, O);
+ break;
+ case 512:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<512, 2>(Q, K, V, O);
+ break;
+ case 1024:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<1024, 2>(Q, K, V, O);
+ break;
+ default:
+ throw std::runtime_error("headdim not support!");
+ break;
+ }
+ } else {
+ switch (d)
+ {
+ case 32:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<32, 1>(Q, K, V, O);
+ break;
+ case 64:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<64, 1>(Q, K, V, O);
+ break;
+ case 96:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<96, 1>(Q, K, V, O);
+ break;
+ case 128:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<128, 1>(Q, K, V, O);
+ break;
+ case 256:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<256, 1>(Q, K, V, O);
+ break;
+ case 512:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<512, 1>(Q, K, V, O);
+ break;
+ case 1024:
+ launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<1024, 1>(Q, K, V, O);
+ break;
+ default:
+ throw std::runtime_error("headdim not support!");
+ break;
+ }
+ }
+}
diff --git a/kernels/flash-attn/mma/flash_attn_mma_tiling_qkv_swizzle.cu b/kernels/flash-attn/mma/flash_attn_mma_tiling_qkv_swizzle.cu
new file mode 100644
index 00000000..af846429
--- /dev/null
+++ b/kernels/flash-attn/mma/flash_attn_mma_tiling_qkv_swizzle.cu
@@ -0,0 +1,2 @@
+// TODO: Manually apply SMEM swizzling instead of padding in
+// Split-Q kernels to reduce bank conflicts.
diff --git a/kernels/flash-attn/pybind/flash_attn.cc b/kernels/flash-attn/pybind/flash_attn.cc
index 8ba636c8..0faa796d 100644
--- a/kernels/flash-attn/pybind/flash_attn.cc
+++ b/kernels/flash-attn/pybind/flash_attn.cc
@@ -35,10 +35,18 @@ void flash_attn_mma_stages_split_q_tiling_qk(torch::Tensor Q,
torch::Tensor O,
int stages);
+void flash_attn_mma_stages_split_q_tiling_qk_swizzle(torch::Tensor Q,
+ torch::Tensor K,
+ torch::Tensor V,
+ torch::Tensor O,
+ int stages);
+
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_kv)
TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q)
TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q_shared_kv)
TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q_shared_qkv)
TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q_tiling_qk)
+ TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q_tiling_qk_swizzle)
}
diff --git a/kernels/flash-attn/tools/print_swizzle_layout.py b/kernels/flash-attn/tools/print_swizzle_layout.py
new file mode 100644
index 00000000..dcb9d48e
--- /dev/null
+++ b/kernels/flash-attn/tools/print_swizzle_layout.py
@@ -0,0 +1,46 @@
+import argparse
+
+
+def pretty_print_line(m: str = "", sep: str = "-", width: int = 130):
+ res_len = width - len(m)
+ left_len = int(res_len / 2)
+ right_len = res_len - left_len
+ pretty_line = sep * left_len + m + sep * right_len
+ print(pretty_line)
+
+
+def swizzle_permuted_j(i: int, j: int, col_stride: int = 64, step: int = 8):
+ # i: row index; j: col index.
+ return ((int(j / step) ^ int(i / 4)) % int(col_stride / step)) * step
+
+
+def print_swizzle_layout(rows: int = 16, col_stride: int = 64, step: int = 8):
+ str_len = 0
+ for i in range(rows):
+ layout = tuple(swizzle_permuted_j(i, j, col_stride, step)
+ for j in range(0, col_stride, step))
+ layout_str = (f"| row {i:<2} | {layout} |")
+ str_len = len(layout_str)
+ if (i == 0):
+ print("-" * str_len)
+ pretty_print_line(f"swizzle layout", width=str_len)
+ pretty_print_line(f"col 0~{col_stride}, step {step}", width=str_len)
+ print("-" * str_len)
+ print(layout_str)
+ if ((i + 1) % 4 == 0 and i != (rows - 1)):
+ print("-" * str_len)
+ print("-" * str_len)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--col_stride", "--col", type=int, default=64)
+ parser.add_argument("--step", type=int, default=8)
+ parser.add_argument("--rows", type=int, default=16)
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+ print_swizzle_layout(args.rows, args.col_stride, args.step)
+
diff --git a/kernels/hgemm/README.md b/kernels/hgemm/README.md
index 6d66037c..7ccac9f1 100755
--- a/kernels/hgemm/README.md
+++ b/kernels/hgemm/README.md
@@ -1,5 +1,5 @@
-## ⚡️⚡️Toy-HGEMM Library: May achieve the 98%~100% performance of cuBLAS 🎉🎉
+# ⚡️⚡️Toy-HGEMM: Achieve the 98%~100% TFLOPS of cuBLAS 🎉🎉
![toy-hgemm-library](https://github.com/user-attachments/assets/962bda14-b494-4423-b8eb-775da9f5503d)
@@ -23,7 +23,7 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d
|✔️|✔️|✔️|✔️|
|Copy Async (cp.async.cg/ca)|Tile MMA (More Threads)|Tile Warp (More Values)|Multi Stages(2/3/4/5)|
|✔️|✔️|✔️|✔️|
-|Register Double Buffers|Block Swizzle (Zigzag N)|Warp Swizzle (Zigzag N)| SMEM Swizzle (CUTLASS/CuTe)|
+|Register Double Buffers|Block Swizzle (Zigzag N)|Warp Swizzle (Zigzag N)| SMEM Swizzle (CuTe/MMA) |
|✔️|✔️|✔️|✔️|
|Collective Store (Warp Shuffle & Reg Reuse)|Row Major (NN)|Col Major (TN)|SGEMM FP32/TF32|
|✔️|✔️|✔️|✔️|
@@ -79,6 +79,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(torch::Tensor a, torch:
void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
void hgemm_mma_stages_block_swizzle_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
+void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
```
## 📖 Contents
diff --git a/kernels/hgemm/hgemm.py b/kernels/hgemm/hgemm.py
index 008161d0..4bde91a0 100644
--- a/kernels/hgemm/hgemm.py
+++ b/kernels/hgemm/hgemm.py
@@ -160,11 +160,11 @@ def run_benchmark(perf_func: callable,
else:
improve = 0
MAX_TFLOPS = TFLOPS
- print(f"{out_info:>50}: {out_val}, time:{mean_time_ms}ms, "
+ print(f"{out_info:>52}: {out_val}, time:{mean_time_ms}ms, "
f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}(+{improve:.2f}%)")
else:
if not only_show_improved or "cublas" in tag:
- print(f"{out_info:>50}: {out_val}, time:{mean_time_ms}ms, "
+ print(f"{out_info:>52}: {out_val}, time:{mean_time_ms}ms, "
f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}")
if show_matrix: print(out)
if args.plot_flops:
@@ -359,6 +359,9 @@ def get_mnk(sep: int = args.SEP):
run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage4+dsmem)", c, stages=4)
run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage3+dsmem)", c, stages=3)
run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage2+dsmem)", c, stages=2)
+ run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4)
+ run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3)
+ run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2)
if args.enable_mma_all: # more mma kernel tests.
run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+rr)", c, stages=4)
run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+rr)", c, stages=3)
@@ -373,6 +376,9 @@ def get_mnk(sep: int = args.SEP):
run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4, swizzle=True)
run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
+ run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4, swizzle=True)
+ run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True)
+ run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True)
if args.enable_mma_all:
run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage3+swizzle)", c, stages=3, swizzle=True)
run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage2+swizzle)", c, stages=2, swizzle=True)
diff --git a/kernels/hgemm/makefile b/kernels/hgemm/makefile
index 5518dfef..d0312cc3 100644
--- a/kernels/hgemm/makefile
+++ b/kernels/hgemm/makefile
@@ -7,11 +7,18 @@ default:
nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.bin $(DEFAULT_FLAGS)
nvcc cublas/hgemm_cublas.cu -o hgemm_cublas.bin $(DEFAULT_FLAGS)
nvcc mma/hgemm_mma_stage.cu -o hgemm_mma_stage.bin $(DEFAULT_FLAGS)
+ nvcc mma/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.bin $(DEFAULT_FLAGS)
cute_89:
nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.89.bin $(DEFAULT_FLAGS_89)
cute_89_debug:
nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.89.debug.bin $(DEFAULT_FLAGS_89) -DCUTE_HGEMM_DEBUG -Xcompiler "-Wno-format"
mma_89:
nvcc mma/hgemm_mma_stage.cu -o hgemm_mma_stage.89.bin $(DEFAULT_FLAGS_89)
+mma_89_debug:
+ nvcc mma/hgemm_mma_stage.cu -o hgemm_mma_stage.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG
+mma_89_swizzle:
+ nvcc mma/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.89.bin $(DEFAULT_FLAGS_89)
+mma_89_swizzle_debug:
+ nvcc mma/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG
clean:
rm -rf *.bin
diff --git a/kernels/hgemm/mma/hgemm_mma.cu b/kernels/hgemm/mma/hgemm_mma.cu
index 37fbd6e4..0e787483 100644
--- a/kernels/hgemm/mma/hgemm_mma.cu
+++ b/kernels/hgemm/mma/hgemm_mma.cu
@@ -313,8 +313,11 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4(
constexpr int MMA_TILE_N = 4;
constexpr int WARP_TILE_M = 4;
constexpr int WARP_TILE_N = 4;
- constexpr int A_PAD = 0;
- constexpr int B_PAD = 16;
+ // bank conflicts free via pad = 8, reject fantasy, trust the profile.
+ // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin
+ // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin
+ constexpr int A_PAD = 8;
+ constexpr int B_PAD = 8;
constexpr int NUM_THREADS= (
MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
diff --git a/kernels/hgemm/mma/hgemm_mma_stage.cu b/kernels/hgemm/mma/hgemm_mma_stage.cu
index 52feac84..d021ba29 100644
--- a/kernels/hgemm/mma/hgemm_mma_stage.cu
+++ b/kernels/hgemm/mma/hgemm_mma_stage.cu
@@ -1974,13 +1974,11 @@ void lanunch_hgemm_mma_m16n8k16_nn(
constexpr int WARP_TILE_M = 4;
constexpr int WARP_TILE_N = 4;
constexpr int WARP_TILE_K = 2;
- // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts.
- // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts.
- // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts.
- // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus,
- // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD.
- constexpr int A_PAD = 0; // 0,8,16
- constexpr int B_PAD = 16; // 0,8,16
+ // bank conflicts free via pad = 8, reject fantasy, trust the profile.
+ // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.debug.89.bin
+ // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.debug.89.bin
+ constexpr int A_PAD = 8; // 0,8,16
+ constexpr int B_PAD = 8; // 0,8,16
constexpr int NUM_THREADS= (
MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
@@ -1993,8 +1991,16 @@ void lanunch_hgemm_mma_m16n8k16_nn(
LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4x2_DSMEM_KERNEL(K_STAGE, BLOCK_SWIZZLE_STRIDE);
}
-int main() {
+#ifdef HGEMM_MMA_DEBUG
+#include
+#endif
+
+int main(int argc, char *argv[]) {
+#ifdef HGEMM_MMA_DEBUG
+ const int test_num = 1;
+#else
const int test_num = 64;
+#endif
int M_list[test_num];
int N_list[test_num];
int K_list[test_num];
@@ -2005,9 +2011,22 @@ int main() {
K_list[i] = (i + 1) * 256;
}
- const int outer_repeat = 10, inner_repeat = 1;
+#ifdef HGEMM_MMA_DEBUG
+ if (argc > 1) M_list[0] = std::stoi(argv[1]);
+ if (argc > 2) N_list[0] = std::stoi(argv[2]);
+ if (argc > 3) K_list[0] = std::stoi(argv[3]);
+#endif
+
+#ifdef HGEMM_MMA_DEBUG
+ int outer_repeat = 1, inner_repeat = 1, warmup = 1;
+ if (argc > 4) warmup = std::stoi(argv[4]);
+ if (argc > 5) inner_repeat = std::stoi(argv[5]);
+#else
+ int outer_repeat = 10, inner_repeat = 1, warmup = 1;
+#endif
printf("ALGO = MMA16816 HGEMM NN MMA=2x4 WARP=4x4x2 STAGES=2 BLOCK SWIZZLE=2048\n");
+#ifndef HGEMM_MMA_DEBUG
for (int j = 0; j < 5; j++) {
int M = M_list[j], N = N_list[j], K = K_list[j];
float max_error = gemm_error_check_nn(
@@ -2016,6 +2035,7 @@ int main() {
printf("M N K = %6d %6d %6d, ", M, N, K);
printf("Max Error = %f\n", max_error);
}
+#endif
for (int j = 0; j < test_num; j++) {
int M = M_list[j], N = N_list[j], K = K_list[j];
@@ -2027,7 +2047,7 @@ int main() {
for (int k = 0; k < outer_repeat; k++) {
double this_sec = perf_gemm(
lanunch_hgemm_mma_m16n8k16_nn<2, 2048>,
- M, N, K, inner_repeat);
+ M, N, K, inner_repeat, warmup);
max_sec = max(max_sec, this_sec);
min_sec = min(min_sec, this_sec);
total_sec += this_sec;
@@ -2120,13 +2140,11 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages(
constexpr int MMA_TILE_N = 4;
constexpr int WARP_TILE_M = 4;
constexpr int WARP_TILE_N = 4;
- // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts.
- // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts.
- // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts.
- // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus,
- // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD.
- constexpr int A_PAD = 0; // 0,8,16
- constexpr int B_PAD = 16; // 0,8,16
+ // bank conflicts free via pad = 8, reject fantasy, trust the profile.
+ // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin
+ // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin
+ constexpr int A_PAD = 8; // 0,8,16
+ constexpr int B_PAD = 8; // 0,8,16
constexpr int NUM_THREADS= (
MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
@@ -2250,13 +2268,11 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem(
constexpr int MMA_TILE_N = 4;
constexpr int WARP_TILE_M = 4;
constexpr int WARP_TILE_N = 4;
- // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts.
- // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts.
- // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts.
- // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus,
- // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD.
- constexpr int A_PAD = 0; // 0,8,16
- constexpr int B_PAD = 16; // 0,8,16
+ // bank conflicts free via pad = 8, reject fantasy, trust the profile.
+ // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin
+ // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin
+ constexpr int A_PAD = 8; // 0,8,16
+ constexpr int B_PAD = 8; // 0,8,16
constexpr int NUM_THREADS= (
MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
@@ -2381,13 +2397,11 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem(
constexpr int WARP_TILE_M = 4;
constexpr int WARP_TILE_N = 4;
constexpr int WARP_TILE_K = 2;
- // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts.
- // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts.
- // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts.
- // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus,
- // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD.
- constexpr int A_PAD = 0; // 0,8,16
- constexpr int B_PAD = 16; // 0,8,16
+ // bank conflicts free via pad = 8, reject fantasy, trust the profile.
+ // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin
+ // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin
+ constexpr int A_PAD = 8; // 0,8,16
+ constexpr int B_PAD = 8; // 0,8,16
constexpr int NUM_THREADS= (
MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
@@ -2513,13 +2527,11 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(
constexpr int WARP_TILE_M = 4;
constexpr int WARP_TILE_N = 4;
constexpr int WARP_TILE_K = 2;
- // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts.
- // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts.
- // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts.
- // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus,
- // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD.
- constexpr int A_PAD = 0; // 0,8,16
- constexpr int B_PAD = 16; // 0,8,16
+ // bank conflicts free via pad = 8, reject fantasy, trust the profile.
+ // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin
+ // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin
+ constexpr int A_PAD = 8; // 0,8,16
+ constexpr int B_PAD = 8; // 0,8,16
constexpr int NUM_THREADS= (
MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
@@ -2646,13 +2658,11 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(
constexpr int WARP_TILE_M = 4;
constexpr int WARP_TILE_N = 4;
constexpr int WARP_TILE_K = 2;
- // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts.
- // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts.
- // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts.
- // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus,
- // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD.
- constexpr int A_PAD = 0; // 0,8,16
- constexpr int B_PAD = 16; // 0,8,16
+ // bank conflicts free via pad = 8, reject fantasy, trust the profile.
+ // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin
+ // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin
+ constexpr int A_PAD = 8; // 0,8,16
+ constexpr int B_PAD = 8; // 0,8,16
constexpr int NUM_THREADS= (
MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
diff --git a/kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu b/kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu
new file mode 100644
index 00000000..184571d3
--- /dev/null
+++ b/kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu
@@ -0,0 +1,853 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+using namespace nvcuda;
+
+#define WARP_SIZE 32
+#define DEVICE_INLINE __device__ inline
+#define HOST_DEVICE_INLINE __device__ __host__ inline
+#define INT4(value) (reinterpret_cast(&(value))[0])
+#define FLOAT4(value) (reinterpret_cast(&(value))[0])
+#define HALF2(value) (reinterpret_cast(&(value))[0])
+#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
+#define LDST32BITS(value) (reinterpret_cast(&(value))[0])
+#define LDST64BITS(value) (reinterpret_cast(&(value))[0])
+#define LDST128BITS(value) (reinterpret_cast(&(value))[0])
+// gmem -> smem
+#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::)
+#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::)
+#define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n))
+// ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes.
+#define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
+#define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
+// smem -> gmem: requires sm_90 or higher.
+#define CP_ASYNC_BULK_COMMIT_GROUP() asm volatile("cp.async.bulk.commit_group;\n" ::)
+#define CP_ASYNC_BULK_WAIT_ALL() asm volatile("cp.async.bulk.wait_all;\n" ::)
+#define CP_ASYNC_BULK_WAIT_GROUP(n) asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(n))
+#define CP_ASYNC_BULK(dst, src, bytes) asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
+// ldmatrix
+#define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr))
+#define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr))
+#define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr))
+#define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr))
+#define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr))
+#define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr))
+// stmatrix: requires sm_90 or higher.
+#define STMATRIX_X1(addr, R) asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R))
+#define STMATRIX_X2(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1))
+#define STMATRIX_X4(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3))
+#define STMATRIX_X1_T(addr, R) asm volatile("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R))
+#define STMATRIX_X2_T(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1))
+#define STMATRIX_X4_T(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3))
+// mma m16n8k16
+#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1))
+
+HOST_DEVICE_INLINE
+int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); }
+
+// i: row index; j: col index.
+// e.g kColStride = 64, kStep = 8 -> load 8 half as 128 bits memory issue.
+template
+static __device__ __forceinline__ int swizzle_permuted_j(int i, int j) {
+ // -------------------------------------------
+ // --------------swizzle layout---------------
+ // -------------col 0~64, step 8--------------
+ // -------------------------------------------
+ // | row 0 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // | row 1 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // | row 2 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // | row 3 | (0, 8, 16, 24, 32, 40, 48, 56) |
+ // -------------------------------------------
+ // | row 4 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // | row 5 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // | row 6 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // | row 7 | (8, 0, 24, 16, 40, 32, 56, 48) |
+ // -------------------------------------------
+ // | row 8 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // | row 9 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // | row 10 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // | row 11 | (16, 24, 0, 8, 48, 56, 32, 40) |
+ // -------------------------------------------
+ // | row 12 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // | row 13 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // | row 14 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // | row 15 | (24, 16, 8, 0, 56, 48, 40, 32) |
+ // -------------------------------------------
+ // swizzle: ((int(j / kStep) ^ int(i / 4)) % int(kColStride / kStep)) * kStep;
+ static_assert(kStep == 4 || kStep == 8, "kStep must be 8 or 4.");
+ static_assert(kColStride % kStep == 0, "kColStride must be multiple of kStep.");
+ if constexpr (kStep == 8) {
+ return (((j >> 3) ^ (i >> 2)) % (kColStride >> 3)) << 3;
+ } else {
+ static_assert(kStep == 4);
+ return (((j >> 2) ^ (i >> 2)) % (kColStride >> 2)) << 2;
+ }
+}
+
+// i: row index; j: col index
+template
+static __device__ __forceinline__ int swizzle_permuted_A_j(int i, int j) {
+ // -------------------
+ // -col 0~16, step 8--
+ // -------------------
+ // | row 0 | (0, 8) |
+ // | row 1 | (0, 8) |
+ // | row 2 | (0, 8) |
+ // | row 3 | (0, 8) |
+ // -------------------
+ // | row 4 | (8, 0) |
+ // | row 5 | (8, 0) |
+ // | row 6 | (8, 0) |
+ // | row 7 | (8, 0) |
+ // -------------------
+ // | row 8 | (0, 8) |
+ // | row 9 | (0, 8) |
+ // | row 10 | (0, 8) |
+ // | row 11 | (0, 8) |
+ // -------------------
+ // | row 12 | (8, 0) |
+ // | row 13 | (8, 0) |
+ // | row 14 | (8, 0) |
+ // | row 15 | (8, 0) |
+ // -------------------
+ return swizzle_permuted_j(i, j);
+}
+
+// In order to reduce bank conflicts, we will save the K(16x2=32)
+// dimension by half according to the stage dimension. For example,
+// stages=3, warp_tile_k=2, it will be saved as [3*2][BM][16].
+// 128x128, mma2x4, warp4x4(64,32,32), stages, block swizzle, dsmem,
+// k32 with reg double buffers
+template
+__global__ void __launch_bounds__(256)
+hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel(
+ const half* __restrict__ A, const half* __restrict__ B, half* __restrict__ C,
+ int M, int N, int K) {
+ // BLOCK_SWIZZLE 0/1 control use block swizzle or not.
+ const int bx = ((int) BLOCK_SWIZZLE) * blockIdx.z * gridDim.x + blockIdx.x;
+ const int by = blockIdx.y;
+ const int NUM_K_TILES = div_ceil(K, MMA_K * WARP_TILE_K);
+ constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128
+ constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128
+ constexpr int BK = MMA_K; // 16x2=32
+
+ extern __shared__ half smem[];
+ half* s_a = smem;
+ half* s_b = smem + K_STAGE * BM * (BK + A_PAD) * WARP_TILE_K;
+ constexpr int s_a_stage_offset = BM * (BK + A_PAD); // 128x16
+ constexpr int s_b_stage_offset = BK * (BN + B_PAD); // 16x128
+ constexpr int s_a_mma_k_store_offset = K_STAGE * BM * (BK + A_PAD);
+ constexpr int s_b_mma_k_store_offset = K_STAGE * BK * (BN + B_PAD);
+
+ const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block
+ const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block
+ const int lane_id = tid % WARP_SIZE; // 0~31
+ const int warp_m = warp_id % 2; // 0,1
+ const int warp_n = warp_id / 2; // 0,1,2,3
+
+ int load_smem_a_m = tid / 2; // row 0~127
+ int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8
+ int load_smem_b_k = tid / 16; // row 0~15
+ int load_smem_b_n = (tid % 16) * 8; // col 0,8,16,...
+ int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
+ int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
+ if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
+
+ uint32_t RC[WARP_TILE_M][WARP_TILE_N][2];
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ RC[i][j][0] = 0;
+ RC[i][j][1] = 0;
+ }
+ }
+
+ uint32_t smem_a_base_ptr = __cvta_generic_to_shared(s_a);
+ uint32_t smem_b_base_ptr = __cvta_generic_to_shared(s_b);
+
+ #pragma unroll
+ for (int k = 0; k < (K_STAGE - 1); ++k) { // 0, 1
+ // k * WMMA_K, WMMA_K=16 -> (k << 4)
+ int load_gmem_a_k = k * BK * WARP_TILE_K + load_smem_a_k; // global col of a
+ int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
+ int load_gmem_b_k = k * BK * WARP_TILE_K + load_smem_b_k; // global row of b
+ int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
+
+ uint32_t load_smem_a_ptr = (
+ smem_a_base_ptr + (k * s_a_stage_offset +
+ load_smem_a_m * (BK + A_PAD) +
+ swizzle_permuted_A_j(
+ load_smem_a_m, load_smem_a_k)) * sizeof(half)
+ );
+ CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); // MMA_K 0
+ uint32_t load_smem_a_mma_k_ptr = (
+ smem_a_base_ptr + s_a_mma_k_store_offset * sizeof(half) +
+ (k * s_a_stage_offset + load_smem_a_m * (BK + A_PAD) +
+ swizzle_permuted_A_j(load_smem_a_m, load_smem_a_k)) * sizeof(half)
+ );
+ CP_ASYNC_CG(load_smem_a_mma_k_ptr, &A[load_gmem_a_addr + 16], 16); // MMA_K 1
+
+ uint32_t load_smem_b_ptr = (
+ smem_b_base_ptr + (k * s_b_stage_offset +
+ load_smem_b_k * (BN + B_PAD) +
+ load_smem_b_n) * sizeof(half)
+ );
+ CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);
+
+ int load_gmem_b_k_mma_k = k * BK * WARP_TILE_K + MMA_K + load_smem_b_k;
+ int load_gmem_b_addr_mma_k = load_gmem_b_k_mma_k * N + load_gmem_b_n;
+ uint32_t load_smem_b_mma_k_ptr = (
+ smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) +
+ (k * s_b_stage_offset + load_smem_b_k * (BN + B_PAD) +
+ load_smem_b_n) * sizeof(half)
+ );
+ CP_ASYNC_CG(load_smem_b_mma_k_ptr, &B[load_gmem_b_addr_mma_k], 16);
+
+ CP_ASYNC_COMMIT_GROUP();
+ }
+
+ CP_ASYNC_WAIT_GROUP(K_STAGE-2); // s2->0, s3->1, s4->2
+ __syncthreads();
+
+ uint32_t RA[2][WARP_TILE_M][4];
+ uint32_t RB[2][WARP_TILE_N][2];
+
+ int reg_store_idx = 0;
+ int reg_load_idx = 1;
+
+ {
+ // ldmatrix for s_a, ldmatrix.trans for s_b.
+ // smem -> reg buffers 0, first MMA_K, 0~15
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
+ int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
+ int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
+ uint32_t lane_smem_a_ptr = (
+ smem_a_base_ptr +
+ (0 * s_a_stage_offset + lane_smem_a_m * (BK + A_PAD) +
+ swizzle_permuted_A_j(lane_smem_a_m, lane_smem_a_k)) * sizeof(half)
+ );
+ LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1],
+ RA[reg_store_idx][i][2], RA[reg_store_idx][i][3],
+ lane_smem_a_ptr);
+ }
+
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
+ int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
+ int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
+ uint32_t lane_smem_b_ptr = (
+ smem_b_base_ptr +
+ (0 * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
+ lane_smem_b_n) * sizeof(half)
+ );
+ // may use .x4.trans to load 4 matrix for reg double buffers at once?
+ LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
+ lane_smem_b_ptr);
+ }
+ }
+
+ #pragma unroll
+ for (int k = (K_STAGE - 1); k < NUM_K_TILES; ++k) {
+ reg_store_idx ^= 1; // 0->1
+ reg_load_idx ^= 1; // 1->0
+ int smem_sel = (k + 1) % K_STAGE; // s3 k 2->0, k 3->1, k 4->2...
+ int smem_sel_next = k % K_STAGE; // s3 k 2->2, k 3->0, k 4->1...
+
+ // stage gmem -> smem
+ int load_gmem_a_k = k * BK * WARP_TILE_K + load_smem_a_k; // global col of a
+ int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
+ int load_gmem_b_k = k * BK * WARP_TILE_K + load_smem_b_k; // global row of b
+ int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
+
+ uint32_t load_smem_a_ptr = (
+ smem_a_base_ptr + (smem_sel_next * s_a_stage_offset +
+ load_smem_a_m * (BK + A_PAD) +
+ swizzle_permuted_A_j(
+ load_smem_a_m, load_smem_a_k)) * sizeof(half)
+ );
+ CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); // MMA_K 0
+ uint32_t load_smem_a_mma_k_ptr = (
+ smem_a_base_ptr + s_a_mma_k_store_offset * sizeof(half) +
+ (smem_sel_next * s_a_stage_offset + load_smem_a_m * (BK + A_PAD) +
+ swizzle_permuted_A_j(load_smem_a_m, load_smem_a_k)) * sizeof(half)
+ );
+ CP_ASYNC_CG(load_smem_a_mma_k_ptr, &A[load_gmem_a_addr + 16], 16); // MMA_K 1
+
+ uint32_t load_smem_b_ptr = (
+ smem_b_base_ptr + (smem_sel_next * s_b_stage_offset +
+ load_smem_b_k * (BN + B_PAD) +
+ load_smem_b_n) * sizeof(half)
+ );
+ CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16);
+
+ int load_gmem_b_k_mma_k = k * BK * WARP_TILE_K + MMA_K + load_smem_b_k;
+ int load_gmem_b_addr_mma_k = load_gmem_b_k_mma_k * N + load_gmem_b_n;
+ uint32_t load_smem_b_mma_k_ptr = (
+ smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) +
+ (smem_sel_next * s_b_stage_offset + load_smem_b_k * (BN + B_PAD) +
+ load_smem_b_n) * sizeof(half)
+ );
+ CP_ASYNC_CG(load_smem_b_mma_k_ptr, &B[load_gmem_b_addr_mma_k], 16);
+ CP_ASYNC_COMMIT_GROUP();
+
+ // ldmatrix for s_a, ldmatrix.trans for s_b.
+ // smem -> reg buffers 1, second MMA_K, 16~31
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
+ int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
+ int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
+ uint32_t lane_smem_a_ptr = (
+ smem_a_base_ptr + s_a_mma_k_store_offset * sizeof(half) +
+ (smem_sel * s_a_stage_offset + lane_smem_a_m * (BK + A_PAD) +
+ swizzle_permuted_A_j(lane_smem_a_m, lane_smem_a_k)) * sizeof(half)
+ );
+ LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1],
+ RA[reg_store_idx][i][2], RA[reg_store_idx][i][3],
+ lane_smem_a_ptr);
+ }
+
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
+ int lane_smem_b_k = lane_id % 16; // 0~15
+ int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
+ uint32_t lane_smem_b_ptr = (
+ smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) +
+ (smem_sel * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
+ lane_smem_b_n) * sizeof(half)
+ );
+ // may use .x4.trans to load 4 matrix for reg double buffers at once?
+ LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
+ lane_smem_b_ptr);
+ }
+
+ // MMA compute, first MMA_K
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ // Warp swizzle: Right -> Left -> Right -> Left
+ int j_s = ((i % 2) && WARP_SWIZZLE)? (WARP_TILE_N - j - 1) : j;
+ HMMA16816(RC[i][j_s][0], RC[i][j_s][1],
+ RA[reg_load_idx][i][0], RA[reg_load_idx][i][1],
+ RA[reg_load_idx][i][2], RA[reg_load_idx][i][3],
+ RB[reg_load_idx][j_s][0], RB[reg_load_idx][j_s][1],
+ RC[i][j_s][0], RC[i][j_s][1]);
+ }
+ }
+
+ reg_store_idx ^= 1; // 1 -> 0
+ reg_load_idx ^= 1; // 0 -> 1
+ // MMA compute, second MMA_K
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ // Warp swizzle: Right -> Left -> Right -> Left
+ int j_s = ((i % 2) && WARP_SWIZZLE)? (WARP_TILE_N - j - 1) : j;
+ HMMA16816(RC[i][j_s][0], RC[i][j_s][1],
+ RA[reg_load_idx][i][0], RA[reg_load_idx][i][1],
+ RA[reg_load_idx][i][2], RA[reg_load_idx][i][3],
+ RB[reg_load_idx][j_s][0], RB[reg_load_idx][j_s][1],
+ RC[i][j_s][0], RC[i][j_s][1]);
+ }
+ }
+
+ CP_ASYNC_WAIT_GROUP(K_STAGE-2);
+ __syncthreads();
+
+ // load next k iters to reg buffers.
+ // smem -> reg buffers 0, first MMA_K, 0~15
+ // int smem_sel_reg = (k + 2) % K_STAGE; // vs smem_sel k=2->(0)1, k=3->(1)2
+ int smem_sel_reg = (smem_sel + 1) % K_STAGE; // vs smem_sel k=2->(0)1, k=3->(1)2
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
+ int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
+ int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
+ uint32_t lane_smem_a_ptr = (
+ smem_a_base_ptr + (smem_sel_reg * s_a_stage_offset +
+ lane_smem_a_m * (BK + A_PAD) +
+ swizzle_permuted_A_j(
+ lane_smem_a_m, lane_smem_a_k)) * sizeof(half)
+ );
+ LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1],
+ RA[reg_store_idx][i][2], RA[reg_store_idx][i][3],
+ lane_smem_a_ptr);
+ }
+
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
+ int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
+ int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
+ uint32_t lane_smem_b_ptr = (
+ smem_b_base_ptr + (smem_sel_reg * s_b_stage_offset +
+ lane_smem_b_k * (BN + B_PAD) +
+ lane_smem_b_n) * sizeof(half)
+ );
+ // may use .x4.trans to load 4 matrix for reg double buffers at once?
+ LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
+ lane_smem_b_ptr);
+ }
+ }
+
+ // make sure all memory issues ready.
+ if constexpr ((K_STAGE - 2) > 0) {
+ CP_ASYNC_WAIT_GROUP(0);
+ __syncthreads();
+ }
+
+ // processing last (K_STAGE-1) k iters.
+ {
+ #pragma unroll
+ for (int k = 0; k < (K_STAGE - 1); k++) {
+ reg_store_idx ^= 1; // 0->1
+ reg_load_idx ^= 1; // 1->0
+
+ int stage_sel = ((NUM_K_TILES - (K_STAGE - 1) + k) % K_STAGE);
+ // ldmatrix for s_a, ldmatrix.trans for s_b.
+ // smem -> reg buffers 1, second MMA_K
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
+ int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
+ int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
+ uint32_t lane_smem_a_ptr = (
+ smem_a_base_ptr + s_a_mma_k_store_offset * sizeof(half) +
+ (stage_sel * s_a_stage_offset + lane_smem_a_m * (BK + A_PAD) +
+ swizzle_permuted_A_j(lane_smem_a_m, lane_smem_a_k)) * sizeof(half)
+ );
+ LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1],
+ RA[reg_store_idx][i][2], RA[reg_store_idx][i][3],
+ lane_smem_a_ptr);
+ }
+
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
+ int lane_smem_b_k = lane_id % 16; // 0~15
+ int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
+ uint32_t lane_smem_b_ptr = (
+ smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) +
+ (stage_sel * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) +
+ lane_smem_b_n) * sizeof(half)
+ );
+ LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
+ lane_smem_b_ptr);
+ }
+
+ // MMA compute, first MMA_K
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ // Warp swizzle: Right -> Left -> Right -> Left
+ int j_s = ((i % 2) && WARP_SWIZZLE)? (WARP_TILE_N - j - 1) : j;
+ HMMA16816(RC[i][j_s][0], RC[i][j_s][1],
+ RA[reg_load_idx][i][0], RA[reg_load_idx][i][1],
+ RA[reg_load_idx][i][2], RA[reg_load_idx][i][3],
+ RB[reg_load_idx][j_s][0], RB[reg_load_idx][j_s][1],
+ RC[i][j_s][0], RC[i][j_s][1]);
+ }
+ }
+
+ reg_store_idx ^= 1; // 1 -> 0
+ reg_load_idx ^= 1; // 0 -> 1
+
+ // MMA compute, second MMA_K
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ // Warp swizzle: Right -> Left -> Right -> Left
+ int j_s = ((i % 2) && WARP_SWIZZLE)? (WARP_TILE_N - j - 1) : j;
+ HMMA16816(RC[i][j_s][0], RC[i][j_s][1],
+ RA[reg_load_idx][i][0], RA[reg_load_idx][i][1],
+ RA[reg_load_idx][i][2], RA[reg_load_idx][i][3],
+ RB[reg_load_idx][j_s][0], RB[reg_load_idx][j_s][1],
+ RC[i][j_s][0], RC[i][j_s][1]);
+ }
+ }
+
+ // load next k iters to reg buffers.
+ // smem -> reg buffers 0, first MMA_K, 0~15
+ // int stage_sel_reg = ((NUM_K_TILES - K_STAGE + k) % K_STAGE);
+ int stage_sel_reg = (stage_sel + 1) % K_STAGE;
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
+ int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
+ int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
+ uint32_t lane_smem_a_ptr = (
+ smem_a_base_ptr + (stage_sel_reg * s_a_stage_offset +
+ lane_smem_a_m * (BK + A_PAD) +
+ swizzle_permuted_A_j(
+ lane_smem_a_m, lane_smem_a_k)) * sizeof(half)
+ );
+ LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1],
+ RA[reg_store_idx][i][2], RA[reg_store_idx][i][3],
+ lane_smem_a_ptr);
+ }
+
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
+ int lane_smem_b_k = lane_id % 16; // 0~15, 0~15
+ int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
+ uint32_t lane_smem_b_ptr = (
+ smem_b_base_ptr + (stage_sel_reg * s_b_stage_offset +
+ lane_smem_b_k * (BN + B_PAD) +
+ lane_smem_b_n) * sizeof(half)
+ );
+ LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1],
+ lane_smem_b_ptr);
+ }
+ }
+ }
+
+ // collective store with reg reuse & warp shuffle
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ // reuse RA[2][4][4] reg here, this may boost 0.3~0.5 TFLOPS up.
+ // may not put 'if' in N loop, it will crash the 'pragma unroll' hint ?
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ // How to use LDST128BITS here? __shfl_sync -> lane 0 -> store 8 half.
+ // thus, we only need 8 memory issues with 128 bits after shfl_sync.
+ RA[0][j][0] = RC[i][j][0];
+ RA[1][j][0] = RC[i][j][1];
+ RA[0][j][1] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 1);
+ RA[0][j][2] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 2);
+ RA[0][j][3] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 3);
+ RA[1][j][1] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 1);
+ RA[1][j][2] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 2);
+ RA[1][j][3] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 3);
+ }
+
+ if (lane_id % 4 == 0) {
+ int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
+ int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4;
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
+ int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n;
+ int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n;
+ int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n;
+ LDST128BITS(C[store_gmem_c_addr_0]) = LDST128BITS(RA[0][j][0]);
+ LDST128BITS(C[store_gmem_c_addr_1]) = LDST128BITS(RA[1][j][0]);
+ }
+ }
+ }
+}
+
+// build cpp binary
+#ifndef NO_MMA_HGEMM_BIN
+
+#include "utils.h"
+
+// 128x128, mma2x4, warp4x4x2(64,32,32), stages, block&smem swizzle, dsmem, reg double buffers
+#define LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(stages, stride) \
+{ \
+ const int smem_max_size = ( \
+ (stages) * BM * (BK + A_PAD) * WARP_TILE_K * sizeof(half) + \
+ (stages) * BK * (BN + B_PAD) * WARP_TILE_K * sizeof(half)); \
+ cudaFuncSetAttribute( \
+ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel< \
+ MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \
+ WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, A_PAD, B_PAD, (stages), true>, \
+ cudaFuncAttributeMaxDynamicSharedMemorySize, \
+ 98304); \
+ const int N_SWIZZLE = (N + (stride) - 1) / (stride); \
+ dim3 block(NUM_THREADS); \
+ dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \
+ div_ceil(M, BM), \
+ N_SWIZZLE); \
+ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel< \
+ MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \
+ WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, A_PAD, B_PAD, (stages), true><<< \
+ grid, block, smem_max_size>>>( \
+ a, b, c, \
+ M, N, K \
+ ); \
+}
+
+// 128x128, mma2x4, warp4x4x2(64,32,32), stages, block&smem swizzle, dsmem, reg double buffers
+template
+void lanunch_hgemm_mma_m16n8k16_swizzle_nn(
+ const half* a, const half* b, half* c, int M, int N, int K) {
+ constexpr int MMA_M = 16;
+ constexpr int MMA_N = 8;
+ constexpr int MMA_K = 16;
+ constexpr int MMA_TILE_M = 2;
+ constexpr int MMA_TILE_N = 4;
+ constexpr int WARP_TILE_M = 4;
+ constexpr int WARP_TILE_N = 4;
+ constexpr int WARP_TILE_K = 2;
+ // bank conflicts free via pad = 8, reject fantasy, trust the profile.
+ // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.debug.89.bin
+ // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.debug.89.bin
+ constexpr int A_PAD = 0; // apply smem swizzle
+ constexpr int B_PAD = 8; // 0,8,16
+ constexpr int NUM_THREADS= (
+ MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
+ constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
+ constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N;
+ constexpr int BK = MMA_K;
+ // s2: 2*128*(32)*2=16KB, 2*32*(128+16)*2=18KB, ~35KB
+ // s3: 3*128*(32)*2=24KB, 3*32*(128+16)*2=27KB, ~51KB
+ // s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB
+ // s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB
+ LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(K_STAGE, BLOCK_SWIZZLE_STRIDE);
+}
+
+#ifdef HGEMM_MMA_DEBUG
+#include
+#endif
+
+int main(int argc, char *argv[]) {
+#ifdef HGEMM_MMA_DEBUG
+ const int test_num = 1;
+#else
+ const int test_num = 64;
+#endif
+ int M_list[test_num];
+ int N_list[test_num];
+ int K_list[test_num];
+
+ for (int i = 0; i < test_num; i++) {
+ M_list[i] = (i + 1) * 256;
+ N_list[i] = (i + 1) * 256;
+ K_list[i] = (i + 1) * 256;
+ }
+
+#ifdef HGEMM_MMA_DEBUG
+ if (argc > 1) M_list[0] = std::stoi(argv[1]);
+ if (argc > 2) N_list[0] = std::stoi(argv[2]);
+ if (argc > 3) K_list[0] = std::stoi(argv[3]);
+#endif
+
+#ifdef HGEMM_MMA_DEBUG
+ int outer_repeat = 1, inner_repeat = 1, warmup = 1;
+ if (argc > 4) warmup = std::stoi(argv[4]);
+ if (argc > 5) inner_repeat = std::stoi(argv[5]);
+#else
+ int outer_repeat = 10, inner_repeat = 1, warmup = 1;
+#endif
+
+ printf("ALGO = MMA16816 HGEMM NN MMA=2x4 WARP=4x4x2 STAGES=2 BLOCK SWIZZLE=2048 + A SMEM SWIZZLE\n");
+#ifndef HGEMM_MMA_DEBUG
+ for (int j = 0; j < 5; j++) {
+ int M = M_list[j], N = N_list[j], K = K_list[j];
+ float max_error = gemm_error_check_nn(
+ lanunch_hgemm_mma_m16n8k16_swizzle_nn<2, 2048>,
+ M, N, K);
+ printf("M N K = %6d %6d %6d, ", M, N, K);
+ printf("Max Error = %f\n", max_error);
+ }
+#endif
+
+ for (int j = 0; j < test_num; j++) {
+ int M = M_list[j], N = N_list[j], K = K_list[j];
+
+ double max_sec = 0.0;
+ double min_sec = DBL_MAX;
+ double total_sec = 0.0;
+
+ for (int k = 0; k < outer_repeat; k++) {
+ double this_sec = perf_gemm(
+ lanunch_hgemm_mma_m16n8k16_swizzle_nn<2, 2048>,
+ M, N, K, inner_repeat, warmup);
+ max_sec = max(max_sec, this_sec);
+ min_sec = min(min_sec, this_sec);
+ total_sec += this_sec;
+ }
+
+ // 1 TFLOPS = 10^12 FLOPS
+ // ref: https://imgtec.eetrend.com/blog/2021/100062210.html.
+ double avg_sec = total_sec / outer_repeat;
+ double avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
+
+ printf("M N K = %6d %6d %6d, ", M, N, K);
+ printf("Time = %12.8lf %12.8lf %12.8lf s, ", min_sec, avg_sec, max_sec);
+ printf("AVG Performance = %10.4lf Tflops\n", avg_Tflops);
+ }
+
+ return 0;
+}
+
+#else
+
+// --------------------- PyTorch bindings for custom kernel -----------------------
+#include
+#include
+#define STRINGFY(str) #str
+#define TORCH_BINDING_COMMON_EXTENSION(func) \
+ m.def(STRINGFY(func), &func, STRINGFY(func));
+
+#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
+if(((T).options().dtype() != (th_type))) { \
+ std::cout << "Tensor Info:" << (T).options() << std::endl; \
+ throw std::runtime_error("values must be "#th_type); \
+}
+
+#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \
+if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
+ throw std::runtime_error("Tensor size mismatch!"); \
+}
+
+// 128x128, mma2x4, warp4x4x2(64,32,32), stages, block&smem swizzle, dsmem, reg double buffers
+#define LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(stages, stride) \
+{ \
+ const int smem_max_size = ( \
+ (stages) * BM * (BK + A_PAD) * WARP_TILE_K * sizeof(half) + \
+ (stages) * BK * (BN + B_PAD) * WARP_TILE_K * sizeof(half)); \
+ cudaFuncSetAttribute( \
+ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel< \
+ MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \
+ WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, A_PAD, B_PAD, (stages), true>, \
+ cudaFuncAttributeMaxDynamicSharedMemorySize, \
+ 98304); \
+ const int N_SWIZZLE = (N + (stride) - 1) / (stride); \
+ dim3 block(NUM_THREADS); \
+ dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \
+ div_ceil(M, BM), \
+ N_SWIZZLE); \
+ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel< \
+ MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \
+ WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, A_PAD, B_PAD, (stages), true><<< \
+ grid, block, smem_max_size>>>( \
+ reinterpret_cast(a.data_ptr()), \
+ reinterpret_cast(b.data_ptr()), \
+ reinterpret_cast(c.data_ptr()), \
+ M, N, K \
+ ); \
+}
+
+// no block swizzle, but have A smem swizzle
+#define LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_NO_SWIZZLE_KERNEL(stages) \
+{ \
+ const int smem_max_size = ( \
+ (stages) * BM * (BK + A_PAD) * WARP_TILE_K * sizeof(half) + \
+ (stages) * BK * (BN + B_PAD) * WARP_TILE_K * sizeof(half)); \
+ cudaFuncSetAttribute( \
+ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel< \
+ MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \
+ WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, A_PAD, B_PAD, (stages), false>, \
+ cudaFuncAttributeMaxDynamicSharedMemorySize, \
+ 98304); \
+ dim3 block(NUM_THREADS); \
+ dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \
+ hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel< \
+ MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \
+ WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, A_PAD, B_PAD, (stages), false><<< \
+ grid, block, smem_max_size>>>( \
+ reinterpret_cast(a.data_ptr()), \
+ reinterpret_cast(b.data_ptr()), \
+ reinterpret_cast(c.data_ptr()), \
+ M, N, K \
+ ); \
+}
+
+// 128x128, mma2x4, warp4x4x2(64,32,32), stages, block swizzle, dsmem, reg double buffers
+void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle(
+ torch::Tensor a, torch::Tensor b, torch::Tensor c,
+ int stages, bool swizzle, int swizzle_stride) {
+ CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
+ CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
+ CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
+ const int M = a.size(0);
+ const int K = a.size(1);
+ const int N = b.size(1);
+ CHECK_TORCH_TENSOR_SHAPE(a, M, K)
+ CHECK_TORCH_TENSOR_SHAPE(b, K, N)
+ CHECK_TORCH_TENSOR_SHAPE(c, M, N)
+ constexpr int MMA_M = 16;
+ constexpr int MMA_N = 8;
+ constexpr int MMA_K = 16;
+ constexpr int MMA_TILE_M = 2;
+ constexpr int MMA_TILE_N = 4;
+ constexpr int WARP_TILE_M = 4;
+ constexpr int WARP_TILE_N = 4;
+ constexpr int WARP_TILE_K = 2;
+ // bank conflicts free via pad = 8, reject fantasy, trust the profile.
+ // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin
+ // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin
+ constexpr int A_PAD = 0; // apply smem swizzle
+ constexpr int B_PAD = 8; // 0,8,16
+ constexpr int NUM_THREADS= (
+ MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
+ constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M;
+ constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N;
+ constexpr int BK = MMA_K;
+ // s2: 2*128*(32)*2=16KB, 2*32*(128+16)*2=18KB, ~35KB
+ // s3: 3*128*(32)*2=24KB, 3*32*(128+16)*2=27KB, ~51KB
+ // s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB
+ // s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB
+ if (swizzle) {
+ // assert(swizzle_stride % 256 == 0);
+ switch (stages)
+ {
+ case 2: // ~35KB
+ LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(2, swizzle_stride);
+ break;
+ case 3: // ~51KB
+ LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(3, swizzle_stride);
+ break;
+ case 4: // ~68KB
+ LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(4, swizzle_stride);
+ break;
+ case 5: // ~85KB
+ LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(5, swizzle_stride);
+ break;
+ default:
+ LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(2, swizzle_stride);
+ break;
+ }
+ } else {
+ switch (stages)
+ {
+ case 2:
+ LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_NO_SWIZZLE_KERNEL(2);
+ break;
+ case 3:
+ LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_NO_SWIZZLE_KERNEL(3);
+ break;
+ case 4:
+ LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_NO_SWIZZLE_KERNEL(4);
+ break;
+ case 5:
+ LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_NO_SWIZZLE_KERNEL(5);
+ break;
+ default:
+ LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_NO_SWIZZLE_KERNEL(2);
+ break;
+ }
+ }
+}
+
+#endif
diff --git a/kernels/hgemm/pybind/hgemm.cc b/kernels/hgemm/pybind/hgemm.cc
index 36551d0b..0c14d3fd 100644
--- a/kernels/hgemm/pybind/hgemm.cc
+++ b/kernels/hgemm/pybind/hgemm.cc
@@ -49,6 +49,8 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch:
void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
// from hgemm_mma_stage_tn_cute.cu
void hgemm_mma_stages_block_swizzle_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
+// from hgemm_mma_stage_swizzle.cu
+void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
@@ -93,6 +95,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem)
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4)
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr)
+ // smem swizzle
+ TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle)
// TN: A row major MxK, B col major NxK, C row major MxN
TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn)
// TN: cute hgemm with smem & block swizzle
diff --git a/kernels/hgemm/tools/utils.py b/kernels/hgemm/tools/utils.py
index 4f0456e9..7f45ccb2 100644
--- a/kernels/hgemm/tools/utils.py
+++ b/kernels/hgemm/tools/utils.py
@@ -24,6 +24,7 @@ def get_build_sources():
build_sources.append('wmma/hgemm_wmma_stage.cu')
build_sources.append('mma/hgemm_mma.cu')
build_sources.append('mma/hgemm_mma_stage.cu')
+ build_sources.append('mma/hgemm_mma_stage_swizzle.cu')
build_sources.append('mma/hgemm_mma_stage_tn.cu')
build_sources.append('cutlass/hgemm_mma_stage_tn_cute.cu')
build_sources.append('pybind/hgemm.cc')
diff --git a/kernels/hgemm/utils/utils.h b/kernels/hgemm/utils/utils.h
index 86d0d617..23a501dd 100644
--- a/kernels/hgemm/utils/utils.h
+++ b/kernels/hgemm/utils/utils.h
@@ -6,7 +6,7 @@
template
float perf_gemm(
void (*gpu_hgemm) (const T *, const T *, T *, int, int, int),
- int M, int N, int K, int repeat) {
+ int M, int N, int K, int repeat, int warmup = 1) {
size_t size_a = M * K * sizeof(T);
size_t size_b = K * N * sizeof(T);
@@ -19,7 +19,7 @@ float perf_gemm(
cudaMalloc(&d_c, size_c);
// warmup
- for (int i = 0; i < 10; ++i){
+ for (int i = 0; i < warmup; ++i){
gpu_hgemm(d_a, d_b, d_c, M, N, K);
}
cudaDeviceSynchronize();
@@ -52,7 +52,7 @@ float perf_gemm(
template
float perf_gemm_swizzle(
void (*gpu_hgemm) (const T *, const T *, T *, int, int, int, int),
- int M, int N, int K, int swizzle_stride, int repeat) {
+ int M, int N, int K, int swizzle_stride, int repeat, int warmup = 1) {
size_t size_a = M * K * sizeof(T);
size_t size_b = K * N * sizeof(T);
@@ -65,7 +65,7 @@ float perf_gemm_swizzle(
cudaMalloc(&d_c, size_c);
// warmup
- for (int i = 0; i < 10; ++i){
+ for (int i = 0; i < warmup; ++i){
gpu_hgemm(d_a, d_b, d_c, M, N, K, swizzle_stride);
}
cudaDeviceSynchronize();
diff --git a/kernels/swizzle/README.md b/kernels/swizzle/README.md
new file mode 100644
index 00000000..7faaef39
--- /dev/null
+++ b/kernels/swizzle/README.md
@@ -0,0 +1,125 @@
+# 📖 Learn how to apply SMEM Swizzle for bank conflicts free
+
+## 📚 build bin
+
+```bash
+make
+```
+
+## 📚 ncu profile
+
+Achieve 0 bank conflicts for LDSM via smem swizzle.
+
+```bash
+ncu --metrics l1tex__data_bank_reads ./mat_trans_swizzle.bin
+ncu --metrics l1tex__data_bank_writes ./mat_trans_swizzle.bin
+ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./mat_trans_swizzle.bin
+ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st ./mat_trans_swizzle.bin
+
+ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_swizzle.bin 1024 1024 1024 0 1
+ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_swizzle.bin 1024 1024 1024 0 1
+```
+
+log: (achieve 0 bank conflicts for LDSM via smem swizzle)
+
+```bash
+ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_swizzle.bin 1024 1024 1024 0 1
+[1542675] hgemm_mma_swizzle.bin@127.0.0.1
+ void hgemm_mma_m16n8k16_naive_kernel<16, 8, 16>(__half *, __half *, __half *, int, int, int) (128, 64, 1)x(32, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
+ Section: Command line profiler metrics
+ ------------------------------------------------------------------ ----------- ------------
+ Metric Name Metric Unit Metric Value
+ ------------------------------------------------------------------ ----------- ------------
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg 22795.13
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max 24576
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min 18432
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum 2097152
+ ------------------------------------------------------------------ ----------- ------------
+
+ void hgemm_mma_m16n8k16_naive_smem_swizzle_kernel<16, 8, 16>(__half *, __half *, __half *, int, int, int) (128, 64, 1)x(32, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
+ Section: Command line profiler metrics
+ ------------------------------------------------------------------ ----------- ------------
+ Metric Name Metric Unit Metric Value
+ ------------------------------------------------------------------ ----------- ------------
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg 0
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max 0
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min 0
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum 0
+ ------------------------------------------------------------------ ----------- ------------
+
+ void hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel<16, 8, 16, 2, 4, 4, 4, 0, 0>(__half *, __half *, __half *, int, int, int) (8, 8, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
+ Section: Command line profiler metrics
+ ------------------------------------------------------------------ ----------- ------------
+ Metric Name Metric Unit Metric Value
+ ------------------------------------------------------------------ ----------- ------------
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg 25644.52
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max 36864
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min 0
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum 2359296
+ ------------------------------------------------------------------ ----------- ------------
+
+ void hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle_kernel<16, 8, 16, 2, 4, 4, 4, 0, 8>(__half *, __half *, __half *, int, int, int) (8, 8, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 8.9
+ Section: Command line profiler metrics
+ ------------------------------------------------------------------ ----------- ------------
+ Metric Name Metric Unit Metric Value
+ ------------------------------------------------------------------ ----------- ------------
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg 0
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max 0
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min 0
+ sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum 0
+ ------------------------------------------------------------------ ----------- ------------
+```
+## 📚 print swizzle layout
+```bash
+python3 print_swizzle_layout.py --col 64
+-------------------------------------------
+--------------swizzle layout---------------
+-------------col 0~64, step 8--------------
+-------------------------------------------
+| row 0 | (0, 8, 16, 24, 32, 40, 48, 56) |
+| row 1 | (0, 8, 16, 24, 32, 40, 48, 56) |
+| row 2 | (0, 8, 16, 24, 32, 40, 48, 56) |
+| row 3 | (0, 8, 16, 24, 32, 40, 48, 56) |
+-------------------------------------------
+| row 4 | (8, 0, 24, 16, 40, 32, 56, 48) |
+| row 5 | (8, 0, 24, 16, 40, 32, 56, 48) |
+| row 6 | (8, 0, 24, 16, 40, 32, 56, 48) |
+| row 7 | (8, 0, 24, 16, 40, 32, 56, 48) |
+-------------------------------------------
+| row 8 | (16, 24, 0, 8, 48, 56, 32, 40) |
+| row 9 | (16, 24, 0, 8, 48, 56, 32, 40) |
+| row 10 | (16, 24, 0, 8, 48, 56, 32, 40) |
+| row 11 | (16, 24, 0, 8, 48, 56, 32, 40) |
+-------------------------------------------
+| row 12 | (24, 16, 8, 0, 56, 48, 40, 32) |
+| row 13 | (24, 16, 8, 0, 56, 48, 40, 32) |
+| row 14 | (24, 16, 8, 0, 56, 48, 40, 32) |
+| row 15 | (24, 16, 8, 0, 56, 48, 40, 32) |
+-------------------------------------------
+
+python3 print_swizzle_layout.py --col 16
+-------------------
+--swizzle layout---
+-col 0~16, step 8--
+-------------------
+| row 0 | (0, 8) |
+| row 1 | (0, 8) |
+| row 2 | (0, 8) |
+| row 3 | (0, 8) |
+-------------------
+| row 4 | (8, 0) |
+| row 5 | (8, 0) |
+| row 6 | (8, 0) |
+| row 7 | (8, 0) |
+-------------------
+| row 8 | (0, 8) |
+| row 9 | (0, 8) |
+| row 10 | (0, 8) |
+| row 11 | (0, 8) |
+-------------------
+| row 12 | (8, 0) |
+| row 13 | (8, 0) |
+| row 14 | (8, 0) |
+| row 15 | (8, 0) |
+-------------------
+```
diff --git a/kernels/swizzle/hgemm_mma_swizzle.cu b/kernels/swizzle/hgemm_mma_swizzle.cu
index 37fbd6e4..cbb4f068 100644
--- a/kernels/swizzle/hgemm_mma_swizzle.cu
+++ b/kernels/swizzle/hgemm_mma_swizzle.cu
@@ -1,15 +1,15 @@
#include
#include
+#include
#include
#include
#include
+#include
#include
#include
#include
#include
#include
-#include
-#include
using namespace nvcuda;
#define WARP_SIZE 32
@@ -249,63 +249,270 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel(
}
}
+// i: row index; j: col index
+__device__ __host__ __forceinline__ int swizzle_A_j(int i, int j) {
+ // >>> sw(0,0),sw(0,8),sw(1,0),sw(1,8),sw(2,0),sw(2,8),sw(3,0),sw(3,8)
+ // (0, 8, 0, 8, 0, 8, 0, 8)
+ // >>> sw(4,0),sw(4,8),sw(5,0),sw(5,8),sw(6,0),sw(6,8),sw(7,0),sw(7,8)
+ // (8, 0, 8, 0, 8, 0, 8, 0)
+ // >>> sw(8,0),sw(8,8),sw(9,0),sw(9,8),sw(10,0),sw(10,8),sw(11,0),sw(11,8)
+ // (0, 8, 0, 8, 0, 8, 0, 8)
+ // >>> sw(12,0),sw(12,8),sw(13,0),sw(13,8),sw(14,0),sw(14,8),sw(15,0),sw(15,8)
+ // (8, 0, 8, 0, 8, 0, 8, 0)
+ return ((int(j / 8) ^ int(i / 4)) % 2) * 8;
+}
+
+
+// TODO: hgemm_mma_m16n8k16_naive_smem_swizzle_kernel
+// only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major.
+template
+__global__ void hgemm_mma_m16n8k16_naive_smem_swizzle_kernel(
+ half* A, half* B, half* C, int M, int N, int K) {
+ const int bx = blockIdx.x;
+ const int by = blockIdx.y;
+ const int NUM_K_TILES = div_ceil(K, MMA_K);
+ constexpr int BM = MMA_M; // 16
+ constexpr int BN = MMA_N; // 8
+ constexpr int BK = MMA_K; // 16
+
+ __shared__ half s_a[MMA_M][MMA_K]; // 16x16
+ __shared__ half s_b[MMA_K][MMA_N]; // 16x8
+ __shared__ half s_c[MMA_M][MMA_N]; // 16x8
+
+ const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block
+ const int lane_id = tid % WARP_SIZE; // 0~31
+
+ // s_a[16][16], 每行16,每线程load 8,需要2线程,共16行,需2x16=32线程
+ const int load_smem_a_m = tid / 2; // row 0~15
+ const int load_smem_a_k = (tid % 2) * 8; // col 0,8
+ // s_b[16][8], 每行8,每线程load 8,需要1线程,共16行,需16线程,只需一半线程加载
+ const int load_smem_b_k = tid; // row 0~31, but only use 0~15
+ const int load_smem_b_n = 0; // col 0
+ const int load_gmem_a_m = by * BM + load_smem_a_m; // global m
+ const int load_gmem_b_n = bx * BN + load_smem_b_n; // global n
+ if (load_gmem_a_m >= M && load_gmem_b_n >= N) return;
+
+ uint32_t RC[2] = {0, 0};
-// --------------------- PyTorch bindings for custom kernel -----------------------
-#define STRINGFY(str) #str
-#define TORCH_BINDING_COMMON_EXTENSION(func) \
- m.def(STRINGFY(func), &func, STRINGFY(func));
+ #pragma unroll
+ for (int k = 0; k < NUM_K_TILES; ++k) {
+ // gmem_a -> smem_a
+ int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a
+ int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
+ // LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = (
+ // LDST128BITS(A[load_gmem_a_addr]));
+ LDST128BITS(s_a[load_smem_a_m][swizzle_A_j(
+ load_smem_a_m, load_smem_a_k)]) = (LDST128BITS(A[load_gmem_a_addr]));
+
+ // gmem_b -> smem_b
+ if (lane_id < MMA_K) {
+ int load_gmem_b_k = k * MMA_K + load_smem_b_k; // global row of b
+ int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
+ LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = (
+ LDST128BITS(B[load_gmem_b_addr]));
+ }
+ __syncthreads();
-#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \
-if(((T).options().dtype() != (th_type))) { \
- std::cout << "Tensor Info:" << (T).options() << std::endl; \
- throw std::runtime_error("values must be "#th_type); \
+ uint32_t RA[4];
+ uint32_t RB[2];
+
+ // ldmatrix for s_a, ldmatrix.trans for s_b.
+ // s_a: (0,1)*8 -> 0,8 -> [(0~15),(0,8)]
+ // uint32_t load_smem_a_ptr = __cvta_generic_to_shared(
+ // &s_a[lane_id % 16][(lane_id / 16) * 8]);
+ uint32_t load_smem_a_ptr = __cvta_generic_to_shared(
+ &s_a[lane_id % 16][swizzle_A_j(lane_id % 16, (lane_id / 16) * 8)]);
+ LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], load_smem_a_ptr);
+ uint32_t load_smem_b_ptr = __cvta_generic_to_shared(
+ &s_b[lane_id % 16][0]);
+ LDMATRIX_X2_T(RB[0], RB[1], load_smem_b_ptr);
+
+ HMMA16816(RC[0], RC[1], RA[0], RA[1], RA[2], RA[3], RB[0], RB[1], RC[0], RC[1]);
+
+ __syncthreads();
+ }
+
+ // s_c[16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+ // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+ // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16]
+ LDST32BITS(s_c[lane_id / 4 ][(lane_id % 4) * 2]) = LDST32BITS(RC[0]);
+ LDST32BITS(s_c[lane_id / 4 + 8][(lane_id % 4) * 2]) = LDST32BITS(RC[1]);
+
+ __syncthreads();
+
+ // store s_c[16][8]
+ if (lane_id < MMA_M) {
+ // store 128 bits per memory issue.
+ int store_gmem_c_m = by * BM + lane_id;
+ int store_gmem_c_n = bx * BN;
+ int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n;
+ LDST128BITS(C[store_gmem_c_addr]) = (LDST128BITS(s_c[lane_id][0]));
+ }
}
-#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \
-if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \
- throw std::runtime_error("Tensor size mismatch!"); \
+// 128x128, mma2x4, warp4x4(64,32,16)
+template
+__global__ void __launch_bounds__(256)
+hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle_kernel(
+ half* A, half* B, half* C, int M, int N, int K) {
+ const int bx = blockIdx.x;
+ const int by = blockIdx.y;
+ const int NUM_K_TILES = div_ceil(K, MMA_K);
+ constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128
+ constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128
+ constexpr int BK = MMA_K; // 16
+
+ __shared__ half s_a[BM][BK+A_PAD]; // 128*16*2=4KB
+ __shared__ half s_b[BK][BN+B_PAD]; // 16*128*2=4KB, 16*(128+16)*2=4.5KB
+
+ const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block
+ const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block
+ const int lane_id = tid % WARP_SIZE; // 0~31
+ const int warp_m = warp_id % 2; // 0,1
+ const int warp_n = warp_id / 2; // 0,1,2,3
+
+ // 先计算shared memory中的索引
+ // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=16 按行读取 A行主序
+ // 对于s_a每行16个数据,每个线程读取8个,需要2个线程;总共128行,需要128x2刚好256线程
+ int load_smem_a_m = tid / 2; // row 0~127
+ int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8
+ // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B行主序
+ // 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共16行,需要16x16=256个线程
+ int load_smem_b_k = tid / 16; // row 0~15
+ int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120
+ // 再计算全局内存中的索引
+ // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块
+ int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c
+ int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c
+ if (load_gmem_a_m >= M || load_gmem_b_n >= N) return;
+
+ uint32_t RC[WARP_TILE_M][WARP_TILE_N][2];
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ RC[i][j][0] = 0;
+ RC[i][j][1] = 0;
+ }
+ }
+
+ #pragma unroll
+ for (int k = 0; k < NUM_K_TILES; ++k) {
+ // gmem -> smem
+ int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a
+ int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
+ int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b
+ int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
+ LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = (
+ LDST128BITS(B[load_gmem_b_addr]));
+ // LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = (
+ // LDST128BITS(A[load_gmem_a_addr]));
+ LDST128BITS(s_a[load_smem_a_m][swizzle_A_j(
+ load_smem_a_m, load_smem_a_k)]) = (LDST128BITS(A[load_gmem_a_addr]));
+ __syncthreads();
+
+ // ldmatrix for s_a, ldmatrix.trans for s_b.
+ uint32_t RA[WARP_TILE_M][4];
+ uint32_t RB[WARP_TILE_N][2];
+
+ // smem -> reg
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
+ int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15
+ int lane_smem_a_k = (lane_id / 16) * 8; // 0,8
+ // uint32_t lane_smem_a_ptr = __cvta_generic_to_shared(
+ // &s_a[lane_smem_a_m][lane_smem_a_k]);
+ uint32_t lane_smem_a_ptr = __cvta_generic_to_shared(
+ &s_a[lane_smem_a_m][swizzle_A_j(lane_smem_a_m, lane_smem_a_k)]);
+ LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr);
+ }
+
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
+ int lane_smem_b_k = lane_id % 16; // 0~15
+ int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8
+ uint32_t lane_smem_b_ptr = __cvta_generic_to_shared(
+ &s_b[lane_smem_b_k][lane_smem_b_n]);
+ LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr);
+ }
+
+ // MMA compute
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ HMMA16816(RC[i][j][0], RC[i][j][1],
+ RA[i][0], RA[i][1], RA[i][2], RA[i][3],
+ RB[j][0], RB[j][1],
+ RC[i][j][0], RC[i][j][1]);
+ }
+ }
+ __syncthreads();
+ }
+
+ // reg -> gmem, MMA_MxMMA_N=16x8
+ #pragma unroll
+ for (int i = 0; i < WARP_TILE_M; ++i) {
+ #pragma unroll
+ for (int j = 0; j < WARP_TILE_N; ++j) {
+ int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M;
+ int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N;
+ // mapping lane smem index -> global index.
+ // [16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+ // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+ // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16]
+ int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4;
+ int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n + (lane_id % 4) * 2;
+ int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n;
+ int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n;
+ // TODO: how to use LDST128BITS here ? reverse the loop order ?
+ LDST32BITS(C[store_gmem_c_addr_0]) = LDST32BITS(RC[i][j][0]);
+ LDST32BITS(C[store_gmem_c_addr_1]) = LDST32BITS(RC[i][j][1]);
+ }
+ }
}
-// only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major.
-void hgemm_mma_m16n8k16_naive(
- torch::Tensor a, torch::Tensor b, torch::Tensor c) {
- CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
- CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
- CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
- const int M = a.size(0);
- const int K = a.size(1);
- const int N = b.size(1);
- CHECK_TORCH_TENSOR_SHAPE(a, M, K)
- CHECK_TORCH_TENSOR_SHAPE(b, K, N)
- CHECK_TORCH_TENSOR_SHAPE(c, M, N)
+// launcher
+void launch_hgemm_mma_m16n8k16_naive(
+ half* a, half* b, half* c, int M, int N, int K) {
constexpr int MMA_M = 16;
constexpr int MMA_N = 8;
- constexpr int MMA_K = 16;
-
+ constexpr int MMA_K = 16;
dim3 block(WARP_SIZE);
dim3 grid(div_ceil(N, MMA_N), div_ceil(M, MMA_M));
hgemm_mma_m16n8k16_naive_kernel<
MMA_M, MMA_N, MMA_K><<>>(
- reinterpret_cast(a.data_ptr()),
- reinterpret_cast(b.data_ptr()),
- reinterpret_cast(c.data_ptr()),
- M, N, K
+ a, b, c, M, N, K
);
}
-// 128x128, mma2x4, warp4x4(64,32,16)
-void hgemm_mma_m16n8k16_mma2x4_warp4x4(
- torch::Tensor a, torch::Tensor b, torch::Tensor c) {
- CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf)
- CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf)
- CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf)
- const int M = a.size(0);
- const int K = a.size(1);
- const int N = b.size(1);
- CHECK_TORCH_TENSOR_SHAPE(a, M, K)
- CHECK_TORCH_TENSOR_SHAPE(b, K, N)
- CHECK_TORCH_TENSOR_SHAPE(c, M, N)
+void launch_hgemm_mma_m16n8k16_naive_smem_swizzle(
+ half* a, half* b, half* c, int M, int N, int K) {
+ constexpr int MMA_M = 16;
+ constexpr int MMA_N = 8;
+ constexpr int MMA_K = 16;
+ dim3 block(WARP_SIZE);
+ dim3 grid(div_ceil(N, MMA_N), div_ceil(M, MMA_M));
+
+ hgemm_mma_m16n8k16_naive_smem_swizzle_kernel<
+ MMA_M, MMA_N, MMA_K><<>>(
+ a, b, c, M, N, K
+ );
+}
+
+void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4(
+ half* a, half* b, half* c, int M, int N, int K) {
constexpr int MMA_M = 16;
constexpr int MMA_N = 8;
constexpr int MMA_K = 16;
@@ -313,21 +520,137 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4(
constexpr int MMA_TILE_N = 4;
constexpr int WARP_TILE_M = 4;
constexpr int WARP_TILE_N = 4;
+ // bank conflicts free via pad = 8, 拒绝幻想,相信profile
+ // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_swizzle.bin
+ // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_swizzle.bin
+ // constexpr int A_PAD = 8;
+ // constexpr int B_PAD = 8;
constexpr int A_PAD = 0;
- constexpr int B_PAD = 16;
+ constexpr int B_PAD = 0;
constexpr int NUM_THREADS= (
- MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
-
+ MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
dim3 block(NUM_THREADS);
dim3 grid(div_ceil(N, MMA_N * MMA_TILE_N * WARP_TILE_N),
div_ceil(M, MMA_M * MMA_TILE_M * WARP_TILE_M));
-
+
hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel<
MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N,
- WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD><<>>(
- reinterpret_cast(a.data_ptr()),
- reinterpret_cast(b.data_ptr()),
- reinterpret_cast(c.data_ptr()),
- M, N, K
+ WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD><<<
+ grid, block>>>(
+ a, b, c, M, N, K
);
}
+
+void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle(
+ half* a, half* b, half* c, int M, int N, int K) {
+ constexpr int MMA_M = 16;
+ constexpr int MMA_N = 8;
+ constexpr int MMA_K = 16;
+ constexpr int MMA_TILE_M = 2;
+ constexpr int MMA_TILE_N = 4;
+ constexpr int WARP_TILE_M = 4;
+ constexpr int WARP_TILE_N = 4;
+ constexpr int A_PAD = 0;
+ constexpr int B_PAD = 8;
+ constexpr int NUM_THREADS= (
+ MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256
+ dim3 block(NUM_THREADS);
+ dim3 grid(div_ceil(N, MMA_N * MMA_TILE_N * WARP_TILE_N),
+ div_ceil(M, MMA_M * MMA_TILE_M * WARP_TILE_M));
+
+ hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle_kernel<
+ MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N,
+ WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD><<<
+ grid, block>>>(
+ a, b, c, M, N, K
+ );
+}
+
+template
+float perf_gemm(
+ void (*gpu_hgemm) (T *, T *, T *, int, int, int),
+ int M, int N, int K, int warmup, int repeat) {
+
+ size_t size_a = M * K * sizeof(T);
+ size_t size_b = K * N * sizeof(T);
+ size_t size_c = M * N * sizeof(T);
+
+ T *d_a, *d_b;
+ T *d_c;
+ cudaMalloc(&d_a, size_a);
+ cudaMalloc(&d_b, size_b);
+ cudaMalloc(&d_c, size_c);
+
+ // warmup
+ for (int i = 0; i < warmup; ++i){
+ gpu_hgemm(d_a, d_b, d_c, M, N, K);
+ }
+ cudaDeviceSynchronize();
+
+ cudaEvent_t start, end;
+ cudaEventCreate(&start);
+ cudaEventCreate(&end);
+ cudaEventRecord(start);
+ for (int i = 0; i < repeat; i++) {
+ gpu_hgemm(d_a, d_b, d_c, M, N, K);
+ }
+ cudaEventRecord(end);
+ cudaDeviceSynchronize();
+ cudaEventSynchronize(end);
+
+ float msec, sec;
+ cudaEventElapsedTime(&msec, start, end);
+ sec = msec / 1000.0 / repeat;
+
+ cudaFree(d_a);
+ cudaFree(d_b);
+ cudaFree(d_c);
+ cudaEventDestroy(start);
+ cudaEventDestroy(end);
+
+ return sec;
+}
+
+int main(int argc, char *argv[]) {
+ int M = 1024;
+ int N = 1024;
+ int K = 1024;
+ int W = 1;
+ int R = 10;
+ if (argc > 1) M = std::stoi(argv[1]);
+ if (argc > 2) N = std::stoi(argv[2]);
+ if (argc > 3) K = std::stoi(argv[3]);
+ if (argc > 4) W = std::stoi(argv[4]);
+ if (argc > 5) R = std::stoi(argv[5]);
+ double avg_sec, avg_Tflops;
+
+ printf("\nALGO = HGEMM MMA NAIVE\n");
+ avg_sec = perf_gemm(launch_hgemm_mma_m16n8k16_naive,
+ M, N, K, W, R);
+ avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
+ printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R);
+ printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops);
+
+ printf("\nALGO = HGEMM MMA NAIVE + SMEM SWIZZLE\n");
+ avg_sec = perf_gemm(launch_hgemm_mma_m16n8k16_naive_smem_swizzle,
+ M, N, K, W, R);
+ avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
+ printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R);
+ printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops);
+
+ printf("\nALGO = HGEMM mma2x4_warp4x4\n");
+ avg_sec = perf_gemm(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4,
+ M, N, K, W, R);
+ avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
+ printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R);
+ printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops);
+
+ printf("\nALGO = HGEMM mma2x4_warp4x4 + SMEM SWIZZLE\n");
+ avg_sec = perf_gemm(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle,
+ M, N, K, W, R);
+ avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec;
+ printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R);
+ printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops);
+
+ return 0;
+}
diff --git a/kernels/swizzle/makefile b/kernels/swizzle/makefile
new file mode 100644
index 00000000..f08b1c59
--- /dev/null
+++ b/kernels/swizzle/makefile
@@ -0,0 +1,17 @@
+INCLUDE_DIRS=-I ../../third-party/cutlass/include -I ../../third-party/cutlass/tools/util/include
+ARCHS=-gencode arch=compute_80,code=sm_80 -gencode arch=compute_89,code=sm_89
+ARCHS_89=-gencode arch=compute_89,code=sm_89
+DEFAULT_FLAGS=-O2 $(ARCHS) -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas
+DEFAULT_FLAGS_89=-O2 $(ARCHS_89) -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas
+default:
+ nvcc hgemm_mma_swizzle.cu -o hgemm_mma_swizzle.bin $(DEFAULT_FLAGS)
+ nvcc mat_trans_swizzle.cu -o mat_trans_swizzle.bin $(DEFAULT_FLAGS)
+ nvcc mma_simple_swizzle.cu -o mma_simple_swizzle.bin $(DEFAULT_FLAGS)
+hgemm_89:
+ nvcc hgemm_mma_swizzle.cu -o hgemm_mma_swizzle.89.bin $(DEFAULT_FLAGS_89)
+mma_89:
+ nvcc mma_simple_swizzle.cu -o mma_simple_swizzle.89.bin $(DEFAULT_FLAGS_89)
+mat_89:
+ nvcc mat_trans_swizzle.cu -o mat_trans_swizzle.89.bin $(DEFAULT_FLAGS_89)
+clean:
+ rm -rf *.bin
diff --git a/kernels/swizzle/mat_trans_swizzle.cu b/kernels/swizzle/mat_trans_swizzle.cu
new file mode 100644
index 00000000..c526771d
--- /dev/null
+++ b/kernels/swizzle/mat_trans_swizzle.cu
@@ -0,0 +1,108 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+
+// reference: https://zhuanlan.zhihu.com/p/4746910252
+// 转置前的矩阵存储在dev_A中,矩阵大小为M*N,转置后的数据存储在dev_B中
+__global__ void mat_trans_smem_naive_kernel(int* dev_A, int M, int N, int* dev_B) {
+ int row = blockIdx.y * blockDim.y + threadIdx.y;
+ int col = blockIdx.x * blockDim.x + threadIdx.x;
+
+ // 每个block处理32*32的矩阵块
+ __shared__ int s_data[32][32];
+
+ if (row < M && col < N) {
+ // 从全局内存中加载数据,转置后写到共享内存中
+ s_data[threadIdx.x][threadIdx.y] = dev_A[row * N + col];
+ __syncthreads();
+ int n_col = blockIdx.y * blockDim.y + threadIdx.x;
+ int n_row = blockIdx.x * blockDim.x + threadIdx.y;
+ if (n_col < M && n_row < N) {
+ // 从转置后的共享内存按行写到全局内存结果中
+ dev_B[n_row * M + n_col] = s_data[threadIdx.y][threadIdx.x];
+ }
+ }
+}
+
+// reference: https://zhuanlan.zhihu.com/p/4746910252
+__global__ void mat_trans_smem_padding_kernel(int* dev_A, int M, int N, int* dev_B) {
+ int row = blockIdx.y * blockDim.y + threadIdx.y;
+ int col = blockIdx.x * blockDim.x + threadIdx.x;
+
+ // 每个block处理32*32的矩阵块,尾部padding来避免bank conflict
+ __shared__ int s_data[32][33];
+
+ if (row < M && col < N) {
+ s_data[threadIdx.x][threadIdx.y] = dev_A[row * N + col];
+ __syncthreads();
+ int n_col = blockIdx.y * blockDim.y + threadIdx.x;
+ int n_row = blockIdx.x * blockDim.x + threadIdx.y;
+ if (n_col < M && n_row < N) {
+ dev_B[n_row * M + n_col] = s_data[threadIdx.y][threadIdx.x];
+ }
+ }
+}
+
+// reference: https://zhuanlan.zhihu.com/p/4746910252
+__global__ void mat_trans_smem_swizzle_kernel(int* dev_A, int M, int N, int* dev_B) {
+ int row = blockIdx.y * blockDim.y + threadIdx.y;
+ int col = blockIdx.x * blockDim.x + threadIdx.x;
+
+ __shared__ int s_data[32][32];
+
+ if (row < M && col < N) {
+ // 从全局内存读取数据写入共享内存的逻辑坐标(row=x,col=y)
+ // 其映射的物理存储位置位置(row=x,col=x^y)
+ s_data[threadIdx.x][threadIdx.x ^ threadIdx.y] = dev_A[row * N + col];
+ __syncthreads();
+ int n_col = blockIdx.y * blockDim.y + threadIdx.x;
+ int n_row = blockIdx.x * blockDim.x + threadIdx.y;
+ if (n_row < N && n_col < M) {
+ // 从共享内存的逻辑坐标(row=y,col=x)读取数据
+ // 其映射的物理存储位置(row=y,col=x^y)
+ dev_B[n_row * M + n_col] = s_data[threadIdx.y][threadIdx.x ^ threadIdx.y];
+ }
+ }
+}
+
+int main(int argc, char *argv[]) {
+ int M = 1024;
+ int N = 1024;
+ if (argc > 1) M = std::stoi(argv[1]);
+ if (argc > 2) N = std::stoi(argv[2]);
+ size_t size_a = M * N * sizeof(int);
+ size_t size_b = M * N * sizeof(int);
+
+ int* dev_A;
+ int* dev_B;
+ cudaMalloc(&dev_A, size_a);
+ cudaMalloc(&dev_B, size_b);
+ cudaDeviceSynchronize();
+
+ dim3 block(32, 32);
+ dim3 grid(N/32, M/32);
+
+ mat_trans_smem_naive_kernel<<>>(dev_A, M, N, dev_B);
+ cudaDeviceSynchronize();
+
+ mat_trans_smem_padding_kernel<<>>(dev_A, M, N, dev_B);
+ cudaDeviceSynchronize();
+
+ mat_trans_smem_swizzle_kernel<<>>(dev_A, M, N, dev_B);
+ cudaDeviceSynchronize();
+
+ printf("Done.\n");
+ cudaFree(dev_A);
+ cudaFree(dev_B);
+
+ return 0;
+}
diff --git a/kernels/swizzle/matrix_trans_swizzle.cu b/kernels/swizzle/matrix_trans_swizzle.cu
deleted file mode 100644
index 105dbf23..00000000
--- a/kernels/swizzle/matrix_trans_swizzle.cu
+++ /dev/null
@@ -1,35 +0,0 @@
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-#include
-using namespace nvcuda;
-
-// reference: https://zhuanlan.zhihu.com/p/4746910252
-__global__ void matrix_trans_swizzling(int* dev_A, int M, int N, int* dev_B) {
- int row = blockIdx.y * blockDim.y + threadIdx.y;
- int col = blockIdx.x * blockDim.x + threadIdx.x;
-
- __shared__ int s_data[32][32];
-
- if (row < M && col < N) {
- // 从全局内存读取数据写入共享内存的逻辑坐标(row=x,col=y)
- // 其映射的物理存储位置位置(row=x,col=x^y)
- s_data[threadIdx.x][threadIdx.x ^ threadIdx.y] = dev_A[row * N + col];
- __syncthreads();
- int n_col = blockIdx.y * blockDim.y + threadIdx.x;
- int n_row = blockIdx.x * blockDim.x + threadIdx.y;
- if (n_row < N && n_col < M) {
- // 从共享内存的逻辑坐标(row=y,col=x)读取数据
- // 其映射的物理存储位置(row=y,col=x^y)
- dev_B[n_row * M + n_col] = s_data[threadIdx.y][threadIdx.x ^ threadIdx.y];
- }
- }
-}
diff --git a/kernels/swizzle/mma_simple_swizzle.cu b/kernels/swizzle/mma_simple_swizzle.cu
new file mode 100644
index 00000000..01bb8939
--- /dev/null
+++ b/kernels/swizzle/mma_simple_swizzle.cu
@@ -0,0 +1,202 @@
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+using namespace nvcuda;
+
+#define WARP_SIZE 32
+#define DEVICE_INLINE __device__ inline
+#define HOST_DEVICE_INLINE __device__ __host__ inline
+#define INT4(value) (reinterpret_cast(&(value))[0])
+#define FLOAT4(value) (reinterpret_cast(&(value))[0])
+#define HALF2(value) (reinterpret_cast(&(value))[0])
+#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0])
+#define LDST32BITS(value) (reinterpret_cast(&(value))[0])
+#define LDST64BITS(value) (reinterpret_cast(&(value))[0])
+#define LDST128BITS(value) (reinterpret_cast(&(value))[0])
+#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::)
+#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::)
+#define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n))
+// ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes.
+#define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
+#define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes))
+#define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr))
+#define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr))
+#define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr))
+#define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr))
+#define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr))
+#define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr))
+#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1))
+
+HOST_DEVICE_INLINE
+int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 8) : (a / b); }
+
+// i: row index; j: col index
+__device__ __host__ __forceinline__ int swizzle_j(int i, int j) {
+ // >>> sw(0,0),sw(0,8),sw(1,0),sw(1,8),sw(2,0),sw(2,8),sw(3,0),sw(3,8)
+ // (0, 8, 0, 8, 0, 8, 0, 8)
+ // >>> sw(4,0),sw(4,8),sw(5,0),sw(5,8),sw(6,0),sw(6,8),sw(7,0),sw(7,8)
+ // (8, 0, 8, 0, 8, 0, 8, 0)
+ // >>> sw(8,0),sw(8,8),sw(9,0),sw(9,8),sw(10,0),sw(10,8),sw(11,0),sw(11,8)
+ // (0, 8, 0, 8, 0, 8, 0, 8)
+ // >>> sw(12,0),sw(12,8),sw(13,0),sw(13,8),sw(14,0),sw(14,8),sw(15,0),sw(15,8)
+ // (8, 0, 8, 0, 8, 0, 8, 0)
+ return ((int(j / 8) ^ int(i / 4)) % 2) * 8;
+}
+
+
+template
+__global__ void mma_simple_swizzle_kernel(
+ half* A, half* B, half* C, int M, int N, int K) {
+ const int bx = blockIdx.x;
+ const int by = blockIdx.y;
+ const int NUM_K_TILES = div_ceil(K, MMA_K);
+ constexpr int BM = MMA_M; // 16
+ constexpr int BN = MMA_N; // 8
+ constexpr int BK = MMA_K; // 16
+
+ __shared__ half s_a[MMA_M][MMA_K]; // 16x16
+ __shared__ half s_b[MMA_K][MMA_N]; // 16x8
+
+ const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block
+ const int lane_id = tid % WARP_SIZE; // 0~31
+
+ // s_a[16][16], 每行16,每线程load 8,需要2线程,共16行,需2x16=32线程
+ const int load_smem_a_m = tid / 2; // row 0~15
+ const int load_smem_a_k = (tid % 2) * 8; // col 0,8
+ // s_b[16][8], 每行8,每线程load 8,需要1线程,共16行,需16线程,只需一半线程加载
+ const int load_smem_b_k = tid; // row 0~31, but only use 0~15
+ const int load_smem_b_n = 0; // col 0
+ const int load_gmem_a_m = by * BM + load_smem_a_m; // global m
+ const int load_gmem_b_n = bx * BN + load_smem_b_n; // global n
+ if (load_gmem_a_m >= M && load_gmem_b_n >= N) return;
+
+ uint32_t RC[2] = {0, 0};
+
+ #pragma unroll
+ for (int k = 0; k < NUM_K_TILES; ++k) {
+ // gmem_a -> smem_a
+ int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a
+ int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k;
+ // LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = (
+ // LDST128BITS(A[load_gmem_a_addr]));
+ LDST128BITS(s_a[load_smem_a_m][swizzle_j(
+ load_smem_a_m, load_smem_a_k)]) = (LDST128BITS(A[load_gmem_a_addr]));
+
+ // gmem_b -> smem_b
+ if (lane_id < MMA_K) {
+ int load_gmem_b_k = k * MMA_K + load_smem_b_k; // global row of b
+ int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n;
+ LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = (
+ LDST128BITS(B[load_gmem_b_addr]));
+ }
+ __syncthreads();
+ if (tid == 0) {
+ printf("\n");
+ for (int i = 0; i < MMA_M; i++) {
+ for (int j = 0; j < MMA_K; j++) {
+ printf("A[%2d][%2d]=%4d, ", i, j, __half2int_rz(s_a[i][j]));
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+
+ if (tid == 0) {
+ printf("\n");
+ for (int i = 0; i < MMA_K; i++) {
+ for (int j = 0; j < MMA_N; j++) {
+ printf("B[%2d][%2d]=%4d, ", i, j, __half2int_rz(s_b[i][j]));
+ }
+ printf("\n");
+ }
+ }
+ __syncthreads();
+
+ uint32_t RA[4];
+ uint32_t RB[2];
+
+ // ldmatrix for s_a, ldmatrix.trans for s_b.
+ // s_a: (0,8) *8 -> 0,8 -> [(0~15),(0,8)]
+ // uint32_t load_smem_a_ptr = __cvta_generic_to_shared(
+ // &s_a[lane_id % 16][(lane_id / 16) * 8]);
+ uint32_t load_smem_a_ptr = __cvta_generic_to_shared(
+ &s_a[lane_id % 16][swizzle_j(lane_id % 16, (lane_id / 16) * 8)]);
+ LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], load_smem_a_ptr);
+ uint32_t load_smem_b_ptr = __cvta_generic_to_shared(
+ &s_b[lane_id % 16][0]);
+ LDMATRIX_X2_T(RB[0], RB[1], load_smem_b_ptr);
+
+ HMMA16816(RC[0], RC[1], RA[0], RA[1], RA[2], RA[3], RB[0], RB[1], RC[0], RC[1]);
+
+ __syncthreads();
+ }
+
+ // s_c[16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html
+ // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type
+ // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16]
+ int store_lane_gmem_c_m = by * BM + lane_id / 4;
+ int store_lane_gmem_c_n = bx * BN + (lane_id % 4) * 2;
+ int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n;
+ int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n;
+ LDST32BITS(C[store_gmem_c_addr_0]) = LDST32BITS(RC[0]);
+ LDST32BITS(C[store_gmem_c_addr_1]) = LDST32BITS(RC[1]);
+}
+
+int main(int argc, char *argv[]) {
+ int M = 16;
+ int N = 8;
+ int K = 16;
+ if (argc > 8) M = std::stoi(argv[1]);
+ if (argc > 2) N = std::stoi(argv[2]);
+ if (argc > 3) K = std::stoi(argv[3]);
+
+ size_t size_a = M * K * sizeof(half);
+ size_t size_b = K * N * sizeof(half);
+ size_t size_c = M * N * sizeof(half);
+
+ half *h_a, *h_b, *h_c;
+ half *d_a, *d_b, *d_c;
+ h_a = (half *)malloc(size_a);
+ h_b = (half *)malloc(size_b);
+ h_c = (half *)malloc(size_c);
+
+ cudaMalloc(&d_a, size_a);
+ cudaMalloc(&d_b, size_b);
+ cudaMalloc(&d_c, size_c);
+
+ for (int i = 0; i < M * K; i++)
+ h_a[i] = __float2half((float)i); // 0~255 16x16=256
+ for (int i = 0; i < K * N; i++)
+ h_b[i] = __float2half((float)i); // 0~127 16x8=128
+
+ cudaMemcpy(d_a, h_a, size_a, cudaMemcpyHostToDevice);
+ cudaMemcpy(d_b, h_b, size_b, cudaMemcpyHostToDevice);
+
+ constexpr int MMA_M = 16;
+ constexpr int MMA_N = 8;
+ constexpr int MMA_K = 16;
+ dim3 block(WARP_SIZE);
+ dim3 grid(div_ceil(N, MMA_N), div_ceil(M, MMA_M));
+
+ mma_simple_swizzle_kernel<
+ MMA_M, MMA_N, MMA_K><<>>(
+ d_a, d_b, d_c, M, N, K
+ );
+ cudaFree(d_a);
+ cudaFree(d_b);
+ cudaFree(d_c);
+ free(h_a);
+ free(h_b);
+ free(h_c);
+
+ return 0;
+}
diff --git a/kernels/swizzle/print_swizzle_layout.py b/kernels/swizzle/print_swizzle_layout.py
new file mode 100644
index 00000000..dcb9d48e
--- /dev/null
+++ b/kernels/swizzle/print_swizzle_layout.py
@@ -0,0 +1,46 @@
+import argparse
+
+
+def pretty_print_line(m: str = "", sep: str = "-", width: int = 130):
+ res_len = width - len(m)
+ left_len = int(res_len / 2)
+ right_len = res_len - left_len
+ pretty_line = sep * left_len + m + sep * right_len
+ print(pretty_line)
+
+
+def swizzle_permuted_j(i: int, j: int, col_stride: int = 64, step: int = 8):
+ # i: row index; j: col index.
+ return ((int(j / step) ^ int(i / 4)) % int(col_stride / step)) * step
+
+
+def print_swizzle_layout(rows: int = 16, col_stride: int = 64, step: int = 8):
+ str_len = 0
+ for i in range(rows):
+ layout = tuple(swizzle_permuted_j(i, j, col_stride, step)
+ for j in range(0, col_stride, step))
+ layout_str = (f"| row {i:<2} | {layout} |")
+ str_len = len(layout_str)
+ if (i == 0):
+ print("-" * str_len)
+ pretty_print_line(f"swizzle layout", width=str_len)
+ pretty_print_line(f"col 0~{col_stride}, step {step}", width=str_len)
+ print("-" * str_len)
+ print(layout_str)
+ if ((i + 1) % 4 == 0 and i != (rows - 1)):
+ print("-" * str_len)
+ print("-" * str_len)
+
+
+def get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--col_stride", "--col", type=int, default=64)
+ parser.add_argument("--step", type=int, default=8)
+ parser.add_argument("--rows", type=int, default=16)
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = get_args()
+ print_swizzle_layout(args.rows, args.col_stride, args.step)
+