From bdd361a2e36c66c2112f75ab84607fc48e804b58 Mon Sep 17 00:00:00 2001 From: DefTruth <31974251+DefTruth@users.noreply.github.com> Date: Wed, 25 Dec 2024 13:46:35 +0800 Subject: [PATCH] =?UTF-8?q?[FA2]=20fa2/hgemm=20manually=20smem=20swizzle?= =?UTF-8?q?=F0=9F=8E=89=20(#185)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update flash_attn_mma.py * Create makefile * Create README.md * Update and rename matrix_trans_swizzle.cu to mat_trans_swizzle.cu * Update hgemm_mma_swizzle.cu * Update mat_trans_swizzle.cu * Update and rename flash_attn_mma_swizzle_qkv.cu to flash_attn_mma_share_kv_swizzle.cu * Create flash_attn_mma_share_qkv_swizzle.cu * Create flash_attn_mma_split_q_swizzle.cu * Create flash_attn_mma_split_kv_swizzle.cu * Create flash_attn_mma_tiling_qk_swizzle.cu * Create flash_attn_mma_tiling_qkv_swizzle.cu * Update flash_attn_mma_share_qkv_swizzle.cu * Update flash_attn_mma_split_kv_swizzle.cu * Update flash_attn_mma_split_q_swizzle.cu * Update flash_attn_mma_tiling_qk_swizzle.cu * Update flash_attn_mma_tiling_qkv_swizzle.cu * Update README.md * Update hgemm_mma_swizzle.cu * Update makefile * Update README.md * Update README.md * Update mat_trans_swizzle.cu * Update makefile * Update hgemm_mma_swizzle.cu * Update hgemm_mma_swizzle.cu * Update README.md * Update hgemm_mma_stage.cu * Update hgemm_mma.cu * Update makefile * Update utils.h * Create mma_simple_swizzle.cu * Update makefile * Update mma_simple_swizzle.cu * Update hgemm_mma_swizzle.cu * Update makefile * Update utils.py * Update makefile * Create hgemm_mma_stage_swizzle.cu * Update hgemm.py * Update hgemm.cc * Update mat_trans_swizzle.cu * Update flash_attn_mma_tiling_qk_swizzle.cu * Update flash_attn.cc * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma.py * Update flash_attn_mma_tiling_qk_swizzle.cu * Update flash_attn_mma_tiling_qk_swizzle.cu * Update flash_attn_mma_share_kv_swizzle.cu * Update README.md * Update README.md * Create print_swizzle_layout.py * Update flash_attn_mma_tiling_qk_swizzle.cu * Update flash_attn_mma_share_kv_swizzle.cu * Update README.md * Update hgemm_mma_stage_swizzle.cu * Update README.md * Update README.md * Update README.md * Update mma_simple_swizzle.cu * Create print_swizzle_layout.py * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md * Update README.md --- README.md | 64 +- kernels/flash-attn/README.md | 2 +- kernels/flash-attn/flash_attn_mma.py | 15 +- .../mma/flash_attn_mma_share_kv_swizzle.cu | 934 ++++++++++++++ ...cu => flash_attn_mma_share_qkv_swizzle.cu} | 4 +- .../mma/flash_attn_mma_split_kv_swizzle.cu | 2 + .../mma/flash_attn_mma_split_q_swizzle.cu | 2 + .../mma/flash_attn_mma_tiling_qk_swizzle.cu | 1085 +++++++++++++++++ .../mma/flash_attn_mma_tiling_qkv_swizzle.cu | 2 + kernels/flash-attn/pybind/flash_attn.cc | 8 + .../flash-attn/tools/print_swizzle_layout.py | 46 + kernels/hgemm/README.md | 5 +- kernels/hgemm/hgemm.py | 10 +- kernels/hgemm/makefile | 7 + kernels/hgemm/mma/hgemm_mma.cu | 7 +- kernels/hgemm/mma/hgemm_mma_stage.cu | 100 +- kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu | 853 +++++++++++++ kernels/hgemm/pybind/hgemm.cc | 4 + kernels/hgemm/tools/utils.py | 1 + kernels/hgemm/utils/utils.h | 8 +- kernels/swizzle/README.md | 125 ++ kernels/swizzle/hgemm_mma_swizzle.cu | 427 ++++++- kernels/swizzle/makefile | 17 + kernels/swizzle/mat_trans_swizzle.cu | 108 ++ kernels/swizzle/matrix_trans_swizzle.cu | 35 - kernels/swizzle/mma_simple_swizzle.cu | 202 +++ kernels/swizzle/print_swizzle_layout.py | 46 + 27 files changed, 3940 insertions(+), 179 deletions(-) create mode 100644 kernels/flash-attn/mma/flash_attn_mma_share_kv_swizzle.cu rename kernels/flash-attn/mma/{flash_attn_mma_swizzle_qkv.cu => flash_attn_mma_share_qkv_swizzle.cu} (98%) create mode 100644 kernels/flash-attn/mma/flash_attn_mma_split_kv_swizzle.cu create mode 100644 kernels/flash-attn/mma/flash_attn_mma_split_q_swizzle.cu create mode 100644 kernels/flash-attn/mma/flash_attn_mma_tiling_qk_swizzle.cu create mode 100644 kernels/flash-attn/mma/flash_attn_mma_tiling_qkv_swizzle.cu create mode 100644 kernels/flash-attn/tools/print_swizzle_layout.py create mode 100644 kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu create mode 100644 kernels/swizzle/README.md create mode 100644 kernels/swizzle/makefile create mode 100644 kernels/swizzle/mat_trans_swizzle.cu delete mode 100644 kernels/swizzle/matrix_trans_swizzle.cu create mode 100644 kernels/swizzle/mma_simple_swizzle.cu create mode 100644 kernels/swizzle/print_swizzle_layout.py diff --git a/README.md b/README.md index ed849bcf..09f48c0e 100644 --- a/README.md +++ b/README.md @@ -35,9 +35,9 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d |✔️|✔️|✔️|✔️| |Copy Async|Tile MMA (More Threads)|Tile Warp (More Values)|Multi Stages (2/3/4)| |✔️|✔️|✔️|✔️| -|Reg Double Buffers|Block Swizzle|Warp Swizzle|SMEM Swizzle (CuTe)| +|Reg Double Buffers|Block Swizzle|Warp Swizzle|SMEM Swizzle (CuTe/MMA)| |✔️|✔️|✔️|✔️| -|Collective Store (Warp Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32| +|Collective Store (Shfl)|Row Major (NN)|Col Major (TN)| SGEMM FP32/TF32| |✔️|✔️|✔️|✔️| @@ -48,7 +48,7 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh |Tensor Cores|Loop over Seqlen/Headdim |Tile Block (Br, Bc)|MMA (m16n8k16)| |:---:|:---:|:---:|:---:| |✔️|✔️|✔️|✔️| -|Pack LDST (128 bits)|SMEM Padding|Copy Async|Tile MMA (More Threads)| +|Pack LDST (128 bits)|SMEM **Swizzle**/Padding |Copy Async|Tile MMA (More Threads)| |✔️|✔️|✔️|✔️| |Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Shfl)|**Split KV/Q**| |✔️|✔️|✔️|✔️| @@ -160,7 +160,6 @@ The kernels listed here will guide you through a step-by-step progression, rangi |📖 CUDA Kernel| 📖 Elem DType| 📖 Acc DType| 📖 Docs | 📖 Level | |:---|:---|:---|:---|:---| -| ✔️ [nsys/ncu(timeline/ptx/sass)](./kernels/nvidia-nsight/)|/|/|[link](./kernels/nvidia-nsight/)|⭐️| | ✔️ [elementwise_f32](./kernels/elementwise/elementwise.cu)|f32|/|[link](./kernels/elementwise/)|⭐️| | ✔️ [elementwise_f32x4](./kernels/elementwise/elementwise.cu)|f32|/|[link](./kernels/elementwise/)|⭐️| | ✔️ [elementwise_f16](./kernels/elementwise/elementwise.cu)|f16|/|[link](./kernels/elementwise/)|⭐️| @@ -205,27 +204,27 @@ The kernels listed here will guide you through a step-by-step progression, rangi | ✔️ [mat_trans_f32_diagonal2d](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️| | ✔️ [mat_trans_f32x4_col2row{2d}](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️| | ✔️ [mat_trans_f32x4_row2col{2d}](./kernels/mat-transpose/mat_transpose.cu)|f32|/|[link](./kernels/mat-transpose/)|⭐️⭐️| -| ✔️ [warp_reduce_[all]](./kernels/reduce/block_all_reduce.cu)|all|all|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_f32_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_f32x4_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_f16_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_f16_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_f16x2_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_f16x2_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_f16x8_pack_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_f16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_bf16_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_bf16_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_bf16x2_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_bf16x2_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_bf16x8_pack_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_bf16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_fp8_e4m3_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_fp8_e5m2_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_fp8_e4m3x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_fp8_e5m2x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_i8_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️| -| ✔️ [reduce_i8x16_pack_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [warp_reduce_{all}](./kernels/reduce/block_all_reduce.cu)|all|all|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_f32_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_f32x4_f32](./kernels/reduce/block_all_reduce.cu)|f32|f32|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_f16_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_f16_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_f16x2_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_f16x2_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_f16x8_pack_f16](./kernels/reduce/block_all_reduce.cu)|f16|f16|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_f16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|f16|f32|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_bf16_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_bf16_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_bf16x2_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_bf16x2_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_bf16x8_pack_bf16](./kernels/reduce/block_all_reduce.cu)|bf16|bf16|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_bf16x8_pack_f32](./kernels/reduce/block_all_reduce.cu)|bf16|f32|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_fp8_e4m3_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️| +| ✔️ [block_all_reduce_fp8_e5m2_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️| +| ✔️ [block_all_reduce_fp8_e4m3x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e4m3|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️| +| ✔️ [block_all_reduce_fp8_e5m2x16_pack_f16](./kernels/reduce/block_all_reduce.cu)|fp8_e5m2|f16|[link](./kernels/reduce/)|⭐️⭐️⭐️| +| ✔️ [block_all_reduce_i8_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️| +| ✔️ [block_all_reduce_i8x16_pack_i32](./kernels/reduce/block_all_reduce.cu)|i8|i32|[link](./kernels/reduce/)|⭐️⭐️| | ✔️ [dot_product_f32](./kernels/dot-product/dot_product.cu)|f32|f32|[link](./kernels/dot-product/)|⭐️⭐️| | ✔️ [dot_product_f32x4](./kernels/dot-product/dot_product.cu)|f32|f32|[link](./kernels/dot-product/)|⭐️⭐️| | ✔️ [dot_product_f16_f32](./kernels/dot-product/dot_product.cu)|f16|f32|[link](./kernels/dot-product/)|⭐️⭐️| @@ -262,7 +261,8 @@ The kernels listed here will guide you through a step-by-step progression, rangi | ✔️ [rms_norm_f16x8_pack_f32](./kernels/rms-norm/rms_norm.cu)|f16|f32|[link](./kernels/rms-norm/)|⭐️⭐️| | ✔️ [rms_norm_f16_f32](./kernels/rms-norm/rms_norm.cu)|f16|f32|[link](./kernels/rms-norm/)|⭐️⭐️| | ✔️ [nms_f32](./kernels/nms/nms.cu)|f32|/|[link](./kernels/nms)|⭐️⭐️| -| ✔️ [notes v1(deprecated)](./kernels/notes-v1.cu)|f32|f32|/|⭐️| +| ✔️ [notes v1(deprecated)](./kernels/notes-v1.cu)|f32|f32|/|⭐️⭐️| +| ✔️ [How to profile with nsys/ncu(timeline/ptx/sass)](./kernels/nvidia-nsight/)|/|/|[link](./kernels/nvidia-nsight/)|⭐️⭐️| ### 📚 Hard ⭐⭐⭐️ ([©️back👆🏻](#cuda-kernel)) @@ -284,7 +284,7 @@ The kernels listed here will guide you through a step-by-step progression, rangi | ✔️ [sgemm_t_8x8_sliced_k16...dbuf](./kernels/sgemm/sgemm_async.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️| | ✔️ [sgemm_t_8x8_sliced_k16...async](./kernels/sgemm/sgemm_async.cu)|f32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️| | ✔️ [sgemm_wmma_m16n16k8...stages*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️| -| ✔️ [sgemm_wmma_m16n16k8...swizzle*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️| +| ✔️ [sgemm_wmma_m16n16k8...swizzle{+block}*](./kernels/sgemm/sgemm_wmma_tf32_stage.cu)|tf32|f32|[link](./kernels/sgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_naive_f16](./kernels/hgemm/naive/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️| | ✔️ [hgemm_sliced_k_f16](./kernels/hgemm/naive/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_t_8x8_sliced_k_f16x4](./kernels/hgemm/hgemm.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| @@ -299,12 +299,13 @@ The kernels listed here will guide you through a step-by-step progression, rangi | ✔️ [hgemm_wmma_m16n16k16...dbuf*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_wmma_m32n8k16....dbuf*](./kernels/hgemm/wmma/hgemm_wmma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_wmma_m16n16k16...stages*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| -| ✔️ [hgemm_wmma_m16n16k16...swizzle*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| +| ✔️ [hgemm_wmma_m16n16k16...swizzle{+block}*](./kernels/hgemm/wmma/hgemm_wmma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_mma_m16n8k16...naive*](./kernels/hgemm/mma/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_mma_m16n8k16...mma2x4*](./kernels/hgemm/mma/hgemm_mma.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_mma_m16n8k16...stages*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| -| ✔️ [hgemm_mma_m16n8k16...swizzle*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| -| ✔️ [hgemm_mma_stages{swizzle}...cute*](./kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| +| ✔️ [hgemm_mma_m16n8k16...swizzle{+block}*](./kernels/hgemm/mma/hgemm_mma_stage.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| +| ✔️ [hgemm_mma_m16n8k16...swizzle{+smem}*](./kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| +| ✔️ [hgemm_mma_stages_swizzle{+smem}...cute*](./kernels/hgemm/cutlass/hgemm_mma_stage_tn_cute.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️⭐️| | ✔️ [hgemm_mma_cublas*](./kernels/hgemm/cublas/hgemm_cublas.cu)|f16|f16|[link](./kernels/hgemm/)|⭐️⭐️| ### 📚 Hard+ ⭐️⭐️⭐️⭐️ & Hard++ ⭐️⭐️⭐️⭐️⭐️ ([©️back👆🏻](#cuda-kernel)) @@ -318,11 +319,14 @@ The kernels listed here will guide you through a step-by-step progression, rangi | ✔️ [flash_attn_mma_stages...shared_kv*](./kernels/flash-attn/mma/flash_attn_mma_share_kv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️| | ✔️ [flash_attn_mma_stages...shared_qkv*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️| | ✔️ [flash_attn_mma_stages...tiling_qk*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️| +| ✔️ [flash_attn_mma...tiling_qk_swizzle{+smem}*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk_swizzle.cu)|f16|f16|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️| | ? [flash_attn_mma_stages_split_kv{f32}*](./kernels/flash-attn/mma/flash_attn_mma_split_kv_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️| | ? [flash_attn_mma_stages_split_q{f32}*](./kernels/flash-attn/mma/flash_attn_mma_split_q_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️| | ? [flash_attn_mma_stages...shared_kv{f32}*](./kernels/flash-attn/mma/flash_attn_mma_share_kv_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️| | ? [flash_attn_mma_stages...shared_qkv{f32}*](./kernels/flash-attn/mma/flash_attn_mma_share_qkv_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️| | ? [flash_attn_mma_stages...tiling_qk{f32}*](./kernels/flash-attn/mma/flash_attn_mma_tiling_qk_acc_f32.cu)|f16|f32|[link](./kernels/flash-attn)|⭐️⭐️⭐️⭐️⭐️| +| ✔️ [How to implement MMA smem swizzle*](./kernels/swizzle/mma_simple_swizzle.cu)|f16|f16|[link](./kernels/swizzle)|⭐️⭐️⭐️⭐️| + ## 📖 博客目录
diff --git a/kernels/flash-attn/README.md b/kernels/flash-attn/README.md index 20c3ee42..171849ff 100644 --- a/kernels/flash-attn/README.md +++ b/kernels/flash-attn/README.md @@ -5,7 +5,7 @@ |Tensor Cores|Loop over Seqlen/HeadDim |Tile Block (Br, Bc)|MMA (m16n8k16)| |:---:|:---:|:---:|:---:| |✔️|✔️|✔️|✔️| -|Pack LDST (pack 128 bits)|SMEM Padding|Copy Async (cp.async.cg/ca)|Tile MMA (More Threads) +|Pack LDST (pack 128 bits)|SMEM **Swizzle**/Padding |Copy Async (cp.async.cg/ca)|Tile MMA (More Threads) |✔️|✔️|✔️|✔️| |Tile Warp (More Values)|Multi Stages (1/2)|Collective Store (Warp Shfl & Reg Reuse)|**Split KV/Q**| |✔️|✔️|✔️|✔️| diff --git a/kernels/flash-attn/flash_attn_mma.py b/kernels/flash-attn/flash_attn_mma.py index ebf3ef71..369cd9ec 100644 --- a/kernels/flash-attn/flash_attn_mma.py +++ b/kernels/flash-attn/flash_attn_mma.py @@ -83,6 +83,7 @@ def get_args(): './mma/flash_attn_mma_share_kv.cu', './mma/flash_attn_mma_share_qkv.cu', './mma/flash_attn_mma_tiling_qk.cu', + './mma/flash_attn_mma_tiling_qk_swizzle.cu', './pybind/flash_attn.cc' ], extra_cuda_cflags=[ @@ -218,11 +219,11 @@ def run_benchmark(perf_func: callable, else: improve = 0 MAX_TFLOPS = TFLOPS - print(f"{out_info:>32}: {out_val}, time:{mean_time:<.6f}ms, " + print(f"{out_info:>38}: {out_val}, time:{mean_time:<.6f}ms, " f"TFLOPS:{TFLOPS:<6.2f}(+{improve:.2f}%)") else: - if not only_show_improved or "flash" in tag: - print(f"{out_info:>32}: {out_val}, time:{mean_time:<.6f}ms, " + if not only_show_improved or "flash" in tag or "sdpa" in tag: + print(f"{out_info:>38}: {out_val}, time:{mean_time:<.6f}ms, " f"TFLOPS:{TFLOPS:<6.2f}") if show_matrix: print(out) @@ -296,7 +297,7 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor, diff = torch.abs(out_flash_or_sdpa.float() - out_mma.float()) all_close = str(torch.allclose(out_flash_or_sdpa.float(), out_mma.float(), atol=1e-2)) pretty_print_line( - f"{true_tag} vs {tag:<18}, all close: {all_close:<6}, " + f"{true_tag} vs {tag:<22}, all close: {all_close:<6}, " f"max diff: {diff.max().item():.6f}, min diff: {diff.min().item():.6f}, " f"mean diff: {diff.mean().item():.6f}" ) @@ -340,6 +341,8 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor, out_mma_share_kv2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_shared_kv, q, k, v, "mma(split-q+share-kv+stage2)", o, stages=2) out_mma_tiling_qk1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk, q, k, v, "mma(split-q+tiling-qk+stage1)", o, stages=1) out_mma_tiling_qk2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk, q, k, v, "mma(split-q+tiling-qk+stage2)", o, stages=2) + out_mma_tiling_qk_sw1, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk_swizzle, q, k, v, "mma(split-q+tiling-qk+swizzle+stage1)", o, stages=1) + out_mma_tiling_qk_sw2, _ = run_benchmark(lib.flash_attn_mma_stages_split_q_tiling_qk_swizzle, q, k, v, "mma(split-q+tiling-qk+swizzle+stage2)", o, stages=2) if D <= 256: out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)") if args.run_torch_sdpa: @@ -360,9 +363,13 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor, check_all_close(out_flash, out_mma_share_qkv2, "out_mma_share_qkv2", args.check_all) check_all_close(out_flash, out_mma_tiling_qk1, "out_mma_tiling_qk1", args.check_all) check_all_close(out_flash, out_mma_tiling_qk2, "out_mma_tiling_qk2", args.check_all) + check_all_close(out_flash, out_mma_tiling_qk_sw1, "out_mma_tiling_qk_sw1", args.check_all) + check_all_close(out_flash, out_mma_tiling_qk_sw2, "out_mma_tiling_qk_sw2", args.check_all) pretty_print_line() elif args.run_torch_sdpa: pretty_print_line() check_all_close(out_sdpa, out_mma_tiling_qk1, "out_mma_tiling_qk1", args.check_all, False) check_all_close(out_sdpa, out_mma_tiling_qk2, "out_mma_tiling_qk2", args.check_all, False) + check_all_close(out_sdpa, out_mma_tiling_qk_sw1, "out_mma_tiling_qk_sw1", args.check_all, False) + check_all_close(out_sdpa, out_mma_tiling_qk_sw2, "out_mma_tiling_qk_sw2", args.check_all, False) pretty_print_line() diff --git a/kernels/flash-attn/mma/flash_attn_mma_share_kv_swizzle.cu b/kernels/flash-attn/mma/flash_attn_mma_share_kv_swizzle.cu new file mode 100644 index 00000000..10730cb9 --- /dev/null +++ b/kernels/flash-attn/mma/flash_attn_mma_share_kv_swizzle.cu @@ -0,0 +1,934 @@ +#include "utils.h" + +// Write FlashAttention-2 from scratch using Tensor Cores with MMA PTX instruction. +// The input is Q,K,V, 4D tensor with shape [batch_size, num_heads, seq_len, head_dim]. +// The output is O, a 4D tensor with shape [batch_size, num_heads, seq_len, head_dim]. + +// The FlashAttention-2 algorithm is described in the following paper: +// https://arxiv.org/pdf/2307.08691 + +// Q,K,V,O: [batch_size, num_heads, seq_len, head_dim], [B,H,N,d] +// each block processes Q_tile with shape [Br,d] and full K,V with shape [N,d] + +// Split Q across MMA(Warps) and keep access KV for all MMA(Warps), +// in order to reduce the comm between warps via smem and warp shuffle. + +// MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps +// | 64x64 | warp_KV 0 | +// | warp_QP 0 | MMA 0 ... MMA 0 (x8) | +// | warp_QP 1 | MMA 1 ... MMA 1 (x8) | +// | warp_QP 2 | MMA 2 ... MMA 2 (x8) | +// | warp_QP 3 | MMA 3 ... MMA 3 (x8) | + +// MMA = m16n8k16, Br=16x8=128, Bc=8x16=128, layout: 8 warps +// | 128x128 | warp_KV 0 | +// | warp_QP 0 | MMA 0 ... MMA 0 (x16) | +// | warp_QP 1 | MMA 1 ... MMA 1 (x16) | +// | warp_QP 2 | MMA 2 ... MMA 2 (x16) | +// | warp_QP 3 | MMA 3 ... MMA 3 (x16) | +// | warp_QP 4 | MMA 4 ... MMA 4 (x16) | +// | warp_QP 5 | MMA 5 ... MMA 5 (x16) | +// | warp_QP 6 | MMA 6 ... MMA 6 (x16) | +// | warp_QP 7 | MMA 7 ... MMA 7 (x16) | + +// MMA = m16n8k16, Br=16x8=128, Bc=8x8=64, layout: 8 warps +// | 128x64 | warp_KV 0 | +// | warp_QP 0 | MMA 0 ... MMA 0 (x8) | +// | warp_QP 1 | MMA 1 ... MMA 1 (x8) | +// | warp_QP 2 | MMA 2 ... MMA 2 (x8) | +// | warp_QP 3 | MMA 3 ... MMA 3 (x8) | +// | warp_QP 4 | MMA 4 ... MMA 4 (x8) | +// | warp_QP 5 | MMA 5 ... MMA 5 (x8) | +// | warp_QP 6 | MMA 6 ... MMA 6 (x8) | +// | warp_QP 7 | MMA 7 ... MMA 7 (x8) | + +// Manually apply SMEM swizzling instead of padding in +// Split-Q kernels to reduce bank conflicts. + +// i: row index; j: col index. +// e.g kColStride = 64, kStep = 8 -> load 8 half as 128 bits memory issue. +template +static __device__ __forceinline__ int swizzle_permuted_j(int i, int j) { + // ------------------------------------------- + // --------------swizzle layout--------------- + // -------------col 0~64, step 8-------------- + // ------------------------------------------- + // | row 0 | (0, 8, 16, 24, 32, 40, 48, 56) | + // | row 1 | (0, 8, 16, 24, 32, 40, 48, 56) | + // | row 2 | (0, 8, 16, 24, 32, 40, 48, 56) | + // | row 3 | (0, 8, 16, 24, 32, 40, 48, 56) | + // ------------------------------------------- + // | row 4 | (8, 0, 24, 16, 40, 32, 56, 48) | + // | row 5 | (8, 0, 24, 16, 40, 32, 56, 48) | + // | row 6 | (8, 0, 24, 16, 40, 32, 56, 48) | + // | row 7 | (8, 0, 24, 16, 40, 32, 56, 48) | + // ------------------------------------------- + // | row 8 | (16, 24, 0, 8, 48, 56, 32, 40) | + // | row 9 | (16, 24, 0, 8, 48, 56, 32, 40) | + // | row 10 | (16, 24, 0, 8, 48, 56, 32, 40) | + // | row 11 | (16, 24, 0, 8, 48, 56, 32, 40) | + // ------------------------------------------- + // | row 12 | (24, 16, 8, 0, 56, 48, 40, 32) | + // | row 13 | (24, 16, 8, 0, 56, 48, 40, 32) | + // | row 14 | (24, 16, 8, 0, 56, 48, 40, 32) | + // | row 15 | (24, 16, 8, 0, 56, 48, 40, 32) | + // ------------------------------------------- + // swizzle: ((int(j / kStep) ^ int(i / 4)) % int(kColStride / kStep)) * kStep; + static_assert(kStep == 4 || kStep == 8, "kStep must be 8 or 4."); + static_assert(kColStride % kStep == 0, "kColStride must be multiple of kStep."); + if constexpr (kStep == 8) { + return (((j >> 3) ^ (i >> 2)) % (kColStride >> 3)) << 3; + } else { + static_assert(kStep == 4); + return (((j >> 2) ^ (i >> 2)) % (kColStride >> 2)) << 2; + } +} + +template +static __device__ __forceinline__ int swizzle_permuted_Q_j(int i, int j) { + return swizzle_permuted_j(i, j); +} + +template +static __device__ __forceinline__ int swizzle_permuted_K_j(int i, int j) { + return swizzle_permuted_j(i, j); +} + +template +static __device__ __forceinline__ int swizzle_permuted_V_j(int i, int j) { + return swizzle_permuted_j(i, j); +} + +template< + const int kHeadDim, // Headdim, 32,64,128 + const int kMmaAtomM, // MMA Atom M, 16 + const int kMmaAtomN, // MMA Atom N, 8 + const int kMmaAtomK, // MMA Atom K, 16 + const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] + const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] + const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] + const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] + const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M + const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N + const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M + const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|... + const int kStage, + const int kPad + > +__global__ void __launch_bounds__( + WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK) +flash_attn_mma_stages_split_q_shared_kv_kernel(half* Q, + half* K, + half* V, + half* O, + int QKV_seqlen, + int QKV_head) { + // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN. + // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. + static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16 + static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T + static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V + static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T + // kWarpTileHeadDimV: d=8*(1|2|3|4|...) = 8|...|32|64|96|128|..., etc. + // e.g, kWarpTileHeadDimV = 8 -> d = 8*8 = 64; 16 -> d = 8*16 = 128. + static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == ( + kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V + static_assert(kStage < 3 && kStage > 0); + static_assert(kPad >= 0 && kPad % 8 == 0); // 0,8,16 + constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64 + constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64 + static_assert(Br >= Bc); // for shared memory reuse. + constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads + // Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen. + const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K_tile[Bc,d] + const float scale = 1.0f / sqrt((float) kHeadDim); + + // grid(div_ceil(QKV_seqlen, Br), QKV_batch * QKV_head), (x,y,z) + const int QKV_batch_id = blockIdx.y / QKV_head; // Batch size + const int QKV_head_id = blockIdx.y % QKV_head; // Head num + const int Q_tile_id = blockIdx.x; // Q tile_id, range [0, Tr] + const int O_tile_id = Q_tile_id; // O tile_id, same as Q. + const int tid = threadIdx.x; // within block + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int lane_id = tid % WARP_SIZE; // 0~31 + const int warp_QP = warp_id; // 0,1,2,3 or 0~7 + const int warp_KV = 0; // 0 + // MMA Layout [Br,Bc]=[64,64], MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps + // | 64x64 | warp_KV 0 | + // | warp_QP 0 | MMA 0 ... MMA 0 (x8) | + // | warp_QP 1 | MMA 1 ... MMA 1 (x8) | + // | warp_QP 2 | MMA 2 ... MMA 2 (x8) | + // | warp_QP 3 | MMA 3 ... MMA 3 (x8) | + // MMA Layout [Br,Bc]=[128,128], MMA = m16n8k16, Br=16x8=128, Bc=8x16=128, layout: 8 warps + // | 128x128 | warp_KV 0 | + // | warp_QP 0 | MMA 0 ... MMA 0 (x16) | + // | warp_QP 1 | MMA 1 ... MMA 1 (x16) | + // | warp_QP 2 | MMA 2 ... MMA 2 (x16) | + // | warp_QP 3 | MMA 3 ... MMA 3 (x16) | + // | warp_QP 4 | MMA 4 ... MMA 4 (x16) | + // | warp_QP 5 | MMA 5 ... MMA 5 (x16) | + // | warp_QP 6 | MMA 6 ... MMA 6 (x16) | + // | warp_QP 7 | MMA 7 ... MMA 7 (x16) | + const int Q_gmem_offset = ((QKV_batch_id * QKV_head * QKV_seqlen * kHeadDim) + + (QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d] + const int K_gmem_offset = ((QKV_batch_id * QKV_head * QKV_seqlen * kHeadDim) + + (QKV_head_id * QKV_seqlen * kHeadDim)); // K [seqlen,d] + const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d] + const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d] + + // Mapping Q gmem -> tid -> smem, Q[Br,d]=[64,64 or 128], 128 threads. + int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 2, row 0~64 + int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kHeadDim / (kNumThreads / Br)); // (tid % 2) * 32, 0,32,... + // Mapping K gmem -> tid -> smem, K[Bc,d]=[64 or 128,64], 128 threads. + int load_smem_K_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64 + int load_smem_K_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,... + // Mapping V gmem -> tid -> smem, V[Bc,d]=[64,64 or 128], 128 threads. + int load_smem_V_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64 + int load_smem_V_d = (tid % (kNumThreads / Bc)) * (kHeadDim / (kNumThreads / Bc)); // (tid % 2) * 32, 0,32,... + // global Q row of current head for tile [Br,d] per block. + int load_gmem_Q_Br = Q_tile_id * Br + load_smem_Q_Br; + if (load_gmem_Q_Br >= QKV_seqlen) return; + // KV tile gmem load index starts from 0 and increments with + // each iteration as we loop over seqlen. + int load_gmem_K_Bc_offset = 0; + int load_gmem_V_Bc_offset = 0; + + // Shared memory for Q,K,V, we don not need additional smem for O + // collective store which perform via registers reuse and warp shuffle. + extern __shared__ half smem[]; + constexpr int Q_tile_size = Br * (kHeadDim + kPad); // 64*64=4096, ~8192 bytes=8M + constexpr int KV_tile_size = Bc * (kHeadDim + kPad); // K[Bc,d] + half* Q_tile_smem = smem; // 8M/16M + half* K_tile_smem = Q_tile_smem + Q_tile_size; // 8M/16M + half* V_tile_smem = K_tile_smem; // KV shared the same smem + // NOTE: KV may shared same smem to reduce smem usage for kStage 1 + // stage 1, w shared KV smem, Br=Bc=64, d=64: 8M+(8M) =16M, +Pad(2M) = 18M + // stage 1, w shared KV smem, Br=Bc=128, d=64: 16M+16M =32M, +Pad(4M) = 36M + // stage 1, w shared KV smem, Br=Bc=64, d=128: 16M+16M =32M, +Pad(2M) = 36M + // stage 1, w shared KV smem, Br=Bc=64, d=256: 32M+32M =64M, +Pad(2M) = 66M + // stage 1, w shared KV smem, Br=64,Bc=32, d=256: 32M+16M =48M, +Pad(2M) = 50M + // stage 1, w shared KV smem, Br=128,Bc=16,d=256: 64M+16M =80M, +Pad(2M) = 82M + + uint32_t smem_Q_base_ptr = __cvta_generic_to_shared(Q_tile_smem); + uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem); + uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem); + + // --------------------- Registers/SMEM for thread block ------------------------- + // block m_old, l_old, store in lane, use float to keep precision. + float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2] + float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2] + fill_2D_regs(lane_block_row_max_old, -INFINITY); + fill_2D_regs(lane_block_row_sum_old, 0.0f); + + // ---------------------- Registers for S=Q@K^T/O=P@V ---------------------------- + // registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc] and O=P[Br,Bc]@V[Bc,d]=[Br,d]. + // Allocate R_Q[(kHeadDim/kMmaAtomK)<=8][1][4], e.g R_Q[4][1][4] 16 regs. + // By the way, we have to reduce R_Z to 0 regs and reuse R_Q for collective store. + // Then we can load Q from smem only once and reuse it for + // processes. This will reduce large io-access for Q smem while N is large. + // FIXME(DefTruth): why can not get good performance for headdim >= 64 ? + // Will enable it untill I have figure out the performance issues. + constexpr bool kCanPrefetchQs2r = ((kHeadDim / kMmaAtomK) <= 8) && (kHeadDim < 64); + constexpr bool kCanPrefetchKVg2s = (kStage == 2); // whether prefetch KV g2s. + constexpr int kPrefetchKg2sSmemId = 0; // smem id for K g2s, 0. + constexpr int kPrefetchVg2sSmemId = kCanPrefetchKVg2s ? 1 : 0; // smem id for V g2s, 1. + constexpr int kNumPrefetchQs2r = (kCanPrefetchQs2r) ? (kHeadDim / kMmaAtomK) : 1; + uint32_t R_Q[kNumPrefetchQs2r][kWarpTileSeqLenQ][4]; // [4/8/1][1][4] + uint32_t R_K[kWarpTileSeqLenK][ 2]; // [8][2] + uint32_t R_V[kWarpTileHeadDimV][2]; // [8][2] + // registers for current tile_K_seqlen within, [64,64] = S_tile[Br,Bc] + // = Q_tile[Br,d] * K[Bc,d], each thread hold 2x32 bits regs. + uint32_t R_S[kWarpTileSeqLenQ][kWarpTileSeqLenK][ 2]; // [1][8][2] + // registers for tile_K_seqlen O=PV[Br,d]=P@V, [2][2/4][2], 8 or 16 regs. + uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] + // registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs. + uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] + fill_3D_regs(R_S, 0); + fill_3D_regs(R_D, 0); + fill_3D_regs(R_O, 0); + + // load Q from gmem -> smem, only load once. + { + int load_gmem_Q_d = load_smem_Q_d; + int load_gmem_Q_addr = (Q_gmem_offset + load_gmem_Q_Br * kHeadDim + load_gmem_Q_d); + uint32_t load_smem_Q_ptr = (smem_Q_base_ptr + ( + load_smem_Q_Br * (kHeadDim + kPad) + load_smem_Q_d) * sizeof(half)); + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / Br)); i += 8) { + CP_ASYNC_CG(load_smem_Q_ptr + i * 2, &Q[load_gmem_Q_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + + // : for K^T[d,seqlen] with K^T_tile[d,Bc] + // tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc] + #pragma unroll 1 + for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) { + // TODO: process last tile_K_seqlen ? pad to multiple of 8. + + // Load K tile from gmem -> smem, always use smem part 0, send g2s + // memory issues before Prefetch Q s2r. + if constexpr (kCanPrefetchKVg2s) { + if (tile_K_seqlen == 0) { + load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; + int load_gmem_K_d = load_smem_K_d; + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + + // Now, we have to wait curr K tile ready for Q@K^T MMA. + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + // : Load V tile async from gmem -> smem 1, before Q@K^T + { + load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; + int load_gmem_V_d = load_smem_V_d; + int load_gmem_V_addr = ( + V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); + uint32_t load_smem_V_ptr = ( + smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size + + load_smem_V_Bc * (kHeadDim + kPad) + + load_smem_V_d) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + } else { + load_gmem_K_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; + int load_gmem_K_d = load_smem_K_d; + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + // Now, we have to wait curr K tile ready for Q@K^T MMA. + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + + // : Load Q tile from smem -> regs, before Q@K^T. + if constexpr (kCanPrefetchQs2r) { + // Wait Q ready and let K copy async, then prefetch Q from smem -> regs. + // NOTE: we only need to load Q once from smem -> regs, and then reuse it. + if (tile_K_seqlen == 0) { + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + + #pragma unroll + for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] + int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; + int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15 + int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_Q_ptr = ( + smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) + + lane_smem_Q_d) * sizeof(half) + ); + LDMATRIX_X4(R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1], + R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3], + lane_smem_Q_ptr); // now, R_Q[1/2/4/8][1][4] + } + } + __syncthreads(); // wait all warps ready. + } // end if tile_K_seqlen == 0 + } // end if kCanPrefetchQs2r + + // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc] + // Matmul with NT layout, Q row major, K^T col major. + // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. + // S_tile[Br,Bc]=Q_tile[Br,d]@K[Bc,d] + // + fill_3D_regs(R_S, 0); + #pragma unroll + for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { + // smem -> reg, load m16k16 smem Q, offset d according tile_K_d. + // ldmatrix.x4 for Q_tile_smem. + if constexpr (!kCanPrefetchQs2r) { + // load Q from smem -> regs in each loop w/o prefetch Q s2r. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] + int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; + int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15 + int lane_smem_Q_d = tile_K_d * kMmaAtomK + (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_Q_ptr = ( + smem_Q_base_ptr + (lane_smem_Q_Br * (kHeadDim + kPad) + + lane_smem_Q_d) * sizeof(half) + ); + LDMATRIX_X4(R_Q[0][i][0], R_Q[0][i][1], R_Q[0][i][2], R_Q[0][i][3], + lane_smem_Q_ptr); // now, R_Q[1][1][4] + } + } + + // smem -> reg, load k16n8 from smem K, offset d according tile_K_d. + // ldmatrix.x2 for K_tile_smem, [Bc,kMmaAtomK] from [Bc,d]=[K,N] + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + // load k16n8 via ldmatrix.x2 from K_tile_smem[Bc,d]. + // K[Bc,d] with row major means K^T[d,Bc] in col major. + int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; + int lane_smem_K_Bc = warp_smem_K_Bc + lane_id % 8; // 0~7 + int lane_smem_K_d = tile_K_d * kMmaAtomK + ((lane_id / 8) % 2) * 8; // 0,8 + uint32_t lane_smem_K_ptr = ( + smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + + lane_smem_K_Bc * (kHeadDim + kPad) + + lane_smem_K_d) * sizeof(half) + ); + LDMATRIX_X2(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K + } // end for kWarpTileSeqLenK + + if constexpr (kCanPrefetchQs2r) { + // MMA compute + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + HMMA16816(R_S[i][j][0], R_S[i][j][1], + R_Q[tile_K_d][i][0], R_Q[tile_K_d][i][1], + R_Q[tile_K_d][i][2], R_Q[tile_K_d][i][3], + R_K[j][0], R_K[j][1], + R_S[i][j][0], R_S[i][j][1]); + } + } + } else { + // MMA compute + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + HMMA16816(R_S[i][j][0], R_S[i][j][1], + R_Q[0][i][0], R_Q[0][i][1], R_Q[0][i][2], R_Q[0][i][3], + R_K[j][0], R_K[j][1], + R_S[i][j][0], R_S[i][j][1]); + } + } + } + } // end loop over d, S=Q@K^T + __syncthreads(); + + // : If kCanPrefetchKVg2s is not enable, + // we will load V g2s here, before rowmax and rowsum. + if constexpr (!kCanPrefetchKVg2s) { + load_gmem_V_Bc_offset = tile_K_seqlen * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_V_Bc = load_gmem_V_Bc_offset + load_smem_V_Bc; + int load_gmem_V_d = load_smem_V_d; + int load_gmem_V_addr = ( + V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); + uint32_t load_smem_V_ptr = ( + smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size + + load_smem_V_Bc * (kHeadDim + kPad) + + load_smem_V_d) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + + // : load next K tile from gmem -> smem 0, before P@V. + if constexpr (kCanPrefetchKVg2s) { + if ((tile_K_seqlen + 1) < Tc) { + load_gmem_K_Bc_offset = (tile_K_seqlen + 1) * Bc; // e.g (0~3)*64=(0,64,128,192,...) + int load_gmem_K_Bc = load_gmem_K_Bc_offset + load_smem_K_Bc; + int load_gmem_K_d = load_smem_K_d; + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (kPrefetchKg2sSmemId * KV_tile_size + + load_smem_K_Bc * (kHeadDim + kPad) + + load_smem_K_d) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + } + + // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps + // | 64x64 | warp_KV 0 | + // | warp_QP 0 | MMA 0 ... MMA 0 (x8) | + // | warp_QP 1 | MMA 1 ... MMA 1 (x8) | + // | warp_QP 2 | MMA 2 ... MMA 2 (x8) | + // | warp_QP 3 | MMA 3 ... MMA 3 (x8) | + + // Online safe softmax, warp/block reduce max/sum, row wise + float lane_row_max_new[kWarpTileSeqLenQ][2]; // [1][2] + float lane_row_sum_new[kWarpTileSeqLenQ][2]; // [1][2] + fill_2D_regs(lane_row_max_new, -INFINITY); + fill_2D_regs(lane_row_sum_new, 0.0f); + + // Row max for [Br,Bc] tile, Thread -> Warp -> Block. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + // Thread level reduce max across kWarpTileSeqLenK dim, namely Bc. + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // The layout of the fragments held by different threads for C. (m16n8k16) + // Row\Col 0 1 2 3 4 5 6 7 + // 0 T0: {c0, c1} T1: {c0, c1} T2: {c0, c1} T3: {c0, c1} + // 1 T4: {c0, c1} T5: {c0, c1} T6: {c0, c1} T7: {c0, c1} + // 2 ... + // ... + // 7 T28: {c0, c1} T29: {c0, c1} T30: {c0, c1} T31: {c0, c1} + // 8 T0: {c2, c3} T1: {c2, c3} T2: {c2, c3} T3: {c2, c3} + // 9 T4: {c2, c3} T5: {c2, c3} T6: {c2, c3} T7: {c2, c3} + // 10 ... + // ... + // 15 T28: {c2, c3} T29: {c2, c3} T30: {c2, c3} T31: {c2, c3} + float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3} + // This should be the row max after S = (Q @ K^T) / sqrt(d) + float tmp_max_0 = max(t_reg_S_0.x, t_reg_S_0.y) * scale; + float tmp_max_1 = max(t_reg_S_1.x, t_reg_S_1.y) * scale; + lane_row_max_new[i][0] = max(lane_row_max_new[i][0], tmp_max_0); + lane_row_max_new[i][1] = max(lane_row_max_new[i][1], tmp_max_1); + } // end for kWarpTileSeqLenK + + // Warp level reduce max, warp_size = 4 + // Each thread contains the maximum of 2 rows of Br, + // and only the values of T0, T4, ..., T28 are used. + lane_row_max_new[i][0] = warp_reduce_max(lane_row_max_new[i][0]); + lane_row_max_new[i][1] = warp_reduce_max(lane_row_max_new[i][1]); + } // end for kWarpTileSeqLenQ + __syncthreads(); + + // Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + // Use latest global row max without update. + // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; + float block_row_max_new_0 = lane_row_max_new[i][0]; + // Br 1, row_id, 8~15, 24~31, 40~47, 56~63; + float block_row_max_new_1 = lane_row_max_new[i][1]; + + float block_row_max_old_0 = lane_block_row_max_old[i][0]; + float block_row_max_old_1 = lane_block_row_max_old[i][1]; + // Apply m_new = max(m_old, m_new) here. + block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0); + block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1); + + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3} + // P = Exp(S - m_new), fmaf(x, y, z) = x * y + z; + t_reg_S_0.x = __expf(__fmaf_rn(t_reg_S_0.x, scale, - block_row_max_new_0)); + t_reg_S_0.y = __expf(__fmaf_rn(t_reg_S_0.y, scale, - block_row_max_new_0)); + t_reg_S_1.x = __expf(__fmaf_rn(t_reg_S_1.x, scale, - block_row_max_new_1)); + t_reg_S_1.y = __expf(__fmaf_rn(t_reg_S_1.y, scale, - block_row_max_new_1)); + lane_row_sum_new[i][0] += (t_reg_S_0.x + t_reg_S_0.y); + lane_row_sum_new[i][1] += (t_reg_S_1.x + t_reg_S_1.y); + // Update R_S for P[Br,Bc] = Exp(S-m), point wise. + HALF2(R_S[i][j][0]) = __float22half2_rn(t_reg_S_0); + HALF2(R_S[i][j][1]) = __float22half2_rn(t_reg_S_1); + } // end for kWarpTileSeqLenK + + // Warp level reduce sum, warp_size = 4 + lane_row_sum_new[i][0] = warp_reduce_sum(lane_row_sum_new[i][0]); + lane_row_sum_new[i][1] = warp_reduce_sum(lane_row_sum_new[i][1]); + } // end for kWarpTileSeqLenQ + __syncthreads(); + + // Compute P[Br,Bc] @ V[Bc,d] = [Br,d] = [64, 64/128], partion Attention. + // Here, we have to wait V ready before compute O = P @ V + if constexpr (kCanPrefetchKVg2s) { + if ((tile_K_seqlen + 1) < Tc) { + CP_ASYNC_WAIT_GROUP(1); // we have send V & K g2s, wait V and let K async. + } else { + CP_ASYNC_WAIT_GROUP(0); // we have only send V g2s. + } + } else { + CP_ASYNC_WAIT_GROUP(0); + } + __syncthreads(); + + // : P[Br,Bc]@V[Bc,d]=[Br,d]=[64,64/128], partion Attention. + // Matmul with NN layout: P[Br,Bc] row major, V[Bc,d] row major. + // Make sure to clear the states in R_O before MMA for P@V for each step. + + // NOTE: Values for P[Br,Bc] already in R_S registers, can we use these + // registers for P(A) matrix directly ? How to do that ? + // according to the A matrix layout for MMA m16n8k16 instruction. + // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // The layout of the fragments held by different threads for A matrix with .f16. + // R\C 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // 0 T0: {a0, a1} T1: {a0, a1} T2: {a0, a1} T3: {a0, a1} T0: {a4, a5} T1: {a4, a5} T2: {a4, a5} T3: {a4, a5} + // 1 T4: {a0, a1} T5: {a0, a1} T6: {a0, a1} T7: {a0, a1} T4: {a4, a5} T5: {a4, a5} T6: {a4, a5} T7: {a4, a5} + // 2 (dashed arrow pointing right) + // ... + // 7 T28: {a0, a1} T29: {a0, a1} T30: {a0, a1} T31: {a0, a1} T28: {a4, a5} T29: {a4, a5} T30: {a4, a5} T31: {a4, a5} + // 8 T0: {a2, a3} T1: {a2, a3} T2: {a2, a3} T3: {a2, a3} T0: {a6, a7} T1: {a6, a7} T2: {a6, a7} T3: {a6, a7} + // 9 T4: {a2, a3} T5: {a2, a3} T6: {a2, a3} T7: {a2, a3} T4: {a6, a7} T5: {a6, a7} T6: {a6, a7} T7: {a6, a7} + // 10 (dashed arrow pointing right) + // ... + // 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7} + + // + fill_3D_regs(R_O, 0); + #pragma unroll + for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) { + // Load k16n8 V from smem -> regs, R_KV, ldmatrix.x2.trans. + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { + int warp_smem_V_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; // d, matmaul N + int lane_smem_V_Bc = tile_V_Bc * kMmaAtomK + lane_id % 16; // 0~15; Bc, matmul K + int lane_smem_V_d = warp_smem_V_d; // 0 + uint32_t lane_smem_V_ptr = ( + smem_V_base_ptr + (kPrefetchVg2sSmemId * KV_tile_size + + lane_smem_V_Bc * (kHeadDim + kPad) + + lane_smem_V_d) * sizeof(half) + ); + LDMATRIX_X2_T(R_V[j][0], R_V[j][1], lane_smem_V_ptr); // R_V + } + + // For R_S[1][8][2], mapping the layout below of P matrix. + // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps + // | 64x64 | warp_KV 0 | + // | warp_QP 0 | MMA 0 ... MMA 0 (x8) | + // | warp_QP 1 | MMA 1 ... MMA 1 (x8) | + // | warp_QP 2 | MMA 2 ... MMA 2 (x8) | + // | warp_QP 3 | MMA 3 ... MMA 3 (x8) | + // tile_V_Bc = 0, all curr MMAs(0~4) need slice P[:, 0:16], 0, 1; stored in all MMAs. + // tile_V_Bc = 1, all curr MMAs(0~4) need slice P[:, 16:32], 2, 3; stored in all MMAs. + // tile_V_Bc = 2, all curr MMAs(0~4) need slice P[:, 32:48], 4, 5; stored in all MMAs. + // tile_V_Bc = 3, all curr MMAs(0~4) need slice P[:, 48:64], 6, 7; stored in all MMAs. + int w = tile_V_Bc * 2; // MMA(Warp) selected, 0, 2, 4, 6 + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1 + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ... + HMMA16816(R_O[i][j][0], R_O[i][j][1], + R_S[i][w][0], R_S[i][w][1], R_S[i][w + 1][0], R_S[i][w + 1][1], + R_V[j][0], R_V[j][1], + R_O[i][j][0], R_O[i][j][1]); + } + } + } // end for V Bc. + __syncthreads(); + + // Rescale O -> Update row sum Exp -> then, Update row max. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=kWarpTileSeqLenP=1 + // m = max(m_old, m_new), l = exp(m_old - m) * l_old + l_new (FA2 paper) + // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; Br 1, row_id, 8~15, 24~31, 40~47, 56~63 + float block_row_max_new_0 = lane_row_max_new[i][0]; + float block_row_max_new_1 = lane_row_max_new[i][1]; + float block_row_sum_new_0 = lane_row_sum_new[i][0]; + float block_row_sum_new_1 = lane_row_sum_new[i][1]; + + float block_row_max_old_0 = lane_block_row_max_old[i][0]; + float block_row_max_old_1 = lane_block_row_max_old[i][1]; + // NOTE: max(-inf, val) = val. + block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0); + block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1); + // Avoid inf value while using m_old for rescaling O. + block_row_max_old_0 = (tile_K_seqlen > 0 ? block_row_max_old_0 : + block_row_max_new_0); + block_row_max_old_1 = (tile_K_seqlen > 0 ? block_row_max_old_1 : + block_row_max_new_1); + + // rescale factor for O and l, exp(m_old - m) + float rescale_o_factor_0 = __expf(block_row_max_old_0 - block_row_max_new_0); + float rescale_o_factor_1 = __expf(block_row_max_old_1 - block_row_max_new_1); + // 0. Rescale O: Online rescaling O each tile_K_seqlen step, need m_new, m_old. + // m = max(m_old, m_new), O_new[Br,d] = exp(m_old - m) * O_old + P@V + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ... + float2 t_reg_O_0 = __half22float2(HALF2(R_O[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_O_1 = __half22float2(HALF2(R_O[i][j][1])); // 8~15 {c2, c3} + float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3} + // Note that the formula in the FA2 paper is incorrect; here, + // the inverse of the exp function should not be taken, as it + // would result in an error during rescaling, namely, you have + // use exp(m_old - m_new), not 1/(m_old - m_new). + // O_new[Br,d] = exp(m_old - m_new) * O_old + P@V + t_reg_D_0.x = __fmaf_rn(rescale_o_factor_0, t_reg_D_0.x, t_reg_O_0.x); + t_reg_D_0.y = __fmaf_rn(rescale_o_factor_0, t_reg_D_0.y, t_reg_O_0.y); + t_reg_D_1.x = __fmaf_rn(rescale_o_factor_1, t_reg_D_1.x, t_reg_O_1.x); + t_reg_D_1.y = __fmaf_rn(rescale_o_factor_1, t_reg_D_1.y, t_reg_O_1.y); + HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0); + HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1); + } // end for kWarpTileHeadDimV. + + // Now, we can update m, l after O has been scaled. + // 1. First, update block row sum Exp for each lane which + // need both m_new and m_old. + float block_row_sum_old_0 = lane_block_row_sum_old[i][0]; + float block_row_sum_old_1 = lane_block_row_sum_old[i][1]; + // Update l = exp(m_old - m_new) * l_old + row_sum(P). + lane_block_row_sum_old[i][0] = (__fmaf_rn( + rescale_o_factor_0, block_row_sum_old_0, block_row_sum_new_0)); + lane_block_row_sum_old[i][1] = (__fmaf_rn( + rescale_o_factor_1, block_row_sum_old_1, block_row_sum_new_1)); + // 2. Then, update block row max for each lane. + lane_block_row_max_old[i][0] = block_row_max_new_0; + lane_block_row_max_old[i][1] = block_row_max_new_1; + } + + if constexpr (kCanPrefetchKVg2s) { + if ((tile_K_seqlen + 1) < Tc) { + // now, we have to wait next K tile ready in smem. + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + } + + } // end loop over N + __syncthreads(); + + // Finaly, we still have to rescale O once more. + // O_output(D) = ( 1/l_final ) * O_final (FA2 paper) + // NOTE: Here, we choose to reuse R_O as final output + // in order to reduce regs usage. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1 + float rescale_factor_0 = __frcp_rn(lane_block_row_sum_old[i][0]); + float rescale_factor_1 = __frcp_rn(lane_block_row_sum_old[i][1]); + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ... + float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3} + t_reg_D_0.x = rescale_factor_0 * t_reg_D_0.x; + t_reg_D_0.y = rescale_factor_0 * t_reg_D_0.y; + t_reg_D_1.x = rescale_factor_1 * t_reg_D_1.x; + t_reg_D_1.y = rescale_factor_1 * t_reg_D_1.y; + HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0); + HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1); + } + } + + // Store O(D): Write O[Br,d] from regs -> gmem, collective store + // with reg reuse & warp shuffle. need R_Z[2][4]. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1 + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8 + + if constexpr (kCanPrefetchQs2r && kNumPrefetchQs2r > 1) { + // reuse R_Q[4/8][1][4] for collective store. + R_Q[0][0][0] = R_D[i][j][0]; R_Q[1][0][0] = R_D[i][j][1]; // warp_size 4 + R_Q[0][0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4); + R_Q[0][0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4); + R_Q[0][0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4); + R_Q[1][0][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4); + R_Q[1][0][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4); + R_Q[1][0][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4); + // st.global.v4 128 bits. [Br,d] + if (lane_id % 4 == 0) { + // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 + int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM; + int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7 + // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) + int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; + int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8) + int store_gmem_O_addr_0 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d); + int store_gmem_O_addr_1 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d); + LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Q[0][0][0]); + LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Q[1][0][0]); + } + } else { + // we have to use new R_Z regs for collective store. + uint32_t R_Z[2][4]; + R_Z[0][0] = R_D[i][j][0]; R_Z[1][0] = R_D[i][j][1]; // warp_size 4 + R_Z[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4); + R_Z[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4); + R_Z[0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4); + R_Z[1][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4); + R_Z[1][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4); + R_Z[1][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4); + // st.global.v4 128 bits. [Br,d] + if (lane_id % 4 == 0) { + // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 + int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM; + int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7 + // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) + int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; + int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8) + int store_gmem_O_addr_0 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d); + int store_gmem_O_addr_1 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d); + LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Z[0][0]); + LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Z[1][0]); + } + } // end if kCanPrefetchQs2r + } // end for kWarpTileHeadDimV + } // end for kWarpTileSeqLenQ +} + +// Launch kernel for flash_attn_mma_stages_split_q +template +void launch_flash_attn_mma_stages_split_q_shared_kv( + torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) { + // Now: fixed tile BrxBc=128x64 for d < 128, 128x16 for d >= 128 + // TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size. + constexpr int kMmaAtomM = 16; + constexpr int kMmaAtomN = 8; + constexpr int kMmaAtomK = 16; + constexpr int kMmaTileSeqLenQ = (kHeadDim < 128) ? 8 : 8; + constexpr int kMmaTileSeqLenK = 1; + constexpr int kMmaTileSeqLenP = (kHeadDim < 128) ? 8 : 8; + constexpr int kMmaTileHeadDimV = 1; + constexpr int kWarpTileSeqLenQ = 1; + constexpr int kWarpTileSeqLenK = (kHeadDim < 128) ? 8 : 2; + constexpr int kWarpTileSeqLenP = 1; + constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); // 8,16,32,.... + constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64 + constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64 + constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads + constexpr int kPad = 8; + + // static int kMaxSramPerBlock; + // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0); + // Calculate SRAM size needed per block, Q,K/V smem size, KV shared the same smem. + constexpr int KV_tile_size = (Bc * (kHeadDim + kPad)); + const int smem_max_size = ((Br * (kHeadDim + kPad)) + + (kStage * KV_tile_size)) * sizeof(half); + + const int QKV_batch = Q.size(0); + const int QKV_head = Q.size(1); + const int QKV_seqlen = Q.size(2); // QKV_seqlen + assert(QKV_seqlen % max(Br, Bc) == 0); // multiple of max(Br, Bc) + + // TODO: How to apply block swizzle to improve L2 Cache hit rate? + // NOTE: reorder (B,H,Tr) -> (Tr,B*H) seems can improve L2 Cache hit rate. + // This might be because SM schedules blocks starting from the x-dimension. + // Placing Tr at the forefront ensures that identical KV pairs are placed + // in consecutive scheduling queues, thereby improving L2 Cache hit rates. + // Tr(=N/Br), batch_size x num_heads + dim3 grid(div_ceil(QKV_seqlen, Br), QKV_batch * QKV_head); + dim3 block(kNumThreads); // 4/8 warps per block + + cudaFuncSetAttribute( + flash_attn_mma_stages_split_q_shared_kv_kernel< + kHeadDim, + kMmaAtomM, + kMmaAtomN, + kMmaAtomK, + kMmaTileSeqLenQ, + kMmaTileSeqLenK, + kMmaTileSeqLenP, + kMmaTileHeadDimV, + kWarpTileSeqLenQ, + kWarpTileSeqLenK, + kWarpTileSeqLenP, + kWarpTileHeadDimV, + kStage, + kPad + >, + cudaFuncAttributeMaxDynamicSharedMemorySize, + // kMaxSramPerBlock + 98304 + ); + + flash_attn_mma_stages_split_q_shared_kv_kernel< + kHeadDim, + kMmaAtomM, + kMmaAtomN, + kMmaAtomK, + kMmaTileSeqLenQ, + kMmaTileSeqLenK, + kMmaTileSeqLenP, + kMmaTileHeadDimV, + kWarpTileSeqLenQ, + kWarpTileSeqLenK, + kWarpTileSeqLenP, + kWarpTileHeadDimV, + kStage, + kPad + ><<>>( + reinterpret_cast(Q.data_ptr()), + reinterpret_cast(K.data_ptr()), + reinterpret_cast(V.data_ptr()), + reinterpret_cast(O.data_ptr()), + QKV_seqlen, + QKV_head + ); +} + +void flash_attn_mma_stages_split_q_shared_kv(torch::Tensor Q, + torch::Tensor K, + torch::Tensor V, + torch::Tensor O, + int stages) { + CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] + const int d = Q.size(3); // B, H, N, d + + if (stages > 1) { + switch (d) + { + case 32: + launch_flash_attn_mma_stages_split_q_shared_kv<32, 2>(Q, K, V, O); + break; + case 64: + launch_flash_attn_mma_stages_split_q_shared_kv<64, 2>(Q, K, V, O); + break; + case 96: + launch_flash_attn_mma_stages_split_q_shared_kv<96, 2>(Q, K, V, O); + break; + case 128: + launch_flash_attn_mma_stages_split_q_shared_kv<128, 2>(Q, K, V, O); + break; + default: + throw std::runtime_error("headdim not support!"); + break; + } + } else { + switch (d) + { + case 32: + launch_flash_attn_mma_stages_split_q_shared_kv<32, 1>(Q, K, V, O); + break; + case 64: + launch_flash_attn_mma_stages_split_q_shared_kv<64, 1>(Q, K, V, O); + break; + case 96: + launch_flash_attn_mma_stages_split_q_shared_kv<96, 1>(Q, K, V, O); + break; + case 128: + launch_flash_attn_mma_stages_split_q_shared_kv<128, 1>(Q, K, V, O); + break; + case 256: + launch_flash_attn_mma_stages_split_q_shared_kv<256, 1>(Q, K, V, O); + break; + default: + throw std::runtime_error("headdim not support!"); + break; + } + } +} diff --git a/kernels/flash-attn/mma/flash_attn_mma_swizzle_qkv.cu b/kernels/flash-attn/mma/flash_attn_mma_share_qkv_swizzle.cu similarity index 98% rename from kernels/flash-attn/mma/flash_attn_mma_swizzle_qkv.cu rename to kernels/flash-attn/mma/flash_attn_mma_share_qkv_swizzle.cu index 16d179ba..af846429 100644 --- a/kernels/flash-attn/mma/flash_attn_mma_swizzle_qkv.cu +++ b/kernels/flash-attn/mma/flash_attn_mma_share_qkv_swizzle.cu @@ -1,2 +1,2 @@ -// TODO: Manually apply SMEM swizzling instead of padding in -// Split-Q kernels to reduce bank conflicts. +// TODO: Manually apply SMEM swizzling instead of padding in +// Split-Q kernels to reduce bank conflicts. diff --git a/kernels/flash-attn/mma/flash_attn_mma_split_kv_swizzle.cu b/kernels/flash-attn/mma/flash_attn_mma_split_kv_swizzle.cu new file mode 100644 index 00000000..af846429 --- /dev/null +++ b/kernels/flash-attn/mma/flash_attn_mma_split_kv_swizzle.cu @@ -0,0 +1,2 @@ +// TODO: Manually apply SMEM swizzling instead of padding in +// Split-Q kernels to reduce bank conflicts. diff --git a/kernels/flash-attn/mma/flash_attn_mma_split_q_swizzle.cu b/kernels/flash-attn/mma/flash_attn_mma_split_q_swizzle.cu new file mode 100644 index 00000000..af846429 --- /dev/null +++ b/kernels/flash-attn/mma/flash_attn_mma_split_q_swizzle.cu @@ -0,0 +1,2 @@ +// TODO: Manually apply SMEM swizzling instead of padding in +// Split-Q kernels to reduce bank conflicts. diff --git a/kernels/flash-attn/mma/flash_attn_mma_tiling_qk_swizzle.cu b/kernels/flash-attn/mma/flash_attn_mma_tiling_qk_swizzle.cu new file mode 100644 index 00000000..d73fc339 --- /dev/null +++ b/kernels/flash-attn/mma/flash_attn_mma_tiling_qk_swizzle.cu @@ -0,0 +1,1085 @@ +#include "utils.h" + +// Write FlashAttention-2 from scratch using Tensor Cores with MMA PTX instruction. +// The input is Q,K,V, 4D tensor with shape [batch_size, num_heads, seq_len, head_dim]. +// The output is O, a 4D tensor with shape [batch_size, num_heads, seq_len, head_dim]. + +// The FlashAttention-2 algorithm is described in the following paper: +// https://arxiv.org/pdf/2307.08691 + +// Q,K,V,O: [batch_size, num_heads, seq_len, head_dim], [B,H,N,d] +// each block processes Q_tile with shape [Br,d] and full K,V with shape [N,d] + +// Split Q across MMA(Warps) and keep access KV for all MMA(Warps), +// in order to reduce the comm between warps via smem and warp shuffle. + +// MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps +// | 64x64 | warp_KV 0 | +// | warp_QP 0 | MMA 0 ... MMA 0 (x8) | +// | warp_QP 1 | MMA 1 ... MMA 1 (x8) | +// | warp_QP 2 | MMA 2 ... MMA 2 (x8) | +// | warp_QP 3 | MMA 3 ... MMA 3 (x8) | + +// MMA = m16n8k16, Br=16x8=128, Bc=8x16=128, layout: 8 warps +// | 128x128 | warp_KV 0 | +// | warp_QP 0 | MMA 0 ... MMA 0 (x16) | +// | warp_QP 1 | MMA 1 ... MMA 1 (x16) | +// | warp_QP 2 | MMA 2 ... MMA 2 (x16) | +// | warp_QP 3 | MMA 3 ... MMA 3 (x16) | +// | warp_QP 4 | MMA 4 ... MMA 4 (x16) | +// | warp_QP 5 | MMA 5 ... MMA 5 (x16) | +// | warp_QP 6 | MMA 6 ... MMA 6 (x16) | +// | warp_QP 7 | MMA 7 ... MMA 7 (x16) | + +// MMA = m16n8k16, Br=16x8=128, Bc=8x8=64, layout: 8 warps +// | 128x64 | warp_KV 0 | +// | warp_QP 0 | MMA 0 ... MMA 0 (x8) | +// | warp_QP 1 | MMA 1 ... MMA 1 (x8) | +// | warp_QP 2 | MMA 2 ... MMA 2 (x8) | +// | warp_QP 3 | MMA 3 ... MMA 3 (x8) | +// | warp_QP 4 | MMA 4 ... MMA 4 (x8) | +// | warp_QP 5 | MMA 5 ... MMA 5 (x8) | +// | warp_QP 6 | MMA 6 ... MMA 6 (x8) | +// | warp_QP 7 | MMA 7 ... MMA 7 (x8) | + +// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of +// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to +// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to +// extend D (head dimension) up to 1024. Performance optimizations are ongoing. +// Stay tuned for updates ~ + +// Manually apply SMEM swizzling instead of padding in +// Split-Q kernels to reduce bank conflicts. + +// i: row index; j: col index. +// e.g kColStride = 64, kStep = 8 -> load 8 half as 128 bits memory issue. +template +static __device__ __forceinline__ int swizzle_permuted_j(int i, int j) { + // ------------------------------------------- + // --------------swizzle layout--------------- + // -------------col 0~64, step 8-------------- + // ------------------------------------------- + // | row 0 | (0, 8, 16, 24, 32, 40, 48, 56) | + // | row 1 | (0, 8, 16, 24, 32, 40, 48, 56) | + // | row 2 | (0, 8, 16, 24, 32, 40, 48, 56) | + // | row 3 | (0, 8, 16, 24, 32, 40, 48, 56) | + // ------------------------------------------- + // | row 4 | (8, 0, 24, 16, 40, 32, 56, 48) | + // | row 5 | (8, 0, 24, 16, 40, 32, 56, 48) | + // | row 6 | (8, 0, 24, 16, 40, 32, 56, 48) | + // | row 7 | (8, 0, 24, 16, 40, 32, 56, 48) | + // ------------------------------------------- + // | row 8 | (16, 24, 0, 8, 48, 56, 32, 40) | + // | row 9 | (16, 24, 0, 8, 48, 56, 32, 40) | + // | row 10 | (16, 24, 0, 8, 48, 56, 32, 40) | + // | row 11 | (16, 24, 0, 8, 48, 56, 32, 40) | + // ------------------------------------------- + // | row 12 | (24, 16, 8, 0, 56, 48, 40, 32) | + // | row 13 | (24, 16, 8, 0, 56, 48, 40, 32) | + // | row 14 | (24, 16, 8, 0, 56, 48, 40, 32) | + // | row 15 | (24, 16, 8, 0, 56, 48, 40, 32) | + // ------------------------------------------- + // swizzle: ((int(j / kStep) ^ int(i / 4)) % int(kColStride / kStep)) * kStep; + static_assert(kStep == 4 || kStep == 8, "kStep must be 8 or 4."); + static_assert(kColStride % kStep == 0, "kColStride must be multiple of kStep."); + if constexpr (kStep == 8) { + return (((j >> 3) ^ (i >> 2)) % (kColStride >> 3)) << 3; + } else { + static_assert(kStep == 4); + return (((j >> 2) ^ (i >> 2)) % (kColStride >> 2)) << 2; + } +} + +// i: row index; j: col index +// e.g kColStride = kMmaAtomK = 16, kStep = 8 -> load 8 half as 128 bits memory issue. +template +static __device__ __forceinline__ int swizzle_permuted_Q_j(int i, int j) { + // ------------------- + // --swizzle layout--- + // -col 0~16, step 8-- + // ------------------- + // | row 0 | (0, 8) | + // | row 1 | (0, 8) | + // | row 2 | (0, 8) | + // | row 3 | (0, 8) | + // ------------------- + // | row 4 | (8, 0) | + // | row 5 | (8, 0) | + // | row 6 | (8, 0) | + // | row 7 | (8, 0) | + // ------------------- + // | row 8 | (0, 8) | + // | row 9 | (0, 8) | + // | row 10 | (0, 8) | + // | row 11 | (0, 8) | + // ------------------- + // | row 12 | (8, 0) | + // | row 13 | (8, 0) | + // | row 14 | (8, 0) | + // | row 15 | (8, 0) | + // ------------------- + return swizzle_permuted_j(i, j); +} + +// i: row index; j: col index +// e.g kColStride = kMmaAtomK = 16, kStep = 8 -> load 8 half as 128 bits memory issue. +template +static __device__ __forceinline__ int swizzle_permuted_K_j(int i, int j) { + // ------------------- + // --swizzle layout--- + // -col 0~16, step 8-- + // ------------------- + // | row 0 | (0, 8) | + // | row 1 | (0, 8) | + // | row 2 | (0, 8) | + // | row 3 | (0, 8) | + // ------------------- + // | row 4 | (8, 0) | + // | row 5 | (8, 0) | + // | row 6 | (8, 0) | + // | row 7 | (8, 0) | + // ------------------- + // | row 8 | (0, 8) | + // | row 9 | (0, 8) | + // | row 10 | (0, 8) | + // | row 11 | (0, 8) | + // ------------------- + // | row 12 | (8, 0) | + // | row 13 | (8, 0) | + // | row 14 | (8, 0) | + // | row 15 | (8, 0) | + // ------------------- + return swizzle_permuted_j(i, j); +} + +// i: row index; j: col index. +// e.g kColStride = kHeadDim = 64, kStep = 8 -> load 8 half as 128 bits memory issue. +template +static __device__ __forceinline__ int swizzle_permuted_V_j(int i, int j) { + // ------------------------------------------- + // --------------swizzle layout--------------- + // -------------col 0~64, step 8-------------- + // ------------------------------------------- + // | row 0 | (0, 8, 16, 24, 32, 40, 48, 56) | + // | row 1 | (0, 8, 16, 24, 32, 40, 48, 56) | + // | row 2 | (0, 8, 16, 24, 32, 40, 48, 56) | + // | row 3 | (0, 8, 16, 24, 32, 40, 48, 56) | + // ------------------------------------------- + // | row 4 | (8, 0, 24, 16, 40, 32, 56, 48) | + // | row 5 | (8, 0, 24, 16, 40, 32, 56, 48) | + // | row 6 | (8, 0, 24, 16, 40, 32, 56, 48) | + // | row 7 | (8, 0, 24, 16, 40, 32, 56, 48) | + // ------------------------------------------- + // | row 8 | (16, 24, 0, 8, 48, 56, 32, 40) | + // | row 9 | (16, 24, 0, 8, 48, 56, 32, 40) | + // | row 10 | (16, 24, 0, 8, 48, 56, 32, 40) | + // | row 11 | (16, 24, 0, 8, 48, 56, 32, 40) | + // ------------------------------------------- + // | row 12 | (24, 16, 8, 0, 56, 48, 40, 32) | + // | row 13 | (24, 16, 8, 0, 56, 48, 40, 32) | + // | row 14 | (24, 16, 8, 0, 56, 48, 40, 32) | + // | row 15 | (24, 16, 8, 0, 56, 48, 40, 32) | + // ------------------------------------------- + return swizzle_permuted_j(i, j); +} + +template< + const int kHeadDim, // Headdim, 32,64,128 + const int kMmaAtomM, // MMA Atom M, 16 + const int kMmaAtomN, // MMA Atom N, 8 + const int kMmaAtomK, // MMA Atom K, 16 + const int kMmaTileSeqLenQ, // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] + const int kMmaTileSeqLenK, // 1, more MMA(warp), N=8*1 =8, Q@K^T=[Br(M), d(K)]@[d(K), Bc(N)] + const int kMmaTileSeqLenP, // 4, more MMA(warp), M=16*4=64, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] + const int kMmaTileHeadDimV, // 1, more MMA(warp), N=8*1 =8, P@V =[Br(M),Bc(K)]@[Bc(K), d(N) ] + const int kWarpTileSeqLenQ, // 1, more values, M, Br=64*1=64, matmul M + const int kWarpTileSeqLenK, // 8, more values, N, Bc=8*8 =64, matmul N + const int kWarpTileSeqLenP, // 1, more values, M, Br=64*1=64, matmul M + const int kWarpTileHeadDimV, // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|... + const int kStage, + const int kPadQ, + const int kPadK, + const int kPadV + > +__global__ void __launch_bounds__( + WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK) +flash_attn_mma_stages_split_q_tiling_qk_swizzle_kernel(half* Q, + half* K, + half* V, + half* O, + int QKV_seqlen, + int QKV_head) { + // Matmul Layout: Q[Br,d]@K^T[d,Bc] NT, P[Br,Bc]@V[Bc,d] NN. + // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. + static_assert(kMmaAtomM == 16 && kMmaAtomN == 8 && kMmaAtomK == 16); // m16n8k16 + static_assert(kMmaTileSeqLenQ <= 8 && kMmaTileSeqLenK == 1); // Q@K^T + static_assert(kMmaTileSeqLenP <= 8 && kMmaTileHeadDimV == 1); // P@V + static_assert(kWarpTileSeqLenQ == 1 && kWarpTileSeqLenK <= 16); // Q@K^T + // kWarpTileHeadDimV: d=8*(1|2|3|4|...) = 8|...|32|64|96|128|..., etc. + // e.g, kWarpTileHeadDimV = 8 -> d = 8*8 = 64; 16 -> d = 8*16 = 128. + static_assert(kWarpTileSeqLenP == 1 && kWarpTileHeadDimV == ( + kHeadDim / (kMmaAtomN * kMmaTileHeadDimV))); // P@V + static_assert(kStage < 3 && kStage > 0); + static_assert(kPadQ >= 0 && kPadQ % 8 == 0); // 0,8,16 + static_assert(kPadK >= 0 && kPadK % 8 == 0); // 0,8,16 + static_assert(kPadV >= 0 && kPadV % 8 == 0); // 0,8,16 + constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64 + constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64 + static_assert(Br >= Bc); // for shared memory reuse. + constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads + // Now, N must be mutliples of Bc(32/64) for KV tiling across seqlen. + const int Tc = div_ceil(QKV_seqlen, Bc); // Tc K_tile[Bc,d] + const float scale = 1.0f / sqrt((float) kHeadDim); + + // grid(div_ceil(QKV_seqlen, Br), QKV_batch * QKV_head), (x,y,z) + const int QKV_batch_id = blockIdx.y / QKV_head; // Batch size + const int QKV_head_id = blockIdx.y % QKV_head; // Head num + const int Q_tile_id = blockIdx.x; // Q tile_id, range [0, Tr] + const int O_tile_id = Q_tile_id; // O tile_id, same as Q. + const int tid = threadIdx.x; // within block + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int lane_id = tid % WARP_SIZE; // 0~31 + const int warp_QP = warp_id; // 0,1,2,3 or 0~7 + const int warp_KV = 0; // 0 + // MMA Layout [Br,Bc]=[64,64], MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps + // | 64x64 | warp_KV 0 | + // | warp_QP 0 | MMA 0 ... MMA 0 (x8) | + // | warp_QP 1 | MMA 1 ... MMA 1 (x8) | + // | warp_QP 2 | MMA 2 ... MMA 2 (x8) | + // | warp_QP 3 | MMA 3 ... MMA 3 (x8) | + // MMA Layout [Br,Bc]=[128,128], MMA = m16n8k16, Br=16x8=128, Bc=8x16=128, layout: 8 warps + // | 128x128 | warp_KV 0 | + // | warp_QP 0 | MMA 0 ... MMA 0 (x16) | + // | warp_QP 1 | MMA 1 ... MMA 1 (x16) | + // | warp_QP 2 | MMA 2 ... MMA 2 (x16) | + // | warp_QP 3 | MMA 3 ... MMA 3 (x16) | + // | warp_QP 4 | MMA 4 ... MMA 4 (x16) | + // | warp_QP 5 | MMA 5 ... MMA 5 (x16) | + // | warp_QP 6 | MMA 6 ... MMA 6 (x16) | + // | warp_QP 7 | MMA 7 ... MMA 7 (x16) | + const int Q_gmem_offset = ((QKV_batch_id * QKV_head * QKV_seqlen * kHeadDim) + + (QKV_head_id * QKV_seqlen * kHeadDim)); // Q [seqlen,d] + const int K_gmem_offset = ((QKV_batch_id * QKV_head * QKV_seqlen * kHeadDim) + + (QKV_head_id * QKV_seqlen * kHeadDim)); // K [seqlen,d] + const int V_gmem_offset = Q_gmem_offset; // V [seqlen,d] + const int O_gmem_offset = Q_gmem_offset; // O [seqlen,d] + + // Mapping Q gmem -> tid -> smem, Q[Br,kMmaAtomK]=[64/128,16], 128/256 threads. + int load_smem_Q_Br = (tid / (kNumThreads / Br)); // Br 64, tid / 2, row 0~64 + int load_smem_Q_d = (tid % (kNumThreads / Br)) * (kMmaAtomK / (kNumThreads / Br)); // (tid % 2) * 8, 0,8,... + // Mapping K gmem -> tid -> smem, K[Bc,kMmaAtomK]=[64/128,16], 128 threads. + int load_smem_K_Bc = (tid / (kNumThreads / Bc)); // Bc 64, tid / 2, row 0~64 + int load_smem_K_d = (tid % (kNumThreads / Bc)) * (kMmaAtomK / (kNumThreads / Bc)); // (tid % 2) * 8, 0,8,... + // TODO: Mapping V gmem -> tid -> smem, V[kMmaAtomK,kMmaAtomN]=[16,64/128], 128 threads. + // Mapping V gmem -> tid -> smem, V[kMmaAtomK,d]=[16,64/128], 128 threads. + int load_smem_V_Bc = (tid / (kNumThreads / kMmaAtomK)); // kMmaAtomK 16, tid / 8, row 0~15 + int load_smem_V_d = (tid % (kNumThreads / kMmaAtomK)) * (kHeadDim / (kNumThreads / kMmaAtomK)); // (tid % 8) * 8, 0,8,56... + // global Q row of current head for tile [Br,d] per block. + int load_gmem_Q_Br = Q_tile_id * Br + load_smem_Q_Br; + if (load_gmem_Q_Br >= QKV_seqlen) return; + constexpr bool kIsVCanLoadIn128b = (kHeadDim / (kNumThreads / kMmaAtomK)) % 8 == 0; + constexpr bool kIsVCanLoadIn64b = (kHeadDim / (kNumThreads / kMmaAtomK)) % 4 == 0; + static_assert(kIsVCanLoadIn128b || kIsVCanLoadIn64b, "V can't load in 128b or 64b."); // 32,64,128,192,256,... + + // Shared memory for Q,K,V, we don not need additional smem for O + // collective store which perform via registers reuse and warp shuffle. + extern __shared__ half smem[]; + // Split Q + Shared KV SMEM + Fine grain tiling, only need O(1) SRAM complexity. + constexpr int Q_tile_size = Br * (kMmaAtomK + kPadQ); // Q[Br,16], 64*16*2=2048 bytes, 2M + constexpr int K_tile_size = Bc * (kMmaAtomK + kPadK); // K[Bc,16], 2M + constexpr int V_tile_size = kMmaAtomK * (kHeadDim + kPadV); // V[16,d], 2M + // TODO: optimize QKV kStage smem store layout as in HGEMM. + half* Q_tile_smem = smem; // 8M/16M + half* K_tile_smem = Q_tile_smem + kStage * Q_tile_size; // 8M/16M + half* V_tile_smem = Q_tile_smem; // V may reuse all Q+K smem after Q@K^T. + // stage 1, Q/K smem = 64*16*2/1024=2M, V smem =16*d(64|128|...)*2/1024=2M/4M/.. + // stage 1, total smem = max(QK_smem, V_smem) = 4M if d <= 64 else V_smem. + // stage 1, V shared QK smem, Br=Bc=64, d=64: 2M+(2M) =4M, +Pad(2M) = 6M + // stage 1, V shared QK smem, Br=Bc=128, d=64: 4M+4M =8M, +Pad(2M) = 10M + // stage 2, V shared QK smem, Br=Bc=64, d=64: 4M+(4M) =8M, +Pad(2M) = 10M + // stage 2, V shared QK smem, Br=Bc=128, d=64: 8M+8M =16M, +Pad(2M) = 18M + uint32_t smem_Q_base_ptr = __cvta_generic_to_shared(Q_tile_smem); + uint32_t smem_K_base_ptr = __cvta_generic_to_shared(K_tile_smem); + uint32_t smem_V_base_ptr = __cvta_generic_to_shared(V_tile_smem); + + // --------------------- Registers/SMEM for thread block ------------------------- + // block m_old, l_old, store in lane, use float to keep precision. + float lane_block_row_max_old[kWarpTileSeqLenQ][2]; // [1][2] + float lane_block_row_sum_old[kWarpTileSeqLenQ][2]; // [1][2] + fill_2D_regs(lane_block_row_max_old, -INFINITY); + fill_2D_regs(lane_block_row_sum_old, 0.0f); + + // ---------------------- Registers for S=Q@K^T/O=P@V ---------------------------- + // registers for QKV, S=Q[Br,d]@K[Bc,d]=[Br,Bc] and O=P[Br,Bc]@V[Bc,d]=[Br,d]. + uint32_t R_Q[kWarpTileSeqLenQ][ 4]; // [1][4] + uint32_t R_K[kWarpTileSeqLenK][ 2]; // [8][2] + uint32_t R_V[kWarpTileHeadDimV][2]; // [8][2] + // NOTE: For R_V[kWarpTileHeadDimV][2], kWarpTileHeadDimV will increase with d. + // so, for large d, R_V will need more registers and cause performance down. + // We have to find a way to apply MMA level tiling for V(R_V) for large d. + // registers for current tile_K_seqlen within, [64,64] = S_tile[Br,Bc] + // = Q_tile[Br,d] * K[Bc,d], each thread hold 2x32 bits regs. + uint32_t R_S[kWarpTileSeqLenQ][kWarpTileSeqLenK][ 2]; // [1][8][2] + // registers for tile_K_seqlen O=PV[Br,d]=P@V, [2][2/4][2], 8 or 16 regs. + uint32_t R_O[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] + // registers final Output [D]=final rescale(R_O), [2][2/4][2], 8 or 16 regs. + uint32_t R_D[kWarpTileSeqLenP][kWarpTileHeadDimV][2]; // [1][8][2] + fill_3D_regs(R_S, 0); + fill_3D_regs(R_D, 0); + fill_3D_regs(R_O, 0); + + // : for K^T[d,seqlen] with K^T_tile[d,Bc] + // tile_K_seqlen: compute S_tile[Br,Bc] = Q@K^T = Q_tile[Br,d] * K^T[d,Bc] + #pragma unroll 1 + for (int tile_K_seqlen = 0; tile_K_seqlen < Tc; ++tile_K_seqlen) { + // TODO: process last tile_K_seqlen ? pad to multiple of 8. + + // Q/K g2s + if constexpr (kStage > 1) { + #pragma unroll + for (int stage = 0; stage < (kStage - 1); ++stage) { + // Q g2s + int load_gmem_Q_d = (stage * kMmaAtomK) + load_smem_Q_d; // 0,8 + int load_gmem_Q_addr = ( + Q_gmem_offset + load_gmem_Q_Br * kHeadDim + load_gmem_Q_d); + uint32_t load_smem_Q_ptr = ( + smem_Q_base_ptr + (stage * Q_tile_size + + load_smem_Q_Br * (kMmaAtomK + kPadQ) + + swizzle_permuted_Q_j( + load_smem_Q_Br, load_smem_Q_d)) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kMmaAtomK / (kNumThreads / Br)); i += 8) { + CP_ASYNC_CG(load_smem_Q_ptr + i * 2, &Q[load_gmem_Q_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + + // K g2s + int load_gmem_K_Bc = (tile_K_seqlen * Bc) + load_smem_K_Bc; // < seqlen + int load_gmem_K_d = (stage * kMmaAtomK) + load_smem_K_d; // K [Bc,16] from [seqlen,d] + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (stage * K_tile_size + + load_smem_K_Bc * (kMmaAtomK + kPadK) + + swizzle_permuted_K_j( + load_smem_K_Bc, load_smem_K_d)) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kMmaAtomK / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } // end for stage + + CP_ASYNC_WAIT_GROUP(kStage - 2); // s2->0, s3->1, s4->2 + __syncthreads(); + } // end if kStage > 1 + + // : tile_K_d, kMmaAtomK = 16, K_tile_d[kMmaAtomK,Bc] + // Matmul with NT layout, Q row major, K^T col major. + // NOTE: K[Bc,d] with row major means K^T[d,Bc] in col major. + // S_tile[Br,Bc]=Q_tile[Br,d]@K[Bc,d] + // + fill_3D_regs(R_S, 0); + #pragma unroll + for (int tile_K_d = 0; tile_K_d < (kHeadDim / kMmaAtomK); ++tile_K_d) { + // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0; + int smem_sel = (tile_K_d) % kStage; + // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2; + int smem_sel_next = (tile_K_d + (kStage - 1)) % kStage; + + // stages for Q, K + if constexpr (kStage > 1) { + if ((tile_K_d + 1) < (kHeadDim / kMmaAtomK)) { + // next Q tile g2s + int load_gmem_Q_d = ((tile_K_d + 1) * kMmaAtomK) + load_smem_Q_d; + int load_gmem_Q_addr = ( + Q_gmem_offset + load_gmem_Q_Br * kHeadDim + load_gmem_Q_d); + uint32_t load_smem_Q_ptr = ( + smem_Q_base_ptr + (smem_sel_next * Q_tile_size + + load_smem_Q_Br * (kMmaAtomK + kPadQ) + + swizzle_permuted_Q_j( + load_smem_Q_Br, load_smem_Q_d)) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kMmaAtomK / (kNumThreads / Br)); i += 8) { + CP_ASYNC_CG(load_smem_Q_ptr + i * 2, &Q[load_gmem_Q_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + + // next K tile g2s + int load_gmem_K_Bc = tile_K_seqlen * Bc + load_smem_K_Bc; // < seqlen + int load_gmem_K_d = ((tile_K_d + 1) * kMmaAtomK) + load_smem_K_d; // K [Bc,16] from [seqlen,d] + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel_next * K_tile_size + + load_smem_K_Bc * (kMmaAtomK + kPadK) + + swizzle_permuted_K_j( + load_smem_K_Bc, load_smem_K_d)) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kMmaAtomK / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + } + } else { + // sync load curr Q, K g2s + // curr Q tile g2s + int load_gmem_Q_d = (tile_K_d * kMmaAtomK) + load_smem_Q_d; + int load_gmem_Q_addr = ( + Q_gmem_offset + load_gmem_Q_Br * kHeadDim + load_gmem_Q_d); + uint32_t load_smem_Q_ptr = ( + smem_Q_base_ptr + (smem_sel * Q_tile_size + + load_smem_Q_Br * (kMmaAtomK + kPadQ) + + swizzle_permuted_Q_j( + load_smem_Q_Br, load_smem_Q_d)) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kMmaAtomK / (kNumThreads / Br)); i += 8) { + CP_ASYNC_CG(load_smem_Q_ptr + i * 2, &Q[load_gmem_Q_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + + // curr K tile g2s + int load_gmem_K_Bc = (tile_K_seqlen * Bc) + load_smem_K_Bc; // < seqlen + int load_gmem_K_d = (tile_K_d * kMmaAtomK) + load_smem_K_d; // K [Bc,16] from [seqlen,d] + int load_gmem_K_addr = ( + K_gmem_offset + load_gmem_K_Bc * kHeadDim + load_gmem_K_d); + uint32_t load_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel * K_tile_size + + load_smem_K_Bc * (kMmaAtomK + kPadK) + + swizzle_permuted_K_j( + load_smem_K_Bc, load_smem_K_d)) * sizeof(half) + ); + #pragma unroll + for (int i = 0; i < (kMmaAtomK / (kNumThreads / Bc)); i += 8) { + CP_ASYNC_CG(load_smem_K_ptr + i * 2, &K[load_gmem_K_addr + i], 16); + } + CP_ASYNC_COMMIT_GROUP(); + // Wait curr Q, K tile ready. + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } // end if kStage > 1 + + // Q s2r + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { // Q[Br,d]=[M,K] + int warp_smem_Q_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenQ) + i * kMmaAtomM; + int lane_smem_Q_Br = warp_smem_Q_Br + lane_id % 16; // 0~15 + int lane_smem_Q_d = (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_Q_ptr = ( + smem_Q_base_ptr + (smem_sel * Q_tile_size + + lane_smem_Q_Br * (kMmaAtomK + kPadQ) + + swizzle_permuted_Q_j( + lane_smem_Q_Br, lane_smem_Q_d)) * sizeof(half) + ); + LDMATRIX_X4(R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3], + lane_smem_Q_ptr); // now, R_Q[1][4] + } + + // smem -> reg, load k16n8 from smem K, offset d according tile_K_d. + // ldmatrix.x2 for K_tile_smem, [Bc,kMmaAtomK] from [Bc,d]=[K,N] + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + // load k16n8 via ldmatrix.x2 from K_tile_smem[Bc,d]. + // K[Bc,d] with row major means K^T[d,Bc] in col major. + int warp_smem_K_Bc = warp_KV * (kMmaAtomN * kWarpTileSeqLenK) + j * kMmaAtomN; + int lane_smem_K_Bc = warp_smem_K_Bc + lane_id % 8; // 0~7 + int lane_smem_K_d = ((lane_id / 8) % 2) * 8; // 0,8 + uint32_t lane_smem_K_ptr = ( + smem_K_base_ptr + (smem_sel * K_tile_size + + lane_smem_K_Bc * (kMmaAtomK + kPadK) + + swizzle_permuted_K_j( + lane_smem_K_Bc, lane_smem_K_d)) * sizeof(half) + ); + LDMATRIX_X2(R_K[j][0], R_K[j][1], lane_smem_K_ptr); // R_K + } // end for kWarpTileSeqLenK + if constexpr (kStage < 2) { + // Wait Q, K s2r ready if kStage < 2 in order to avoid + // the next Q, K tile g2s overwrite. + __syncthreads(); + } + + // MMA compute + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + HMMA16816(R_S[i][j][0], R_S[i][j][1], + R_Q[i][0], R_Q[i][1], R_Q[i][2], R_Q[i][3], + R_K[j][0], R_K[j][1], + R_S[i][j][0], R_S[i][j][1]); + } + } + + if constexpr (kStage > 1) { + // Wait next Q, K tile g2s ready. + CP_ASYNC_WAIT_GROUP(kStage - 2); + __syncthreads(); + } + + } // end loop over d, S=Q@K^T + __syncthreads(); + + // V g2s stages. (reuse Q+K smem) load [16,d] from [Bc,d] + if constexpr (kStage > 1) { + #pragma unroll + for (int stage = 0; stage < (kStage - 1); ++stage) { + // V g2s + int load_gmem_V_Bc = ( + (tile_K_seqlen * Bc) + (stage * kMmaAtomK) + load_smem_V_Bc); // 0~15 + int load_gmem_V_d = load_smem_V_d; + int load_gmem_V_addr = ( + V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); + uint32_t load_smem_V_ptr = ( + smem_V_base_ptr + (stage * V_tile_size + + load_smem_V_Bc * (kHeadDim + kPadV) + + load_smem_V_d) * sizeof(half) + ); + // headdim must be multiple of 32, (kHeadDim/8)%8==0 for 128 bits ld. + if constexpr (kIsVCanLoadIn128b) { + // 64,128,192,256,... + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / kMmaAtomK)); i += 8) { + CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + } + } else { + // 32,96,160,224 + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / kMmaAtomK)); i += 4) { + CP_ASYNC_CA(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 8); + } + } + CP_ASYNC_COMMIT_GROUP(); + } // end for stage + } + + // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps + // | 64x64 | warp_KV 0 | + // | warp_QP 0 | MMA 0 ... MMA 0 (x8) | + // | warp_QP 1 | MMA 1 ... MMA 1 (x8) | + // | warp_QP 2 | MMA 2 ... MMA 2 (x8) | + // | warp_QP 3 | MMA 3 ... MMA 3 (x8) | + + // Online safe softmax, warp/block reduce max/sum, row wise + float lane_row_max_new[kWarpTileSeqLenQ][2]; // [1][2] + float lane_row_sum_new[kWarpTileSeqLenQ][2]; // [1][2] + fill_2D_regs(lane_row_max_new, -INFINITY); + fill_2D_regs(lane_row_sum_new, 0.0f); + + // Row max for [Br,Bc] tile, Thread -> Warp -> Block. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + // Thread level reduce max across kWarpTileSeqLenK dim, namely Bc. + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // The layout of the fragments held by different threads for C. (m16n8k16) + // Row\Col 0 1 2 3 4 5 6 7 + // 0 T0: {c0, c1} T1: {c0, c1} T2: {c0, c1} T3: {c0, c1} + // 1 T4: {c0, c1} T5: {c0, c1} T6: {c0, c1} T7: {c0, c1} + // 2 ... + // ... + // 7 T28: {c0, c1} T29: {c0, c1} T30: {c0, c1} T31: {c0, c1} + // 8 T0: {c2, c3} T1: {c2, c3} T2: {c2, c3} T3: {c2, c3} + // 9 T4: {c2, c3} T5: {c2, c3} T6: {c2, c3} T7: {c2, c3} + // 10 ... + // ... + // 15 T28: {c2, c3} T29: {c2, c3} T30: {c2, c3} T31: {c2, c3} + float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3} + // This should be the row max after S = (Q @ K^T) / sqrt(d) + float tmp_max_0 = max(t_reg_S_0.x, t_reg_S_0.y) * scale; + float tmp_max_1 = max(t_reg_S_1.x, t_reg_S_1.y) * scale; + lane_row_max_new[i][0] = max(lane_row_max_new[i][0], tmp_max_0); + lane_row_max_new[i][1] = max(lane_row_max_new[i][1], tmp_max_1); + } // end for kWarpTileSeqLenK + + // Warp level reduce max, warp_size = 4 + // Each thread contains the maximum of 2 rows of Br, + // and only the values of T0, T4, ..., T28 are used. + lane_row_max_new[i][0] = warp_reduce_max(lane_row_max_new[i][0]); + lane_row_max_new[i][1] = warp_reduce_max(lane_row_max_new[i][1]); + } // end for kWarpTileSeqLenQ + __syncthreads(); + + // Exp sum and mul scale_factor for [Br,Bc] tile, Thread -> Warp -> Block. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenQ; ++i) { + // Use latest global row max without update. + // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; + float block_row_max_new_0 = lane_row_max_new[i][0]; + // Br 1, row_id, 8~15, 24~31, 40~47, 56~63; + float block_row_max_new_1 = lane_row_max_new[i][1]; + + float block_row_max_old_0 = lane_block_row_max_old[i][0]; + float block_row_max_old_1 = lane_block_row_max_old[i][1]; + // Apply m_new = max(m_old, m_new) here. + block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0); + block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1); + + #pragma unroll + for (int j = 0; j < kWarpTileSeqLenK; ++j) { + float2 t_reg_S_0 = __half22float2(HALF2(R_S[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_S_1 = __half22float2(HALF2(R_S[i][j][1])); // 8~15 {c2, c3} + // P = Exp(S - m_new), fmaf(x, y, z) = x * y + z; + t_reg_S_0.x = __expf(__fmaf_rn(t_reg_S_0.x, scale, - block_row_max_new_0)); + t_reg_S_0.y = __expf(__fmaf_rn(t_reg_S_0.y, scale, - block_row_max_new_0)); + t_reg_S_1.x = __expf(__fmaf_rn(t_reg_S_1.x, scale, - block_row_max_new_1)); + t_reg_S_1.y = __expf(__fmaf_rn(t_reg_S_1.y, scale, - block_row_max_new_1)); + lane_row_sum_new[i][0] += (t_reg_S_0.x + t_reg_S_0.y); + lane_row_sum_new[i][1] += (t_reg_S_1.x + t_reg_S_1.y); + // Update R_S for P[Br,Bc] = Exp(S-m), point wise. + HALF2(R_S[i][j][0]) = __float22half2_rn(t_reg_S_0); + HALF2(R_S[i][j][1]) = __float22half2_rn(t_reg_S_1); + } // end for kWarpTileSeqLenK + + // Warp level reduce sum, warp_size = 4 + lane_row_sum_new[i][0] = warp_reduce_sum(lane_row_sum_new[i][0]); + lane_row_sum_new[i][1] = warp_reduce_sum(lane_row_sum_new[i][1]); + } // end for kWarpTileSeqLenQ + __syncthreads(); + + // : P[Br,Bc]@V[Bc,d]=[Br,d]=[64,64/128], partion Attention. + // Matmul with NN layout: P[Br,Bc] row major, V[Bc,d] row major. + // Make sure to clear the states in R_O before MMA for P@V for each step. + + // NOTE: Values for P[Br,Bc] already in R_S registers, can we use these + // registers for P(A) matrix directly ? How to do that ? + // according to the A matrix layout for MMA m16n8k16 instruction. + // reference: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // The layout of the fragments held by different threads for A matrix with .f16. + // R\C 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 + // 0 T0: {a0, a1} T1: {a0, a1} T2: {a0, a1} T3: {a0, a1} T0: {a4, a5} T1: {a4, a5} T2: {a4, a5} T3: {a4, a5} + // 1 T4: {a0, a1} T5: {a0, a1} T6: {a0, a1} T7: {a0, a1} T4: {a4, a5} T5: {a4, a5} T6: {a4, a5} T7: {a4, a5} + // 2 (dashed arrow pointing right) + // ... + // 7 T28: {a0, a1} T29: {a0, a1} T30: {a0, a1} T31: {a0, a1} T28: {a4, a5} T29: {a4, a5} T30: {a4, a5} T31: {a4, a5} + // 8 T0: {a2, a3} T1: {a2, a3} T2: {a2, a3} T3: {a2, a3} T0: {a6, a7} T1: {a6, a7} T2: {a6, a7} T3: {a6, a7} + // 9 T4: {a2, a3} T5: {a2, a3} T6: {a2, a3} T7: {a2, a3} T4: {a6, a7} T5: {a6, a7} T6: {a6, a7} T7: {a6, a7} + // 10 (dashed arrow pointing right) + // ... + // 15 T28: {a2, a3} T29: {a2, a3} T30: {a2, a3} T31: {a2, a3} T28: {a6, a7} T29: {a6, a7} T30: {a6, a7} T31: {a6, a7} + + // Wait V g2s stages ready. + if constexpr (kStage > 1) { + CP_ASYNC_WAIT_GROUP(kStage - 2); // s2->0, s3->1, s4->2 + __syncthreads(); + } + + // P@V=[Br,Bc]@[Bc,d] + fill_3D_regs(R_O, 0); + #pragma unroll + for (int tile_V_Bc = 0; tile_V_Bc < (Bc / kMmaAtomK); ++tile_V_Bc) { + // s2 tn 0->0, 1->1, 2->0; s3 tn 0->0, 1->1, 2->2, 3->0; + int smem_sel = (tile_V_Bc) % kStage; + // s2 tn 0->1, 1->0, 2->1; s3 tn 0->2, 1->0, 2->1, 3->2; + int smem_sel_next = (tile_V_Bc + (kStage - 1)) % kStage; + + // stages for V + if constexpr (kStage > 1) { + if ((tile_V_Bc + 1) < (Bc / kMmaAtomK)) { + // next V tile g2s + int load_gmem_V_Bc = ( + (tile_K_seqlen * Bc) + (tile_V_Bc + 1) * kMmaAtomK + load_smem_V_Bc); // 0~15 + int load_gmem_V_d = load_smem_V_d; + int load_gmem_V_addr = ( + V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); + uint32_t load_smem_V_ptr = ( + smem_V_base_ptr + (smem_sel_next * V_tile_size + + load_smem_V_Bc * (kHeadDim + kPadV) + + load_smem_V_d) * sizeof(half) + ); + // headdim must be multiple of 32, (kHeadDim/8)%8==0 for 128 bits ld. + if constexpr (kIsVCanLoadIn128b) { + // 64,128,192,256,... + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / kMmaAtomK)); i += 8) { + CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + } + } else { + // 32,96,160,224 + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / kMmaAtomK)); i += 4) { + CP_ASYNC_CA(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 8); + } + } + CP_ASYNC_COMMIT_GROUP(); + } + } else { + // sync load curr V g2s + int load_gmem_V_Bc = ( + (tile_K_seqlen * Bc) + (tile_V_Bc * kMmaAtomK) + load_smem_V_Bc); // 0~15 + int load_gmem_V_d = load_smem_V_d; + int load_gmem_V_addr = ( + V_gmem_offset + load_gmem_V_Bc * kHeadDim + load_gmem_V_d); + uint32_t load_smem_V_ptr = ( + smem_V_base_ptr + (smem_sel * V_tile_size + + load_smem_V_Bc * (kHeadDim + kPadV) + + load_smem_V_d) * sizeof(half) + ); + // headdim must be multiple of 32, (kHeadDim/8)%8==0 for 128 bits ld. + if constexpr (kIsVCanLoadIn128b) { + // 64,128,192,256,... + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / kMmaAtomK)); i += 8) { + CP_ASYNC_CG(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 16); + } + } else { + // 32,96,160,224 + #pragma unroll + for (int i = 0; i < (kHeadDim / (kNumThreads / kMmaAtomK)); i += 4) { + CP_ASYNC_CA(load_smem_V_ptr + i * 2, &V[load_gmem_V_addr + i], 8); + } + } + CP_ASYNC_COMMIT_GROUP(); + // Wait curr V tile ready. + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + + // Load k16n8 V from smem -> regs, R_KV, ldmatrix.x2.trans. + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { + int warp_smem_V_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; // d, matmaul N + int lane_smem_V_Bc = lane_id % 16; // 0~15; Bc, matmul K + int lane_smem_V_d = warp_smem_V_d; // 0 + uint32_t lane_smem_V_ptr = ( + smem_V_base_ptr + (smem_sel * V_tile_size + + lane_smem_V_Bc * (kHeadDim + kPadV) + + lane_smem_V_d) * sizeof(half) + ); + LDMATRIX_X2_T(R_V[j][0], R_V[j][1], lane_smem_V_ptr); // R_V + } + if constexpr (kStage < 2) { + // Wait V s2r ready if kStage < 2 in order to avoid + // the next V tile g2s overwrite. + __syncthreads(); + } + + // For R_S[1][8][2], mapping the layout below of P matrix. + // MMA = m16n8k16, Br=16x4=64, Bc=8x8=64, layout: 4 warps + // | 64x64 | warp_KV 0 | + // | warp_QP 0 | MMA 0 ... MMA 0 (x8) | + // | warp_QP 1 | MMA 1 ... MMA 1 (x8) | + // | warp_QP 2 | MMA 2 ... MMA 2 (x8) | + // | warp_QP 3 | MMA 3 ... MMA 3 (x8) | + // tile_V_Bc = 0, all curr MMAs(0~4) need slice P[:, 0:16], 0, 1; stored in all MMAs. + // tile_V_Bc = 1, all curr MMAs(0~4) need slice P[:, 16:32], 2, 3; stored in all MMAs. + // tile_V_Bc = 2, all curr MMAs(0~4) need slice P[:, 32:48], 4, 5; stored in all MMAs. + // tile_V_Bc = 3, all curr MMAs(0~4) need slice P[:, 48:64], 6, 7; stored in all MMAs. + int w = tile_V_Bc * 2; // MMA(Warp) selected, 0, 2, 4, 6 + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1 + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ... + HMMA16816(R_O[i][j][0], R_O[i][j][1], + R_S[i][w][0], R_S[i][w][1], R_S[i][w + 1][0], R_S[i][w + 1][1], + R_V[j][0], R_V[j][1], + R_O[i][j][0], R_O[i][j][1]); + } + } + + if constexpr (kStage > 1) { + // Wait next V tile g2s ready. + CP_ASYNC_WAIT_GROUP(kStage - 2); + __syncthreads(); + } + + } // end for V Bc. + __syncthreads(); + + // Rescale O -> Update row sum Exp -> then, Update row max. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // kWarpTileSeqLenQ=kWarpTileSeqLenP=1 + // m = max(m_old, m_new), l = exp(m_old - m) * l_old + l_new (FA2 paper) + // Br 0, row_id, 0~7, 16~23, 32~39, 48~55; Br 1, row_id, 8~15, 24~31, 40~47, 56~63 + float block_row_max_new_0 = lane_row_max_new[i][0]; + float block_row_max_new_1 = lane_row_max_new[i][1]; + float block_row_sum_new_0 = lane_row_sum_new[i][0]; + float block_row_sum_new_1 = lane_row_sum_new[i][1]; + + float block_row_max_old_0 = lane_block_row_max_old[i][0]; + float block_row_max_old_1 = lane_block_row_max_old[i][1]; + // NOTE: max(-inf, val) = val. + block_row_max_new_0 = max(block_row_max_old_0, block_row_max_new_0); + block_row_max_new_1 = max(block_row_max_old_1, block_row_max_new_1); + // Avoid inf value while using m_old for rescaling O. + block_row_max_old_0 = (tile_K_seqlen > 0 ? block_row_max_old_0 : + block_row_max_new_0); + block_row_max_old_1 = (tile_K_seqlen > 0 ? block_row_max_old_1 : + block_row_max_new_1); + + // rescale factor for O and l, exp(m_old - m) + float rescale_o_factor_0 = __expf(block_row_max_old_0 - block_row_max_new_0); + float rescale_o_factor_1 = __expf(block_row_max_old_1 - block_row_max_new_1); + // 0. Rescale O: Online rescaling O each tile_K_seqlen step, need m_new, m_old. + // m = max(m_old, m_new), O_new[Br,d] = exp(m_old - m) * O_old + P@V + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ... + float2 t_reg_O_0 = __half22float2(HALF2(R_O[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_O_1 = __half22float2(HALF2(R_O[i][j][1])); // 8~15 {c2, c3} + float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3} + // Note that the formula in the FA2 paper is incorrect; here, + // the inverse of the exp function should not be taken, as it + // would result in an error during rescaling, namely, you have + // use exp(m_old - m_new), not 1/(m_old - m_new). + // O_new[Br,d] = exp(m_old - m_new) * O_old + P@V + t_reg_D_0.x = __fmaf_rn(rescale_o_factor_0, t_reg_D_0.x, t_reg_O_0.x); + t_reg_D_0.y = __fmaf_rn(rescale_o_factor_0, t_reg_D_0.y, t_reg_O_0.y); + t_reg_D_1.x = __fmaf_rn(rescale_o_factor_1, t_reg_D_1.x, t_reg_O_1.x); + t_reg_D_1.y = __fmaf_rn(rescale_o_factor_1, t_reg_D_1.y, t_reg_O_1.y); + HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0); + HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1); + } // end for kWarpTileHeadDimV. + + // Now, we can update m, l after O has been scaled. + // 1. First, update block row sum Exp for each lane which + // need both m_new and m_old. + float block_row_sum_old_0 = lane_block_row_sum_old[i][0]; + float block_row_sum_old_1 = lane_block_row_sum_old[i][1]; + // Update l = exp(m_old - m_new) * l_old + row_sum(P). + lane_block_row_sum_old[i][0] = (__fmaf_rn( + rescale_o_factor_0, block_row_sum_old_0, block_row_sum_new_0)); + lane_block_row_sum_old[i][1] = (__fmaf_rn( + rescale_o_factor_1, block_row_sum_old_1, block_row_sum_new_1)); + // 2. Then, update block row max for each lane. + lane_block_row_max_old[i][0] = block_row_max_new_0; + lane_block_row_max_old[i][1] = block_row_max_new_1; + } + } // end loop over N + __syncthreads(); + + // Finaly, we still have to rescale O once more. + // O_output(D) = ( 1/l_final ) * O_final (FA2 paper) + // NOTE: Here, we choose to reuse R_O as final output + // in order to reduce regs usage. + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1 + float rescale_factor_0 = __frcp_rn(lane_block_row_sum_old[i][0]); + float rescale_factor_1 = __frcp_rn(lane_block_row_sum_old[i][1]); + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8, 16, 32, ... + float2 t_reg_D_0 = __half22float2(HALF2(R_D[i][j][0])); // 0~7 {c0, c1} + float2 t_reg_D_1 = __half22float2(HALF2(R_D[i][j][1])); // 8~15 {c2, c3} + t_reg_D_0.x = rescale_factor_0 * t_reg_D_0.x; + t_reg_D_0.y = rescale_factor_0 * t_reg_D_0.y; + t_reg_D_1.x = rescale_factor_1 * t_reg_D_1.x; + t_reg_D_1.y = rescale_factor_1 * t_reg_D_1.y; + HALF2(R_D[i][j][0]) = __float22half2_rn(t_reg_D_0); + HALF2(R_D[i][j][1]) = __float22half2_rn(t_reg_D_1); + } + } + + // Store O(D): Write O[Br,d] from regs -> gmem, collective store + // with reg reuse & warp shuffle. need R_Z[2][4]. + // TODO: reuse Q smem for collective store: regs -> smem -> gmem + #pragma unroll + for (int i = 0; i < kWarpTileSeqLenP; ++i) { // 1 + #pragma unroll + for (int j = 0; j < kWarpTileHeadDimV; ++j) { // 8 + // we have to use new R_Z regs for collective store. + uint32_t R_Z[2][4]; + R_Z[0][0] = R_D[i][j][0]; R_Z[1][0] = R_D[i][j][1]; // warp_size 4 + R_Z[0][1] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 1, 4); + R_Z[0][2] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 2, 4); + R_Z[0][3] = __shfl_sync((0xffffffff), R_D[i][j][0], lane_id + 3, 4); + R_Z[1][1] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 1, 4); + R_Z[1][2] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 2, 4); + R_Z[1][3] = __shfl_sync((0xffffffff), R_D[i][j][1], lane_id + 3, 4); + // st.global.v4 128 bits. [Br,d] + if (lane_id % 4 == 0) { + // (0/1)*32 + (0/1)*16=(0,16,32,48), + 0~7 -> 0~56 + int store_warp_regs_O_Br = warp_QP * (kMmaAtomM * kWarpTileSeqLenP ) + i * kMmaAtomM; + int store_lane_gmem_O_Br = O_tile_id * Br + store_warp_regs_O_Br + lane_id / 4; // 0~7 + // (0~3)*16 + (0/1)*8=(0,8,16,24,...,48,56) + int store_warp_regs_O_d = warp_KV * (kMmaAtomN * kWarpTileHeadDimV) + j * kMmaAtomN; + int store_lane_gmem_O_d = store_warp_regs_O_d; // (0~3)*16+(0/8) + int store_gmem_O_addr_0 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 0) * kHeadDim + store_lane_gmem_O_d); + int store_gmem_O_addr_1 = ( + O_gmem_offset + (store_lane_gmem_O_Br + 8) * kHeadDim + store_lane_gmem_O_d); + LDST128BITS(O[store_gmem_O_addr_0]) = LDST128BITS(R_Z[0][0]); + LDST128BITS(O[store_gmem_O_addr_1]) = LDST128BITS(R_Z[1][0]); + } + } // end for kWarpTileHeadDimV + } // end for kWarpTileSeqLenQ +} + +// Launch kernel for flash_attn_mma_stages_split_q_tiling_qk +template +void launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle( + torch::Tensor Q, torch::Tensor K, torch::Tensor V, torch::Tensor O) { + // Now: fixed tile BrxBc=128x128 for d>= 128, 64x64 for d<128. + // TODO: dynamic tile size for Br, Bc according to kHeadDim and shared memory size. + constexpr int kMmaAtomM = 16; + constexpr int kMmaAtomN = 8; + constexpr int kMmaAtomK = 16; + constexpr int kMmaTileSeqLenQ = (kHeadDim < 128) ? 4 : 8; + constexpr int kMmaTileSeqLenK = 1; + constexpr int kMmaTileSeqLenP = (kHeadDim < 128) ? 4 : 8; + constexpr int kMmaTileHeadDimV = 1; + constexpr int kWarpTileSeqLenQ = 1; + constexpr int kWarpTileSeqLenK = (kHeadDim < 128) ? 8 : 16; + constexpr int kWarpTileSeqLenP = 1; + constexpr int kWarpTileHeadDimV = (kHeadDim / (kMmaAtomN * kMmaTileHeadDimV)); // (d=64)8,(d=128)16,32,.... + constexpr int Br = kMmaAtomM * kMmaTileSeqLenQ * kWarpTileSeqLenQ; // 16*4*1=64 + constexpr int Bc = kMmaAtomN * kMmaTileSeqLenK * kWarpTileSeqLenK; // 8*1*8=64 + constexpr int kNumThreads = WARP_SIZE * kMmaTileSeqLenQ * kMmaTileSeqLenK; // 32*4*1=128, num threads + constexpr int kPadQ = 0; + constexpr int kPadK = 0; + constexpr int kPadV = 8; + + // static int kMaxSramPerBlock; + // cudaDeviceGetAttribute(&kMaxSramPerBlock, cudaDevAttrMaxSharedMemoryPerBlock, 0); + // Calculate SRAM size needed per block, Q,K,V smem size, V shared the QK smem. + constexpr int QK_smem_size = (kStage * (Br * (kMmaAtomK + kPadQ)) + + kStage * (Bc * (kMmaAtomK + kPadK))); + // Now, for V_smem_size, s=2, d=4M, 16 regs; d=128, 8M, 32 regs; + // d=256, 16M, 64 regs; d=512, 32M, 128 regs; d=1024, 64M, 256 regs; + // TODO: sub-tiling for d while perform P@V, kMmaAtomK * (kMmaAtomN) + constexpr int V_smem_size = (kStage * (kMmaAtomK * (kHeadDim + kPadV))); + // try to let V reuse all Q+K smem after Q@K^T, reduce smem usage. + const int smem_max_size = max(QK_smem_size, V_smem_size) * sizeof(half); + + const int QKV_batch = Q.size(0); + const int QKV_head = Q.size(1); + const int QKV_seqlen = Q.size(2); // QKV_seqlen + assert(QKV_seqlen % max(Br, Bc) == 0); // multiple of max(Br, Bc) + + // TODO: How to apply block swizzle to improve L2 Cache hit rate? + // NOTE: reorder (B,H,Tr) -> (Tr,B*H) seems can improve L2 Cache hit rate. + // This might be because SM schedules blocks starting from the x-dimension. + // Placing Tr at the forefront ensures that identical KV pairs are placed + // in consecutive scheduling queues, thereby improving L2 Cache hit rates. + // Tr(=N/Br), batch_size x num_heads + dim3 grid(div_ceil(QKV_seqlen, Br), QKV_batch * QKV_head); + dim3 block(kNumThreads); // 4/8 warps per block + // when N >= 6016, stage 1 will have precision gap, why? + + cudaFuncSetAttribute( + flash_attn_mma_stages_split_q_tiling_qk_swizzle_kernel< + kHeadDim, + kMmaAtomM, + kMmaAtomN, + kMmaAtomK, + kMmaTileSeqLenQ, + kMmaTileSeqLenK, + kMmaTileSeqLenP, + kMmaTileHeadDimV, + kWarpTileSeqLenQ, + kWarpTileSeqLenK, + kWarpTileSeqLenP, + kWarpTileHeadDimV, + kStage, + kPadQ, + kPadK, + kPadV + >, + cudaFuncAttributeMaxDynamicSharedMemorySize, + // kMaxSramPerBlock + 98304 + ); + + flash_attn_mma_stages_split_q_tiling_qk_swizzle_kernel< + kHeadDim, + kMmaAtomM, + kMmaAtomN, + kMmaAtomK, + kMmaTileSeqLenQ, + kMmaTileSeqLenK, + kMmaTileSeqLenP, + kMmaTileHeadDimV, + kWarpTileSeqLenQ, + kWarpTileSeqLenK, + kWarpTileSeqLenP, + kWarpTileHeadDimV, + kStage, + kPadQ, + kPadK, + kPadV + ><<>>( + reinterpret_cast(Q.data_ptr()), + reinterpret_cast(K.data_ptr()), + reinterpret_cast(V.data_ptr()), + reinterpret_cast(O.data_ptr()), + QKV_seqlen, + QKV_head + ); +} + +void flash_attn_mma_stages_split_q_tiling_qk_swizzle(torch::Tensor Q, + torch::Tensor K, + torch::Tensor V, + torch::Tensor O, + int stages) { + CHECK_TORCH_TENSOR_DTYPE(Q, torch::kHalf) // Q [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(K, torch::kHalf) // K [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(V, torch::kHalf) // V [B,H,N,D] + CHECK_TORCH_TENSOR_DTYPE(O, torch::kHalf) // O [B,H,N,D] + const int d = Q.size(3); // B, H, N, d + + if (stages > 1) { + switch (d) + { + case 32: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<32, 2>(Q, K, V, O); + break; + case 64: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<64, 2>(Q, K, V, O); + break; + case 96: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<96, 2>(Q, K, V, O); + break; + case 128: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<128, 2>(Q, K, V, O); + break; + case 256: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<256, 2>(Q, K, V, O); + break; + case 512: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<512, 2>(Q, K, V, O); + break; + case 1024: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<1024, 2>(Q, K, V, O); + break; + default: + throw std::runtime_error("headdim not support!"); + break; + } + } else { + switch (d) + { + case 32: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<32, 1>(Q, K, V, O); + break; + case 64: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<64, 1>(Q, K, V, O); + break; + case 96: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<96, 1>(Q, K, V, O); + break; + case 128: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<128, 1>(Q, K, V, O); + break; + case 256: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<256, 1>(Q, K, V, O); + break; + case 512: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<512, 1>(Q, K, V, O); + break; + case 1024: + launch_flash_attn_mma_stages_split_q_tiling_qk_swizzle<1024, 1>(Q, K, V, O); + break; + default: + throw std::runtime_error("headdim not support!"); + break; + } + } +} diff --git a/kernels/flash-attn/mma/flash_attn_mma_tiling_qkv_swizzle.cu b/kernels/flash-attn/mma/flash_attn_mma_tiling_qkv_swizzle.cu new file mode 100644 index 00000000..af846429 --- /dev/null +++ b/kernels/flash-attn/mma/flash_attn_mma_tiling_qkv_swizzle.cu @@ -0,0 +1,2 @@ +// TODO: Manually apply SMEM swizzling instead of padding in +// Split-Q kernels to reduce bank conflicts. diff --git a/kernels/flash-attn/pybind/flash_attn.cc b/kernels/flash-attn/pybind/flash_attn.cc index 8ba636c8..0faa796d 100644 --- a/kernels/flash-attn/pybind/flash_attn.cc +++ b/kernels/flash-attn/pybind/flash_attn.cc @@ -35,10 +35,18 @@ void flash_attn_mma_stages_split_q_tiling_qk(torch::Tensor Q, torch::Tensor O, int stages); +void flash_attn_mma_stages_split_q_tiling_qk_swizzle(torch::Tensor Q, + torch::Tensor K, + torch::Tensor V, + torch::Tensor O, + int stages); + + PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_kv) TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q) TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q_shared_kv) TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q_shared_qkv) TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q_tiling_qk) + TORCH_BINDING_COMMON_EXTENSION(flash_attn_mma_stages_split_q_tiling_qk_swizzle) } diff --git a/kernels/flash-attn/tools/print_swizzle_layout.py b/kernels/flash-attn/tools/print_swizzle_layout.py new file mode 100644 index 00000000..dcb9d48e --- /dev/null +++ b/kernels/flash-attn/tools/print_swizzle_layout.py @@ -0,0 +1,46 @@ +import argparse + + +def pretty_print_line(m: str = "", sep: str = "-", width: int = 130): + res_len = width - len(m) + left_len = int(res_len / 2) + right_len = res_len - left_len + pretty_line = sep * left_len + m + sep * right_len + print(pretty_line) + + +def swizzle_permuted_j(i: int, j: int, col_stride: int = 64, step: int = 8): + # i: row index; j: col index. + return ((int(j / step) ^ int(i / 4)) % int(col_stride / step)) * step + + +def print_swizzle_layout(rows: int = 16, col_stride: int = 64, step: int = 8): + str_len = 0 + for i in range(rows): + layout = tuple(swizzle_permuted_j(i, j, col_stride, step) + for j in range(0, col_stride, step)) + layout_str = (f"| row {i:<2} | {layout} |") + str_len = len(layout_str) + if (i == 0): + print("-" * str_len) + pretty_print_line(f"swizzle layout", width=str_len) + pretty_print_line(f"col 0~{col_stride}, step {step}", width=str_len) + print("-" * str_len) + print(layout_str) + if ((i + 1) % 4 == 0 and i != (rows - 1)): + print("-" * str_len) + print("-" * str_len) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--col_stride", "--col", type=int, default=64) + parser.add_argument("--step", type=int, default=8) + parser.add_argument("--rows", type=int, default=16) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + print_swizzle_layout(args.rows, args.col_stride, args.step) + diff --git a/kernels/hgemm/README.md b/kernels/hgemm/README.md index 6d66037c..7ccac9f1 100755 --- a/kernels/hgemm/README.md +++ b/kernels/hgemm/README.md @@ -1,5 +1,5 @@ -## ⚡️⚡️Toy-HGEMM Library: May achieve the 98%~100% performance of cuBLAS 🎉🎉 +# ⚡️⚡️Toy-HGEMM: Achieve the 98%~100% TFLOPS of cuBLAS 🎉🎉 ![toy-hgemm-library](https://github.com/user-attachments/assets/962bda14-b494-4423-b8eb-775da9f5503d) @@ -23,7 +23,7 @@ Currently, on NVIDIA L20, RTX 4090 and RTX 3080 Laptop, compared with cuBLAS's d |✔️|✔️|✔️|✔️| |Copy Async (cp.async.cg/ca)|Tile MMA (More Threads)|Tile Warp (More Values)|Multi Stages(2/3/4/5)| |✔️|✔️|✔️|✔️| -|Register Double Buffers|Block Swizzle (Zigzag N)|Warp Swizzle (Zigzag N)| SMEM Swizzle (CUTLASS/CuTe)| +|Register Double Buffers|Block Swizzle (Zigzag N)|Warp Swizzle (Zigzag N)| SMEM Swizzle (CuTe/MMA) | |✔️|✔️|✔️|✔️| |Collective Store (Warp Shuffle & Reg Reuse)|Row Major (NN)|Col Major (TN)|SGEMM FP32/TF32| |✔️|✔️|✔️|✔️| @@ -79,6 +79,7 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4(torch::Tensor a, torch: void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); void hgemm_mma_stages_block_swizzle_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); +void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); ``` ## 📖 Contents diff --git a/kernels/hgemm/hgemm.py b/kernels/hgemm/hgemm.py index 008161d0..4bde91a0 100644 --- a/kernels/hgemm/hgemm.py +++ b/kernels/hgemm/hgemm.py @@ -160,11 +160,11 @@ def run_benchmark(perf_func: callable, else: improve = 0 MAX_TFLOPS = TFLOPS - print(f"{out_info:>50}: {out_val}, time:{mean_time_ms}ms, " + print(f"{out_info:>52}: {out_val}, time:{mean_time_ms}ms, " f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}(+{improve:.2f}%)") else: if not only_show_improved or "cublas" in tag: - print(f"{out_info:>50}: {out_val}, time:{mean_time_ms}ms, " + print(f"{out_info:>52}: {out_val}, time:{mean_time_ms}ms, " f"swizzle: {swizzle_stride:<4}, TFLOPS: {TFLOPS:<6.2f}") if show_matrix: print(out) if args.plot_flops: @@ -359,6 +359,9 @@ def get_mnk(sep: int = args.SEP): run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage4+dsmem)", c, stages=4) run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage3+dsmem)", c, stages=3) run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage2+dsmem)", c, stages=2) + run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4) + run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3) + run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2) if args.enable_mma_all: # more mma kernel tests. run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+rr)", c, stages=4) run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+rr)", c, stages=3) @@ -373,6 +376,9 @@ def get_mnk(sep: int = args.SEP): run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4, swizzle=True) run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) + run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage4+dsmem+swizzle)", c, stages=4, swizzle=True) + run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage3+dsmem+swizzle)", c, stages=3, swizzle=True) + run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle, a, b, "(mma2x4+warp4x4x2+stage2+dsmem+swizzle)", c, stages=2, swizzle=True) if args.enable_mma_all: run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage3+swizzle)", c, stages=3, swizzle=True) run_benchmark(hgemm.hgemm_mma_m16n8k16_mma2x4_warp4x4_stages, a, b, "(mma2x4+warp4x4+stage2+swizzle)", c, stages=2, swizzle=True) diff --git a/kernels/hgemm/makefile b/kernels/hgemm/makefile index 5518dfef..d0312cc3 100644 --- a/kernels/hgemm/makefile +++ b/kernels/hgemm/makefile @@ -7,11 +7,18 @@ default: nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.bin $(DEFAULT_FLAGS) nvcc cublas/hgemm_cublas.cu -o hgemm_cublas.bin $(DEFAULT_FLAGS) nvcc mma/hgemm_mma_stage.cu -o hgemm_mma_stage.bin $(DEFAULT_FLAGS) + nvcc mma/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.bin $(DEFAULT_FLAGS) cute_89: nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.89.bin $(DEFAULT_FLAGS_89) cute_89_debug: nvcc cutlass/hgemm_mma_stage_tn_cute.cu -o hgemm_cute.89.debug.bin $(DEFAULT_FLAGS_89) -DCUTE_HGEMM_DEBUG -Xcompiler "-Wno-format" mma_89: nvcc mma/hgemm_mma_stage.cu -o hgemm_mma_stage.89.bin $(DEFAULT_FLAGS_89) +mma_89_debug: + nvcc mma/hgemm_mma_stage.cu -o hgemm_mma_stage.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG +mma_89_swizzle: + nvcc mma/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.89.bin $(DEFAULT_FLAGS_89) +mma_89_swizzle_debug: + nvcc mma/hgemm_mma_stage_swizzle.cu -o hgemm_mma_stage_swizzle.89.debug.bin $(DEFAULT_FLAGS_89) -DHGEMM_MMA_DEBUG clean: rm -rf *.bin diff --git a/kernels/hgemm/mma/hgemm_mma.cu b/kernels/hgemm/mma/hgemm_mma.cu index 37fbd6e4..0e787483 100644 --- a/kernels/hgemm/mma/hgemm_mma.cu +++ b/kernels/hgemm/mma/hgemm_mma.cu @@ -313,8 +313,11 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4( constexpr int MMA_TILE_N = 4; constexpr int WARP_TILE_M = 4; constexpr int WARP_TILE_N = 4; - constexpr int A_PAD = 0; - constexpr int B_PAD = 16; + // bank conflicts free via pad = 8, reject fantasy, trust the profile. + // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin + // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin + constexpr int A_PAD = 8; + constexpr int B_PAD = 8; constexpr int NUM_THREADS= ( MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 diff --git a/kernels/hgemm/mma/hgemm_mma_stage.cu b/kernels/hgemm/mma/hgemm_mma_stage.cu index 52feac84..d021ba29 100644 --- a/kernels/hgemm/mma/hgemm_mma_stage.cu +++ b/kernels/hgemm/mma/hgemm_mma_stage.cu @@ -1974,13 +1974,11 @@ void lanunch_hgemm_mma_m16n8k16_nn( constexpr int WARP_TILE_M = 4; constexpr int WARP_TILE_N = 4; constexpr int WARP_TILE_K = 2; - // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts. - // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus, - // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD. - constexpr int A_PAD = 0; // 0,8,16 - constexpr int B_PAD = 16; // 0,8,16 + // bank conflicts free via pad = 8, reject fantasy, trust the profile. + // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.debug.89.bin + // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.debug.89.bin + constexpr int A_PAD = 8; // 0,8,16 + constexpr int B_PAD = 8; // 0,8,16 constexpr int NUM_THREADS= ( MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; @@ -1993,8 +1991,16 @@ void lanunch_hgemm_mma_m16n8k16_nn( LAUNCH_16816_STAGE_SWIZZLE_MMA2x4_WARP4x4x2_DSMEM_KERNEL(K_STAGE, BLOCK_SWIZZLE_STRIDE); } -int main() { +#ifdef HGEMM_MMA_DEBUG +#include +#endif + +int main(int argc, char *argv[]) { +#ifdef HGEMM_MMA_DEBUG + const int test_num = 1; +#else const int test_num = 64; +#endif int M_list[test_num]; int N_list[test_num]; int K_list[test_num]; @@ -2005,9 +2011,22 @@ int main() { K_list[i] = (i + 1) * 256; } - const int outer_repeat = 10, inner_repeat = 1; +#ifdef HGEMM_MMA_DEBUG + if (argc > 1) M_list[0] = std::stoi(argv[1]); + if (argc > 2) N_list[0] = std::stoi(argv[2]); + if (argc > 3) K_list[0] = std::stoi(argv[3]); +#endif + +#ifdef HGEMM_MMA_DEBUG + int outer_repeat = 1, inner_repeat = 1, warmup = 1; + if (argc > 4) warmup = std::stoi(argv[4]); + if (argc > 5) inner_repeat = std::stoi(argv[5]); +#else + int outer_repeat = 10, inner_repeat = 1, warmup = 1; +#endif printf("ALGO = MMA16816 HGEMM NN MMA=2x4 WARP=4x4x2 STAGES=2 BLOCK SWIZZLE=2048\n"); +#ifndef HGEMM_MMA_DEBUG for (int j = 0; j < 5; j++) { int M = M_list[j], N = N_list[j], K = K_list[j]; float max_error = gemm_error_check_nn( @@ -2016,6 +2035,7 @@ int main() { printf("M N K = %6d %6d %6d, ", M, N, K); printf("Max Error = %f\n", max_error); } +#endif for (int j = 0; j < test_num; j++) { int M = M_list[j], N = N_list[j], K = K_list[j]; @@ -2027,7 +2047,7 @@ int main() { for (int k = 0; k < outer_repeat; k++) { double this_sec = perf_gemm( lanunch_hgemm_mma_m16n8k16_nn<2, 2048>, - M, N, K, inner_repeat); + M, N, K, inner_repeat, warmup); max_sec = max(max_sec, this_sec); min_sec = min(min_sec, this_sec); total_sec += this_sec; @@ -2120,13 +2140,11 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages( constexpr int MMA_TILE_N = 4; constexpr int WARP_TILE_M = 4; constexpr int WARP_TILE_N = 4; - // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts. - // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus, - // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD. - constexpr int A_PAD = 0; // 0,8,16 - constexpr int B_PAD = 16; // 0,8,16 + // bank conflicts free via pad = 8, reject fantasy, trust the profile. + // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin + // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin + constexpr int A_PAD = 8; // 0,8,16 + constexpr int B_PAD = 8; // 0,8,16 constexpr int NUM_THREADS= ( MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; @@ -2250,13 +2268,11 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem( constexpr int MMA_TILE_N = 4; constexpr int WARP_TILE_M = 4; constexpr int WARP_TILE_N = 4; - // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts. - // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus, - // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD. - constexpr int A_PAD = 0; // 0,8,16 - constexpr int B_PAD = 16; // 0,8,16 + // bank conflicts free via pad = 8, reject fantasy, trust the profile. + // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin + // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin + constexpr int A_PAD = 8; // 0,8,16 + constexpr int B_PAD = 8; // 0,8,16 constexpr int NUM_THREADS= ( MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; @@ -2381,13 +2397,11 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem( constexpr int WARP_TILE_M = 4; constexpr int WARP_TILE_N = 4; constexpr int WARP_TILE_K = 2; - // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts. - // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus, - // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD. - constexpr int A_PAD = 0; // 0,8,16 - constexpr int B_PAD = 16; // 0,8,16 + // bank conflicts free via pad = 8, reject fantasy, trust the profile. + // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin + // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin + constexpr int A_PAD = 8; // 0,8,16 + constexpr int B_PAD = 8; // 0,8,16 constexpr int NUM_THREADS= ( MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; @@ -2513,13 +2527,11 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4( constexpr int WARP_TILE_M = 4; constexpr int WARP_TILE_N = 4; constexpr int WARP_TILE_K = 2; - // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts. - // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus, - // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD. - constexpr int A_PAD = 0; // 0,8,16 - constexpr int B_PAD = 16; // 0,8,16 + // bank conflicts free via pad = 8, reject fantasy, trust the profile. + // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin + // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin + constexpr int A_PAD = 8; // 0,8,16 + constexpr int B_PAD = 8; // 0,8,16 constexpr int NUM_THREADS= ( MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; @@ -2646,13 +2658,11 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr( constexpr int WARP_TILE_M = 4; constexpr int WARP_TILE_N = 4; constexpr int WARP_TILE_K = 2; - // s_a 4 ways bank conflicts within warp, after pad 8 -> 4 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 8 -> 8 ways bank conflicts. - // s_b 16 ways bank conflicts within warp, after pad 16 -> 4 ways bank conflicts. - // so, the best padding policy for s_a and s_b is A_PAD=0/8, B_PAD=16. Thus, - // improve B_PAD consume 8x~ less smem than A_PAD, 16xB_PAD vs 128xA_PAD. - constexpr int A_PAD = 0; // 0,8,16 - constexpr int B_PAD = 16; // 0,8,16 + // bank conflicts free via pad = 8, reject fantasy, trust the profile. + // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin + // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin + constexpr int A_PAD = 8; // 0,8,16 + constexpr int B_PAD = 8; // 0,8,16 constexpr int NUM_THREADS= ( MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; diff --git a/kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu b/kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu new file mode 100644 index 00000000..184571d3 --- /dev/null +++ b/kernels/hgemm/mma/hgemm_mma_stage_swizzle.cu @@ -0,0 +1,853 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace nvcuda; + +#define WARP_SIZE 32 +#define DEVICE_INLINE __device__ inline +#define HOST_DEVICE_INLINE __device__ __host__ inline +#define INT4(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4(value) (reinterpret_cast(&(value))[0]) +#define HALF2(value) (reinterpret_cast(&(value))[0]) +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) +#define LDST32BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST64BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) +// gmem -> smem +#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::) +#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::) +#define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n)) +// ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes. +#define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) +#define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) +// smem -> gmem: requires sm_90 or higher. +#define CP_ASYNC_BULK_COMMIT_GROUP() asm volatile("cp.async.bulk.commit_group;\n" ::) +#define CP_ASYNC_BULK_WAIT_ALL() asm volatile("cp.async.bulk.wait_all;\n" ::) +#define CP_ASYNC_BULK_WAIT_GROUP(n) asm volatile("cp.async.bulk.wait_group %0;\n" ::"n"(n)) +#define CP_ASYNC_BULK(dst, src, bytes) asm volatile("cp.async.bulk.global.shared::cta.bulk_group.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) +// ldmatrix +#define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +#define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +// stmatrix: requires sm_90 or higher. +#define STMATRIX_X1(addr, R) asm volatile("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R)) +#define STMATRIX_X2(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1)) +#define STMATRIX_X4(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3)) +#define STMATRIX_X1_T(addr, R) asm volatile("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" :: "r"(addr), "r"(R)) +#define STMATRIX_X2_T(addr, R0, R1) asm volatile("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" :: "r"(addr), "r"(R0), "r"(R1)) +#define STMATRIX_X4_T(addr, R0, R1, R2, R3) asm volatile("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" :: "r"(addr), "r"(R0), "r"(R1), "r"(R2), "r"(R3)) +// mma m16n8k16 +#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) + +HOST_DEVICE_INLINE +int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 1) : (a / b); } + +// i: row index; j: col index. +// e.g kColStride = 64, kStep = 8 -> load 8 half as 128 bits memory issue. +template +static __device__ __forceinline__ int swizzle_permuted_j(int i, int j) { + // ------------------------------------------- + // --------------swizzle layout--------------- + // -------------col 0~64, step 8-------------- + // ------------------------------------------- + // | row 0 | (0, 8, 16, 24, 32, 40, 48, 56) | + // | row 1 | (0, 8, 16, 24, 32, 40, 48, 56) | + // | row 2 | (0, 8, 16, 24, 32, 40, 48, 56) | + // | row 3 | (0, 8, 16, 24, 32, 40, 48, 56) | + // ------------------------------------------- + // | row 4 | (8, 0, 24, 16, 40, 32, 56, 48) | + // | row 5 | (8, 0, 24, 16, 40, 32, 56, 48) | + // | row 6 | (8, 0, 24, 16, 40, 32, 56, 48) | + // | row 7 | (8, 0, 24, 16, 40, 32, 56, 48) | + // ------------------------------------------- + // | row 8 | (16, 24, 0, 8, 48, 56, 32, 40) | + // | row 9 | (16, 24, 0, 8, 48, 56, 32, 40) | + // | row 10 | (16, 24, 0, 8, 48, 56, 32, 40) | + // | row 11 | (16, 24, 0, 8, 48, 56, 32, 40) | + // ------------------------------------------- + // | row 12 | (24, 16, 8, 0, 56, 48, 40, 32) | + // | row 13 | (24, 16, 8, 0, 56, 48, 40, 32) | + // | row 14 | (24, 16, 8, 0, 56, 48, 40, 32) | + // | row 15 | (24, 16, 8, 0, 56, 48, 40, 32) | + // ------------------------------------------- + // swizzle: ((int(j / kStep) ^ int(i / 4)) % int(kColStride / kStep)) * kStep; + static_assert(kStep == 4 || kStep == 8, "kStep must be 8 or 4."); + static_assert(kColStride % kStep == 0, "kColStride must be multiple of kStep."); + if constexpr (kStep == 8) { + return (((j >> 3) ^ (i >> 2)) % (kColStride >> 3)) << 3; + } else { + static_assert(kStep == 4); + return (((j >> 2) ^ (i >> 2)) % (kColStride >> 2)) << 2; + } +} + +// i: row index; j: col index +template +static __device__ __forceinline__ int swizzle_permuted_A_j(int i, int j) { + // ------------------- + // -col 0~16, step 8-- + // ------------------- + // | row 0 | (0, 8) | + // | row 1 | (0, 8) | + // | row 2 | (0, 8) | + // | row 3 | (0, 8) | + // ------------------- + // | row 4 | (8, 0) | + // | row 5 | (8, 0) | + // | row 6 | (8, 0) | + // | row 7 | (8, 0) | + // ------------------- + // | row 8 | (0, 8) | + // | row 9 | (0, 8) | + // | row 10 | (0, 8) | + // | row 11 | (0, 8) | + // ------------------- + // | row 12 | (8, 0) | + // | row 13 | (8, 0) | + // | row 14 | (8, 0) | + // | row 15 | (8, 0) | + // ------------------- + return swizzle_permuted_j(i, j); +} + +// In order to reduce bank conflicts, we will save the K(16x2=32) +// dimension by half according to the stage dimension. For example, +// stages=3, warp_tile_k=2, it will be saved as [3*2][BM][16]. +// 128x128, mma2x4, warp4x4(64,32,32), stages, block swizzle, dsmem, +// k32 with reg double buffers +template +__global__ void __launch_bounds__(256) +hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel( + const half* __restrict__ A, const half* __restrict__ B, half* __restrict__ C, + int M, int N, int K) { + // BLOCK_SWIZZLE 0/1 control use block swizzle or not. + const int bx = ((int) BLOCK_SWIZZLE) * blockIdx.z * gridDim.x + blockIdx.x; + const int by = blockIdx.y; + const int NUM_K_TILES = div_ceil(K, MMA_K * WARP_TILE_K); + constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128 + constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128 + constexpr int BK = MMA_K; // 16x2=32 + + extern __shared__ half smem[]; + half* s_a = smem; + half* s_b = smem + K_STAGE * BM * (BK + A_PAD) * WARP_TILE_K; + constexpr int s_a_stage_offset = BM * (BK + A_PAD); // 128x16 + constexpr int s_b_stage_offset = BK * (BN + B_PAD); // 16x128 + constexpr int s_a_mma_k_store_offset = K_STAGE * BM * (BK + A_PAD); + constexpr int s_b_mma_k_store_offset = K_STAGE * BK * (BN + B_PAD); + + const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int lane_id = tid % WARP_SIZE; // 0~31 + const int warp_m = warp_id % 2; // 0,1 + const int warp_n = warp_id / 2; // 0,1,2,3 + + int load_smem_a_m = tid / 2; // row 0~127 + int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8 + int load_smem_b_k = tid / 16; // row 0~15 + int load_smem_b_n = (tid % 16) * 8; // col 0,8,16,... + int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c + int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c + if (load_gmem_a_m >= M || load_gmem_b_n >= N) return; + + uint32_t RC[WARP_TILE_M][WARP_TILE_N][2]; + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + RC[i][j][0] = 0; + RC[i][j][1] = 0; + } + } + + uint32_t smem_a_base_ptr = __cvta_generic_to_shared(s_a); + uint32_t smem_b_base_ptr = __cvta_generic_to_shared(s_b); + + #pragma unroll + for (int k = 0; k < (K_STAGE - 1); ++k) { // 0, 1 + // k * WMMA_K, WMMA_K=16 -> (k << 4) + int load_gmem_a_k = k * BK * WARP_TILE_K + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * BK * WARP_TILE_K + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + + uint32_t load_smem_a_ptr = ( + smem_a_base_ptr + (k * s_a_stage_offset + + load_smem_a_m * (BK + A_PAD) + + swizzle_permuted_A_j( + load_smem_a_m, load_smem_a_k)) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); // MMA_K 0 + uint32_t load_smem_a_mma_k_ptr = ( + smem_a_base_ptr + s_a_mma_k_store_offset * sizeof(half) + + (k * s_a_stage_offset + load_smem_a_m * (BK + A_PAD) + + swizzle_permuted_A_j(load_smem_a_m, load_smem_a_k)) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_a_mma_k_ptr, &A[load_gmem_a_addr + 16], 16); // MMA_K 1 + + uint32_t load_smem_b_ptr = ( + smem_b_base_ptr + (k * s_b_stage_offset + + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); + + int load_gmem_b_k_mma_k = k * BK * WARP_TILE_K + MMA_K + load_smem_b_k; + int load_gmem_b_addr_mma_k = load_gmem_b_k_mma_k * N + load_gmem_b_n; + uint32_t load_smem_b_mma_k_ptr = ( + smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) + + (k * s_b_stage_offset + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_b_mma_k_ptr, &B[load_gmem_b_addr_mma_k], 16); + + CP_ASYNC_COMMIT_GROUP(); + } + + CP_ASYNC_WAIT_GROUP(K_STAGE-2); // s2->0, s3->1, s4->2 + __syncthreads(); + + uint32_t RA[2][WARP_TILE_M][4]; + uint32_t RB[2][WARP_TILE_N][2]; + + int reg_store_idx = 0; + int reg_load_idx = 1; + + { + // ldmatrix for s_a, ldmatrix.trans for s_b. + // smem -> reg buffers 0, first MMA_K, 0~15 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 + int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_a_ptr = ( + smem_a_base_ptr + + (0 * s_a_stage_offset + lane_smem_a_m * (BK + A_PAD) + + swizzle_permuted_A_j(lane_smem_a_m, lane_smem_a_k)) * sizeof(half) + ); + LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1], + RA[reg_store_idx][i][2], RA[reg_store_idx][i][3], + lane_smem_a_ptr); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + int lane_smem_b_k = lane_id % 16; // 0~15, 0~15 + int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 + uint32_t lane_smem_b_ptr = ( + smem_b_base_ptr + + (0 * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) + + lane_smem_b_n) * sizeof(half) + ); + // may use .x4.trans to load 4 matrix for reg double buffers at once? + LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1], + lane_smem_b_ptr); + } + } + + #pragma unroll + for (int k = (K_STAGE - 1); k < NUM_K_TILES; ++k) { + reg_store_idx ^= 1; // 0->1 + reg_load_idx ^= 1; // 1->0 + int smem_sel = (k + 1) % K_STAGE; // s3 k 2->0, k 3->1, k 4->2... + int smem_sel_next = k % K_STAGE; // s3 k 2->2, k 3->0, k 4->1... + + // stage gmem -> smem + int load_gmem_a_k = k * BK * WARP_TILE_K + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * BK * WARP_TILE_K + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + + uint32_t load_smem_a_ptr = ( + smem_a_base_ptr + (smem_sel_next * s_a_stage_offset + + load_smem_a_m * (BK + A_PAD) + + swizzle_permuted_A_j( + load_smem_a_m, load_smem_a_k)) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_a_ptr, &A[load_gmem_a_addr], 16); // MMA_K 0 + uint32_t load_smem_a_mma_k_ptr = ( + smem_a_base_ptr + s_a_mma_k_store_offset * sizeof(half) + + (smem_sel_next * s_a_stage_offset + load_smem_a_m * (BK + A_PAD) + + swizzle_permuted_A_j(load_smem_a_m, load_smem_a_k)) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_a_mma_k_ptr, &A[load_gmem_a_addr + 16], 16); // MMA_K 1 + + uint32_t load_smem_b_ptr = ( + smem_b_base_ptr + (smem_sel_next * s_b_stage_offset + + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_b_ptr, &B[load_gmem_b_addr], 16); + + int load_gmem_b_k_mma_k = k * BK * WARP_TILE_K + MMA_K + load_smem_b_k; + int load_gmem_b_addr_mma_k = load_gmem_b_k_mma_k * N + load_gmem_b_n; + uint32_t load_smem_b_mma_k_ptr = ( + smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) + + (smem_sel_next * s_b_stage_offset + load_smem_b_k * (BN + B_PAD) + + load_smem_b_n) * sizeof(half) + ); + CP_ASYNC_CG(load_smem_b_mma_k_ptr, &B[load_gmem_b_addr_mma_k], 16); + CP_ASYNC_COMMIT_GROUP(); + + // ldmatrix for s_a, ldmatrix.trans for s_b. + // smem -> reg buffers 1, second MMA_K, 16~31 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 + int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_a_ptr = ( + smem_a_base_ptr + s_a_mma_k_store_offset * sizeof(half) + + (smem_sel * s_a_stage_offset + lane_smem_a_m * (BK + A_PAD) + + swizzle_permuted_A_j(lane_smem_a_m, lane_smem_a_k)) * sizeof(half) + ); + LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1], + RA[reg_store_idx][i][2], RA[reg_store_idx][i][3], + lane_smem_a_ptr); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + int lane_smem_b_k = lane_id % 16; // 0~15 + int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 + uint32_t lane_smem_b_ptr = ( + smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) + + (smem_sel * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) + + lane_smem_b_n) * sizeof(half) + ); + // may use .x4.trans to load 4 matrix for reg double buffers at once? + LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1], + lane_smem_b_ptr); + } + + // MMA compute, first MMA_K + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + // Warp swizzle: Right -> Left -> Right -> Left + int j_s = ((i % 2) && WARP_SWIZZLE)? (WARP_TILE_N - j - 1) : j; + HMMA16816(RC[i][j_s][0], RC[i][j_s][1], + RA[reg_load_idx][i][0], RA[reg_load_idx][i][1], + RA[reg_load_idx][i][2], RA[reg_load_idx][i][3], + RB[reg_load_idx][j_s][0], RB[reg_load_idx][j_s][1], + RC[i][j_s][0], RC[i][j_s][1]); + } + } + + reg_store_idx ^= 1; // 1 -> 0 + reg_load_idx ^= 1; // 0 -> 1 + // MMA compute, second MMA_K + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + // Warp swizzle: Right -> Left -> Right -> Left + int j_s = ((i % 2) && WARP_SWIZZLE)? (WARP_TILE_N - j - 1) : j; + HMMA16816(RC[i][j_s][0], RC[i][j_s][1], + RA[reg_load_idx][i][0], RA[reg_load_idx][i][1], + RA[reg_load_idx][i][2], RA[reg_load_idx][i][3], + RB[reg_load_idx][j_s][0], RB[reg_load_idx][j_s][1], + RC[i][j_s][0], RC[i][j_s][1]); + } + } + + CP_ASYNC_WAIT_GROUP(K_STAGE-2); + __syncthreads(); + + // load next k iters to reg buffers. + // smem -> reg buffers 0, first MMA_K, 0~15 + // int smem_sel_reg = (k + 2) % K_STAGE; // vs smem_sel k=2->(0)1, k=3->(1)2 + int smem_sel_reg = (smem_sel + 1) % K_STAGE; // vs smem_sel k=2->(0)1, k=3->(1)2 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 + int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_a_ptr = ( + smem_a_base_ptr + (smem_sel_reg * s_a_stage_offset + + lane_smem_a_m * (BK + A_PAD) + + swizzle_permuted_A_j( + lane_smem_a_m, lane_smem_a_k)) * sizeof(half) + ); + LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1], + RA[reg_store_idx][i][2], RA[reg_store_idx][i][3], + lane_smem_a_ptr); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + int lane_smem_b_k = lane_id % 16; // 0~15, 0~15 + int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 + uint32_t lane_smem_b_ptr = ( + smem_b_base_ptr + (smem_sel_reg * s_b_stage_offset + + lane_smem_b_k * (BN + B_PAD) + + lane_smem_b_n) * sizeof(half) + ); + // may use .x4.trans to load 4 matrix for reg double buffers at once? + LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1], + lane_smem_b_ptr); + } + } + + // make sure all memory issues ready. + if constexpr ((K_STAGE - 2) > 0) { + CP_ASYNC_WAIT_GROUP(0); + __syncthreads(); + } + + // processing last (K_STAGE-1) k iters. + { + #pragma unroll + for (int k = 0; k < (K_STAGE - 1); k++) { + reg_store_idx ^= 1; // 0->1 + reg_load_idx ^= 1; // 1->0 + + int stage_sel = ((NUM_K_TILES - (K_STAGE - 1) + k) % K_STAGE); + // ldmatrix for s_a, ldmatrix.trans for s_b. + // smem -> reg buffers 1, second MMA_K + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 + int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_a_ptr = ( + smem_a_base_ptr + s_a_mma_k_store_offset * sizeof(half) + + (stage_sel * s_a_stage_offset + lane_smem_a_m * (BK + A_PAD) + + swizzle_permuted_A_j(lane_smem_a_m, lane_smem_a_k)) * sizeof(half) + ); + LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1], + RA[reg_store_idx][i][2], RA[reg_store_idx][i][3], + lane_smem_a_ptr); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + int lane_smem_b_k = lane_id % 16; // 0~15 + int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 + uint32_t lane_smem_b_ptr = ( + smem_b_base_ptr + s_b_mma_k_store_offset * sizeof(half) + + (stage_sel * s_b_stage_offset + lane_smem_b_k * (BN + B_PAD) + + lane_smem_b_n) * sizeof(half) + ); + LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1], + lane_smem_b_ptr); + } + + // MMA compute, first MMA_K + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + // Warp swizzle: Right -> Left -> Right -> Left + int j_s = ((i % 2) && WARP_SWIZZLE)? (WARP_TILE_N - j - 1) : j; + HMMA16816(RC[i][j_s][0], RC[i][j_s][1], + RA[reg_load_idx][i][0], RA[reg_load_idx][i][1], + RA[reg_load_idx][i][2], RA[reg_load_idx][i][3], + RB[reg_load_idx][j_s][0], RB[reg_load_idx][j_s][1], + RC[i][j_s][0], RC[i][j_s][1]); + } + } + + reg_store_idx ^= 1; // 1 -> 0 + reg_load_idx ^= 1; // 0 -> 1 + + // MMA compute, second MMA_K + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + // Warp swizzle: Right -> Left -> Right -> Left + int j_s = ((i % 2) && WARP_SWIZZLE)? (WARP_TILE_N - j - 1) : j; + HMMA16816(RC[i][j_s][0], RC[i][j_s][1], + RA[reg_load_idx][i][0], RA[reg_load_idx][i][1], + RA[reg_load_idx][i][2], RA[reg_load_idx][i][3], + RB[reg_load_idx][j_s][0], RB[reg_load_idx][j_s][1], + RC[i][j_s][0], RC[i][j_s][1]); + } + } + + // load next k iters to reg buffers. + // smem -> reg buffers 0, first MMA_K, 0~15 + // int stage_sel_reg = ((NUM_K_TILES - K_STAGE + k) % K_STAGE); + int stage_sel_reg = (stage_sel + 1) % K_STAGE; + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 + int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 + uint32_t lane_smem_a_ptr = ( + smem_a_base_ptr + (stage_sel_reg * s_a_stage_offset + + lane_smem_a_m * (BK + A_PAD) + + swizzle_permuted_A_j( + lane_smem_a_m, lane_smem_a_k)) * sizeof(half) + ); + LDMATRIX_X4(RA[reg_store_idx][i][0], RA[reg_store_idx][i][1], + RA[reg_store_idx][i][2], RA[reg_store_idx][i][3], + lane_smem_a_ptr); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + int lane_smem_b_k = lane_id % 16; // 0~15, 0~15 + int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 + uint32_t lane_smem_b_ptr = ( + smem_b_base_ptr + (stage_sel_reg * s_b_stage_offset + + lane_smem_b_k * (BN + B_PAD) + + lane_smem_b_n) * sizeof(half) + ); + LDMATRIX_X2_T(RB[reg_store_idx][j][0], RB[reg_store_idx][j][1], + lane_smem_b_ptr); + } + } + } + + // collective store with reg reuse & warp shuffle + for (int i = 0; i < WARP_TILE_M; ++i) { + // reuse RA[2][4][4] reg here, this may boost 0.3~0.5 TFLOPS up. + // may not put 'if' in N loop, it will crash the 'pragma unroll' hint ? + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + // How to use LDST128BITS here? __shfl_sync -> lane 0 -> store 8 half. + // thus, we only need 8 memory issues with 128 bits after shfl_sync. + RA[0][j][0] = RC[i][j][0]; + RA[1][j][0] = RC[i][j][1]; + RA[0][j][1] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 1); + RA[0][j][2] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 2); + RA[0][j][3] = __shfl_sync((0xffffffff), RC[i][j][0], lane_id + 3); + RA[1][j][1] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 1); + RA[1][j][2] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 2); + RA[1][j][3] = __shfl_sync((0xffffffff), RC[i][j][1], lane_id + 3); + } + + if (lane_id % 4 == 0) { + int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4; + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n; + int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n; + int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n; + LDST128BITS(C[store_gmem_c_addr_0]) = LDST128BITS(RA[0][j][0]); + LDST128BITS(C[store_gmem_c_addr_1]) = LDST128BITS(RA[1][j][0]); + } + } + } +} + +// build cpp binary +#ifndef NO_MMA_HGEMM_BIN + +#include "utils.h" + +// 128x128, mma2x4, warp4x4x2(64,32,32), stages, block&smem swizzle, dsmem, reg double buffers +#define LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(stages, stride) \ +{ \ + const int smem_max_size = ( \ + (stages) * BM * (BK + A_PAD) * WARP_TILE_K * sizeof(half) + \ + (stages) * BK * (BN + B_PAD) * WARP_TILE_K * sizeof(half)); \ + cudaFuncSetAttribute( \ + hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel< \ + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, A_PAD, B_PAD, (stages), true>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + 98304); \ + const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ + dim3 block(NUM_THREADS); \ + dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ + div_ceil(M, BM), \ + N_SWIZZLE); \ + hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel< \ + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, A_PAD, B_PAD, (stages), true><<< \ + grid, block, smem_max_size>>>( \ + a, b, c, \ + M, N, K \ + ); \ +} + +// 128x128, mma2x4, warp4x4x2(64,32,32), stages, block&smem swizzle, dsmem, reg double buffers +template +void lanunch_hgemm_mma_m16n8k16_swizzle_nn( + const half* a, const half* b, half* c, int M, int N, int K) { + constexpr int MMA_M = 16; + constexpr int MMA_N = 8; + constexpr int MMA_K = 16; + constexpr int MMA_TILE_M = 2; + constexpr int MMA_TILE_N = 4; + constexpr int WARP_TILE_M = 4; + constexpr int WARP_TILE_N = 4; + constexpr int WARP_TILE_K = 2; + // bank conflicts free via pad = 8, reject fantasy, trust the profile. + // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.debug.89.bin + // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.debug.89.bin + constexpr int A_PAD = 0; // apply smem swizzle + constexpr int B_PAD = 8; // 0,8,16 + constexpr int NUM_THREADS= ( + MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 + constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; + constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; + constexpr int BK = MMA_K; + // s2: 2*128*(32)*2=16KB, 2*32*(128+16)*2=18KB, ~35KB + // s3: 3*128*(32)*2=24KB, 3*32*(128+16)*2=27KB, ~51KB + // s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB + // s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB + LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(K_STAGE, BLOCK_SWIZZLE_STRIDE); +} + +#ifdef HGEMM_MMA_DEBUG +#include +#endif + +int main(int argc, char *argv[]) { +#ifdef HGEMM_MMA_DEBUG + const int test_num = 1; +#else + const int test_num = 64; +#endif + int M_list[test_num]; + int N_list[test_num]; + int K_list[test_num]; + + for (int i = 0; i < test_num; i++) { + M_list[i] = (i + 1) * 256; + N_list[i] = (i + 1) * 256; + K_list[i] = (i + 1) * 256; + } + +#ifdef HGEMM_MMA_DEBUG + if (argc > 1) M_list[0] = std::stoi(argv[1]); + if (argc > 2) N_list[0] = std::stoi(argv[2]); + if (argc > 3) K_list[0] = std::stoi(argv[3]); +#endif + +#ifdef HGEMM_MMA_DEBUG + int outer_repeat = 1, inner_repeat = 1, warmup = 1; + if (argc > 4) warmup = std::stoi(argv[4]); + if (argc > 5) inner_repeat = std::stoi(argv[5]); +#else + int outer_repeat = 10, inner_repeat = 1, warmup = 1; +#endif + + printf("ALGO = MMA16816 HGEMM NN MMA=2x4 WARP=4x4x2 STAGES=2 BLOCK SWIZZLE=2048 + A SMEM SWIZZLE\n"); +#ifndef HGEMM_MMA_DEBUG + for (int j = 0; j < 5; j++) { + int M = M_list[j], N = N_list[j], K = K_list[j]; + float max_error = gemm_error_check_nn( + lanunch_hgemm_mma_m16n8k16_swizzle_nn<2, 2048>, + M, N, K); + printf("M N K = %6d %6d %6d, ", M, N, K); + printf("Max Error = %f\n", max_error); + } +#endif + + for (int j = 0; j < test_num; j++) { + int M = M_list[j], N = N_list[j], K = K_list[j]; + + double max_sec = 0.0; + double min_sec = DBL_MAX; + double total_sec = 0.0; + + for (int k = 0; k < outer_repeat; k++) { + double this_sec = perf_gemm( + lanunch_hgemm_mma_m16n8k16_swizzle_nn<2, 2048>, + M, N, K, inner_repeat, warmup); + max_sec = max(max_sec, this_sec); + min_sec = min(min_sec, this_sec); + total_sec += this_sec; + } + + // 1 TFLOPS = 10^12 FLOPS + // ref: https://imgtec.eetrend.com/blog/2021/100062210.html. + double avg_sec = total_sec / outer_repeat; + double avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec; + + printf("M N K = %6d %6d %6d, ", M, N, K); + printf("Time = %12.8lf %12.8lf %12.8lf s, ", min_sec, avg_sec, max_sec); + printf("AVG Performance = %10.4lf Tflops\n", avg_Tflops); + } + + return 0; +} + +#else + +// --------------------- PyTorch bindings for custom kernel ----------------------- +#include +#include +#define STRINGFY(str) #str +#define TORCH_BINDING_COMMON_EXTENSION(func) \ + m.def(STRINGFY(func), &func, STRINGFY(func)); + +#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ +if(((T).options().dtype() != (th_type))) { \ + std::cout << "Tensor Info:" << (T).options() << std::endl; \ + throw std::runtime_error("values must be "#th_type); \ +} + +#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ +if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ + throw std::runtime_error("Tensor size mismatch!"); \ +} + +// 128x128, mma2x4, warp4x4x2(64,32,32), stages, block&smem swizzle, dsmem, reg double buffers +#define LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(stages, stride) \ +{ \ + const int smem_max_size = ( \ + (stages) * BM * (BK + A_PAD) * WARP_TILE_K * sizeof(half) + \ + (stages) * BK * (BN + B_PAD) * WARP_TILE_K * sizeof(half)); \ + cudaFuncSetAttribute( \ + hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel< \ + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, A_PAD, B_PAD, (stages), true>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + 98304); \ + const int N_SWIZZLE = (N + (stride) - 1) / (stride); \ + dim3 block(NUM_THREADS); \ + dim3 grid((div_ceil(N, BN) + N_SWIZZLE - 1) / N_SWIZZLE, \ + div_ceil(M, BM), \ + N_SWIZZLE); \ + hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel< \ + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, A_PAD, B_PAD, (stages), true><<< \ + grid, block, smem_max_size>>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +// no block swizzle, but have A smem swizzle +#define LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_NO_SWIZZLE_KERNEL(stages) \ +{ \ + const int smem_max_size = ( \ + (stages) * BM * (BK + A_PAD) * WARP_TILE_K * sizeof(half) + \ + (stages) * BK * (BN + B_PAD) * WARP_TILE_K * sizeof(half)); \ + cudaFuncSetAttribute( \ + hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel< \ + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, A_PAD, B_PAD, (stages), false>, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + 98304); \ + dim3 block(NUM_THREADS); \ + dim3 grid(div_ceil(N, BN), div_ceil(M, BM)); \ + hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle_kernel< \ + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, \ + WARP_TILE_M, WARP_TILE_N, WARP_TILE_K, A_PAD, B_PAD, (stages), false><<< \ + grid, block, smem_max_size>>>( \ + reinterpret_cast(a.data_ptr()), \ + reinterpret_cast(b.data_ptr()), \ + reinterpret_cast(c.data_ptr()), \ + M, N, K \ + ); \ +} + +// 128x128, mma2x4, warp4x4x2(64,32,32), stages, block swizzle, dsmem, reg double buffers +void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle( + torch::Tensor a, torch::Tensor b, torch::Tensor c, + int stages, bool swizzle, int swizzle_stride) { + CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) + CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) + const int M = a.size(0); + const int K = a.size(1); + const int N = b.size(1); + CHECK_TORCH_TENSOR_SHAPE(a, M, K) + CHECK_TORCH_TENSOR_SHAPE(b, K, N) + CHECK_TORCH_TENSOR_SHAPE(c, M, N) + constexpr int MMA_M = 16; + constexpr int MMA_N = 8; + constexpr int MMA_K = 16; + constexpr int MMA_TILE_M = 2; + constexpr int MMA_TILE_N = 4; + constexpr int WARP_TILE_M = 4; + constexpr int WARP_TILE_N = 4; + constexpr int WARP_TILE_K = 2; + // bank conflicts free via pad = 8, reject fantasy, trust the profile. + // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_stage.89.debug.bin + // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_stage.89.debug.bin + constexpr int A_PAD = 0; // apply smem swizzle + constexpr int B_PAD = 8; // 0,8,16 + constexpr int NUM_THREADS= ( + MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 + constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; + constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; + constexpr int BK = MMA_K; + // s2: 2*128*(32)*2=16KB, 2*32*(128+16)*2=18KB, ~35KB + // s3: 3*128*(32)*2=24KB, 3*32*(128+16)*2=27KB, ~51KB + // s4: 4*128*(32)*2=32KB, 4*32*(128+16)*2=36KB, ~68KB + // s5: 5*128*(32)*2=40KB, 5*32*(128+16)*2=45KB, ~85KB + if (swizzle) { + // assert(swizzle_stride % 256 == 0); + switch (stages) + { + case 2: // ~35KB + LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(2, swizzle_stride); + break; + case 3: // ~51KB + LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(3, swizzle_stride); + break; + case 4: // ~68KB + LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(4, swizzle_stride); + break; + case 5: // ~85KB + LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(5, swizzle_stride); + break; + default: + LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_SWIZZLE_KERNEL(2, swizzle_stride); + break; + } + } else { + switch (stages) + { + case 2: + LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_NO_SWIZZLE_KERNEL(2); + break; + case 3: + LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_NO_SWIZZLE_KERNEL(3); + break; + case 4: + LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_NO_SWIZZLE_KERNEL(4); + break; + case 5: + LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_NO_SWIZZLE_KERNEL(5); + break; + default: + LAUNCH_16816_STAGE_MMA2x4_WARP4x4x2_DSMEM_NO_SWIZZLE_KERNEL(2); + break; + } + } +} + +#endif diff --git a/kernels/hgemm/pybind/hgemm.cc b/kernels/hgemm/pybind/hgemm.cc index 36551d0b..0c14d3fd 100644 --- a/kernels/hgemm/pybind/hgemm.cc +++ b/kernels/hgemm/pybind/hgemm.cc @@ -49,6 +49,8 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr(torch::Tensor a, torch: void hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); // from hgemm_mma_stage_tn_cute.cu void hgemm_mma_stages_block_swizzle_tn_cute(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); +// from hgemm_mma_stage_swizzle.cu +void hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle(torch::Tensor a, torch::Tensor b, torch::Tensor c, int stages, bool swizzle, int swizzle_stride); PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { @@ -93,6 +95,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem) TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_x4) TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_rr) + // smem swizzle + TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4x2_stages_dsmem_swizzle) // TN: A row major MxK, B col major NxK, C row major MxN TORCH_BINDING_COMMON_EXTENSION(hgemm_mma_m16n8k16_mma2x4_warp4x4_stages_dsmem_tn) // TN: cute hgemm with smem & block swizzle diff --git a/kernels/hgemm/tools/utils.py b/kernels/hgemm/tools/utils.py index 4f0456e9..7f45ccb2 100644 --- a/kernels/hgemm/tools/utils.py +++ b/kernels/hgemm/tools/utils.py @@ -24,6 +24,7 @@ def get_build_sources(): build_sources.append('wmma/hgemm_wmma_stage.cu') build_sources.append('mma/hgemm_mma.cu') build_sources.append('mma/hgemm_mma_stage.cu') + build_sources.append('mma/hgemm_mma_stage_swizzle.cu') build_sources.append('mma/hgemm_mma_stage_tn.cu') build_sources.append('cutlass/hgemm_mma_stage_tn_cute.cu') build_sources.append('pybind/hgemm.cc') diff --git a/kernels/hgemm/utils/utils.h b/kernels/hgemm/utils/utils.h index 86d0d617..23a501dd 100644 --- a/kernels/hgemm/utils/utils.h +++ b/kernels/hgemm/utils/utils.h @@ -6,7 +6,7 @@ template float perf_gemm( void (*gpu_hgemm) (const T *, const T *, T *, int, int, int), - int M, int N, int K, int repeat) { + int M, int N, int K, int repeat, int warmup = 1) { size_t size_a = M * K * sizeof(T); size_t size_b = K * N * sizeof(T); @@ -19,7 +19,7 @@ float perf_gemm( cudaMalloc(&d_c, size_c); // warmup - for (int i = 0; i < 10; ++i){ + for (int i = 0; i < warmup; ++i){ gpu_hgemm(d_a, d_b, d_c, M, N, K); } cudaDeviceSynchronize(); @@ -52,7 +52,7 @@ float perf_gemm( template float perf_gemm_swizzle( void (*gpu_hgemm) (const T *, const T *, T *, int, int, int, int), - int M, int N, int K, int swizzle_stride, int repeat) { + int M, int N, int K, int swizzle_stride, int repeat, int warmup = 1) { size_t size_a = M * K * sizeof(T); size_t size_b = K * N * sizeof(T); @@ -65,7 +65,7 @@ float perf_gemm_swizzle( cudaMalloc(&d_c, size_c); // warmup - for (int i = 0; i < 10; ++i){ + for (int i = 0; i < warmup; ++i){ gpu_hgemm(d_a, d_b, d_c, M, N, K, swizzle_stride); } cudaDeviceSynchronize(); diff --git a/kernels/swizzle/README.md b/kernels/swizzle/README.md new file mode 100644 index 00000000..7faaef39 --- /dev/null +++ b/kernels/swizzle/README.md @@ -0,0 +1,125 @@ +# 📖 Learn how to apply SMEM Swizzle for bank conflicts free + +## 📚 build bin + +```bash +make +``` + +## 📚 ncu profile + +Achieve 0 bank conflicts for LDSM via smem swizzle. + +```bash +ncu --metrics l1tex__data_bank_reads ./mat_trans_swizzle.bin +ncu --metrics l1tex__data_bank_writes ./mat_trans_swizzle.bin +ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./mat_trans_swizzle.bin +ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_st ./mat_trans_swizzle.bin + +ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_swizzle.bin 1024 1024 1024 0 1 +ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_swizzle.bin 1024 1024 1024 0 1 +``` + +log: (achieve 0 bank conflicts for LDSM via smem swizzle) + +```bash +ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_swizzle.bin 1024 1024 1024 0 1 +[1542675] hgemm_mma_swizzle.bin@127.0.0.1 + void hgemm_mma_m16n8k16_naive_kernel<16, 8, 16>(__half *, __half *, __half *, int, int, int) (128, 64, 1)x(32, 1, 1), Context 1, Stream 7, Device 0, CC 8.9 + Section: Command line profiler metrics + ------------------------------------------------------------------ ----------- ------------ + Metric Name Metric Unit Metric Value + ------------------------------------------------------------------ ----------- ------------ + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg 22795.13 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max 24576 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min 18432 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum 2097152 + ------------------------------------------------------------------ ----------- ------------ + + void hgemm_mma_m16n8k16_naive_smem_swizzle_kernel<16, 8, 16>(__half *, __half *, __half *, int, int, int) (128, 64, 1)x(32, 1, 1), Context 1, Stream 7, Device 0, CC 8.9 + Section: Command line profiler metrics + ------------------------------------------------------------------ ----------- ------------ + Metric Name Metric Unit Metric Value + ------------------------------------------------------------------ ----------- ------------ + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg 0 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max 0 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min 0 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum 0 + ------------------------------------------------------------------ ----------- ------------ + + void hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel<16, 8, 16, 2, 4, 4, 4, 0, 0>(__half *, __half *, __half *, int, int, int) (8, 8, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 8.9 + Section: Command line profiler metrics + ------------------------------------------------------------------ ----------- ------------ + Metric Name Metric Unit Metric Value + ------------------------------------------------------------------ ----------- ------------ + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg 25644.52 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max 36864 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min 0 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum 2359296 + ------------------------------------------------------------------ ----------- ------------ + + void hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle_kernel<16, 8, 16, 2, 4, 4, 4, 0, 8>(__half *, __half *, __half *, int, int, int) (8, 8, 1)x(256, 1, 1), Context 1, Stream 7, Device 0, CC 8.9 + Section: Command line profiler metrics + ------------------------------------------------------------------ ----------- ------------ + Metric Name Metric Unit Metric Value + ------------------------------------------------------------------ ----------- ------------ + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.avg 0 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.max 0 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.min 0 + sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm.sum 0 + ------------------------------------------------------------------ ----------- ------------ +``` +## 📚 print swizzle layout +```bash +python3 print_swizzle_layout.py --col 64 +------------------------------------------- +--------------swizzle layout--------------- +-------------col 0~64, step 8-------------- +------------------------------------------- +| row 0 | (0, 8, 16, 24, 32, 40, 48, 56) | +| row 1 | (0, 8, 16, 24, 32, 40, 48, 56) | +| row 2 | (0, 8, 16, 24, 32, 40, 48, 56) | +| row 3 | (0, 8, 16, 24, 32, 40, 48, 56) | +------------------------------------------- +| row 4 | (8, 0, 24, 16, 40, 32, 56, 48) | +| row 5 | (8, 0, 24, 16, 40, 32, 56, 48) | +| row 6 | (8, 0, 24, 16, 40, 32, 56, 48) | +| row 7 | (8, 0, 24, 16, 40, 32, 56, 48) | +------------------------------------------- +| row 8 | (16, 24, 0, 8, 48, 56, 32, 40) | +| row 9 | (16, 24, 0, 8, 48, 56, 32, 40) | +| row 10 | (16, 24, 0, 8, 48, 56, 32, 40) | +| row 11 | (16, 24, 0, 8, 48, 56, 32, 40) | +------------------------------------------- +| row 12 | (24, 16, 8, 0, 56, 48, 40, 32) | +| row 13 | (24, 16, 8, 0, 56, 48, 40, 32) | +| row 14 | (24, 16, 8, 0, 56, 48, 40, 32) | +| row 15 | (24, 16, 8, 0, 56, 48, 40, 32) | +------------------------------------------- + +python3 print_swizzle_layout.py --col 16 +------------------- +--swizzle layout--- +-col 0~16, step 8-- +------------------- +| row 0 | (0, 8) | +| row 1 | (0, 8) | +| row 2 | (0, 8) | +| row 3 | (0, 8) | +------------------- +| row 4 | (8, 0) | +| row 5 | (8, 0) | +| row 6 | (8, 0) | +| row 7 | (8, 0) | +------------------- +| row 8 | (0, 8) | +| row 9 | (0, 8) | +| row 10 | (0, 8) | +| row 11 | (0, 8) | +------------------- +| row 12 | (8, 0) | +| row 13 | (8, 0) | +| row 14 | (8, 0) | +| row 15 | (8, 0) | +------------------- +``` diff --git a/kernels/swizzle/hgemm_mma_swizzle.cu b/kernels/swizzle/hgemm_mma_swizzle.cu index 37fbd6e4..cbb4f068 100644 --- a/kernels/swizzle/hgemm_mma_swizzle.cu +++ b/kernels/swizzle/hgemm_mma_swizzle.cu @@ -1,15 +1,15 @@ #include #include +#include #include #include #include +#include #include #include #include #include #include -#include -#include using namespace nvcuda; #define WARP_SIZE 32 @@ -249,63 +249,270 @@ hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel( } } +// i: row index; j: col index +__device__ __host__ __forceinline__ int swizzle_A_j(int i, int j) { + // >>> sw(0,0),sw(0,8),sw(1,0),sw(1,8),sw(2,0),sw(2,8),sw(3,0),sw(3,8) + // (0, 8, 0, 8, 0, 8, 0, 8) + // >>> sw(4,0),sw(4,8),sw(5,0),sw(5,8),sw(6,0),sw(6,8),sw(7,0),sw(7,8) + // (8, 0, 8, 0, 8, 0, 8, 0) + // >>> sw(8,0),sw(8,8),sw(9,0),sw(9,8),sw(10,0),sw(10,8),sw(11,0),sw(11,8) + // (0, 8, 0, 8, 0, 8, 0, 8) + // >>> sw(12,0),sw(12,8),sw(13,0),sw(13,8),sw(14,0),sw(14,8),sw(15,0),sw(15,8) + // (8, 0, 8, 0, 8, 0, 8, 0) + return ((int(j / 8) ^ int(i / 4)) % 2) * 8; +} + + +// TODO: hgemm_mma_m16n8k16_naive_smem_swizzle_kernel +// only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major. +template +__global__ void hgemm_mma_m16n8k16_naive_smem_swizzle_kernel( + half* A, half* B, half* C, int M, int N, int K) { + const int bx = blockIdx.x; + const int by = blockIdx.y; + const int NUM_K_TILES = div_ceil(K, MMA_K); + constexpr int BM = MMA_M; // 16 + constexpr int BN = MMA_N; // 8 + constexpr int BK = MMA_K; // 16 + + __shared__ half s_a[MMA_M][MMA_K]; // 16x16 + __shared__ half s_b[MMA_K][MMA_N]; // 16x8 + __shared__ half s_c[MMA_M][MMA_N]; // 16x8 + + const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block + const int lane_id = tid % WARP_SIZE; // 0~31 + + // s_a[16][16], 每行16,每线程load 8,需要2线程,共16行,需2x16=32线程 + const int load_smem_a_m = tid / 2; // row 0~15 + const int load_smem_a_k = (tid % 2) * 8; // col 0,8 + // s_b[16][8], 每行8,每线程load 8,需要1线程,共16行,需16线程,只需一半线程加载 + const int load_smem_b_k = tid; // row 0~31, but only use 0~15 + const int load_smem_b_n = 0; // col 0 + const int load_gmem_a_m = by * BM + load_smem_a_m; // global m + const int load_gmem_b_n = bx * BN + load_smem_b_n; // global n + if (load_gmem_a_m >= M && load_gmem_b_n >= N) return; + + uint32_t RC[2] = {0, 0}; -// --------------------- PyTorch bindings for custom kernel ----------------------- -#define STRINGFY(str) #str -#define TORCH_BINDING_COMMON_EXTENSION(func) \ - m.def(STRINGFY(func), &func, STRINGFY(func)); + #pragma unroll + for (int k = 0; k < NUM_K_TILES; ++k) { + // gmem_a -> smem_a + int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + // LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = ( + // LDST128BITS(A[load_gmem_a_addr])); + LDST128BITS(s_a[load_smem_a_m][swizzle_A_j( + load_smem_a_m, load_smem_a_k)]) = (LDST128BITS(A[load_gmem_a_addr])); + + // gmem_b -> smem_b + if (lane_id < MMA_K) { + int load_gmem_b_k = k * MMA_K + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = ( + LDST128BITS(B[load_gmem_b_addr])); + } + __syncthreads(); -#define CHECK_TORCH_TENSOR_DTYPE(T, th_type) \ -if(((T).options().dtype() != (th_type))) { \ - std::cout << "Tensor Info:" << (T).options() << std::endl; \ - throw std::runtime_error("values must be "#th_type); \ + uint32_t RA[4]; + uint32_t RB[2]; + + // ldmatrix for s_a, ldmatrix.trans for s_b. + // s_a: (0,1)*8 -> 0,8 -> [(0~15),(0,8)] + // uint32_t load_smem_a_ptr = __cvta_generic_to_shared( + // &s_a[lane_id % 16][(lane_id / 16) * 8]); + uint32_t load_smem_a_ptr = __cvta_generic_to_shared( + &s_a[lane_id % 16][swizzle_A_j(lane_id % 16, (lane_id / 16) * 8)]); + LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], load_smem_a_ptr); + uint32_t load_smem_b_ptr = __cvta_generic_to_shared( + &s_b[lane_id % 16][0]); + LDMATRIX_X2_T(RB[0], RB[1], load_smem_b_ptr); + + HMMA16816(RC[0], RC[1], RA[0], RA[1], RA[2], RA[3], RB[0], RB[1], RC[0], RC[1]); + + __syncthreads(); + } + + // s_c[16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16] + LDST32BITS(s_c[lane_id / 4 ][(lane_id % 4) * 2]) = LDST32BITS(RC[0]); + LDST32BITS(s_c[lane_id / 4 + 8][(lane_id % 4) * 2]) = LDST32BITS(RC[1]); + + __syncthreads(); + + // store s_c[16][8] + if (lane_id < MMA_M) { + // store 128 bits per memory issue. + int store_gmem_c_m = by * BM + lane_id; + int store_gmem_c_n = bx * BN; + int store_gmem_c_addr = store_gmem_c_m * N + store_gmem_c_n; + LDST128BITS(C[store_gmem_c_addr]) = (LDST128BITS(s_c[lane_id][0])); + } } -#define CHECK_TORCH_TENSOR_SHAPE(T, S0, S1) \ -if (((T).size(0) != (S0)) || ((T).size(1) != (S1))) { \ - throw std::runtime_error("Tensor size mismatch!"); \ +// 128x128, mma2x4, warp4x4(64,32,16) +template +__global__ void __launch_bounds__(256) +hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle_kernel( + half* A, half* B, half* C, int M, int N, int K) { + const int bx = blockIdx.x; + const int by = blockIdx.y; + const int NUM_K_TILES = div_ceil(K, MMA_K); + constexpr int BM = MMA_M * MMA_TILE_M * WARP_TILE_M; // 16*2*4=128 + constexpr int BN = MMA_N * MMA_TILE_N * WARP_TILE_N; // 8*4*4=128 + constexpr int BK = MMA_K; // 16 + + __shared__ half s_a[BM][BK+A_PAD]; // 128*16*2=4KB + __shared__ half s_b[BK][BN+B_PAD]; // 16*128*2=4KB, 16*(128+16)*2=4.5KB + + const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block + const int warp_id = tid / WARP_SIZE; // 0~7 warp_id within block + const int lane_id = tid % WARP_SIZE; // 0~31 + const int warp_m = warp_id % 2; // 0,1 + const int warp_n = warp_id / 2; // 0,1,2,3 + + // 先计算shared memory中的索引 + // tid和需要加载的smem s_a[BM][BK] 之间的索引关系 BM=128 BK=16 按行读取 A行主序 + // 对于s_a每行16个数据,每个线程读取8个,需要2个线程;总共128行,需要128x2刚好256线程 + int load_smem_a_m = tid / 2; // row 0~127 + int load_smem_a_k = (tid % 2 == 0) ? 0 : 8; // col 0,8 + // tid和需要加载的smem s_b[BK][BN] 之间的索引关系 BK=16 BN=128 按行读取 B行主序 + // 对于s_b每行128个数据,每个线程读8个数据,需要16个线程;总共16行,需要16x16=256个线程 + int load_smem_b_k = tid / 16; // row 0~15 + int load_smem_b_n = (tid % 16) * 8; // col 0,8,...,120 + // 再计算全局内存中的索引 + // 要加载到s_a中的元素对应到A全局内存中的行数 每个block负责出C中大小为BM*BN的块 + int load_gmem_a_m = by * BM + load_smem_a_m; // global row of a and c + int load_gmem_b_n = bx * BN + load_smem_b_n; // global col of b and c + if (load_gmem_a_m >= M || load_gmem_b_n >= N) return; + + uint32_t RC[WARP_TILE_M][WARP_TILE_N][2]; + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + RC[i][j][0] = 0; + RC[i][j][1] = 0; + } + } + + #pragma unroll + for (int k = 0; k < NUM_K_TILES; ++k) { + // gmem -> smem + int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + int load_gmem_b_k = k * BK + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = ( + LDST128BITS(B[load_gmem_b_addr])); + // LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = ( + // LDST128BITS(A[load_gmem_a_addr])); + LDST128BITS(s_a[load_smem_a_m][swizzle_A_j( + load_smem_a_m, load_smem_a_k)]) = (LDST128BITS(A[load_gmem_a_addr])); + __syncthreads(); + + // ldmatrix for s_a, ldmatrix.trans for s_b. + uint32_t RA[WARP_TILE_M][4]; + uint32_t RB[WARP_TILE_N][2]; + + // smem -> reg + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + int warp_smem_a_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int lane_smem_a_m = warp_smem_a_m + lane_id % 16; // 0~15 + int lane_smem_a_k = (lane_id / 16) * 8; // 0,8 + // uint32_t lane_smem_a_ptr = __cvta_generic_to_shared( + // &s_a[lane_smem_a_m][lane_smem_a_k]); + uint32_t lane_smem_a_ptr = __cvta_generic_to_shared( + &s_a[lane_smem_a_m][swizzle_A_j(lane_smem_a_m, lane_smem_a_k)]); + LDMATRIX_X4(RA[i][0], RA[i][1], RA[i][2], RA[i][3], lane_smem_a_ptr); + } + + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int warp_smem_b_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + int lane_smem_b_k = lane_id % 16; // 0~15 + int lane_smem_b_n = warp_smem_b_n; // 0, MMA_N=8 + uint32_t lane_smem_b_ptr = __cvta_generic_to_shared( + &s_b[lane_smem_b_k][lane_smem_b_n]); + LDMATRIX_X2_T(RB[j][0], RB[j][1], lane_smem_b_ptr); + } + + // MMA compute + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + HMMA16816(RC[i][j][0], RC[i][j][1], + RA[i][0], RA[i][1], RA[i][2], RA[i][3], + RB[j][0], RB[j][1], + RC[i][j][0], RC[i][j][1]); + } + } + __syncthreads(); + } + + // reg -> gmem, MMA_MxMMA_N=16x8 + #pragma unroll + for (int i = 0; i < WARP_TILE_M; ++i) { + #pragma unroll + for (int j = 0; j < WARP_TILE_N; ++j) { + int store_warp_smem_c_m = warp_m * (MMA_M * WARP_TILE_M) + i * MMA_M; + int store_warp_smem_c_n = warp_n * (MMA_N * WARP_TILE_N) + j * MMA_N; + // mapping lane smem index -> global index. + // [16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16] + int store_lane_gmem_c_m = by * BM + store_warp_smem_c_m + lane_id / 4; + int store_lane_gmem_c_n = bx * BN + store_warp_smem_c_n + (lane_id % 4) * 2; + int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n; + int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n; + // TODO: how to use LDST128BITS here ? reverse the loop order ? + LDST32BITS(C[store_gmem_c_addr_0]) = LDST32BITS(RC[i][j][0]); + LDST32BITS(C[store_gmem_c_addr_1]) = LDST32BITS(RC[i][j][1]); + } + } } -// only 1 warp per block(32 threads), m16n8k16. A, B, C: all row_major. -void hgemm_mma_m16n8k16_naive( - torch::Tensor a, torch::Tensor b, torch::Tensor c) { - CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) - CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) - CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) - const int M = a.size(0); - const int K = a.size(1); - const int N = b.size(1); - CHECK_TORCH_TENSOR_SHAPE(a, M, K) - CHECK_TORCH_TENSOR_SHAPE(b, K, N) - CHECK_TORCH_TENSOR_SHAPE(c, M, N) +// launcher +void launch_hgemm_mma_m16n8k16_naive( + half* a, half* b, half* c, int M, int N, int K) { constexpr int MMA_M = 16; constexpr int MMA_N = 8; - constexpr int MMA_K = 16; - + constexpr int MMA_K = 16; dim3 block(WARP_SIZE); dim3 grid(div_ceil(N, MMA_N), div_ceil(M, MMA_M)); hgemm_mma_m16n8k16_naive_kernel< MMA_M, MMA_N, MMA_K><<>>( - reinterpret_cast(a.data_ptr()), - reinterpret_cast(b.data_ptr()), - reinterpret_cast(c.data_ptr()), - M, N, K + a, b, c, M, N, K ); } -// 128x128, mma2x4, warp4x4(64,32,16) -void hgemm_mma_m16n8k16_mma2x4_warp4x4( - torch::Tensor a, torch::Tensor b, torch::Tensor c) { - CHECK_TORCH_TENSOR_DTYPE(a, torch::kHalf) - CHECK_TORCH_TENSOR_DTYPE(b, torch::kHalf) - CHECK_TORCH_TENSOR_DTYPE(c, torch::kHalf) - const int M = a.size(0); - const int K = a.size(1); - const int N = b.size(1); - CHECK_TORCH_TENSOR_SHAPE(a, M, K) - CHECK_TORCH_TENSOR_SHAPE(b, K, N) - CHECK_TORCH_TENSOR_SHAPE(c, M, N) +void launch_hgemm_mma_m16n8k16_naive_smem_swizzle( + half* a, half* b, half* c, int M, int N, int K) { + constexpr int MMA_M = 16; + constexpr int MMA_N = 8; + constexpr int MMA_K = 16; + dim3 block(WARP_SIZE); + dim3 grid(div_ceil(N, MMA_N), div_ceil(M, MMA_M)); + + hgemm_mma_m16n8k16_naive_smem_swizzle_kernel< + MMA_M, MMA_N, MMA_K><<>>( + a, b, c, M, N, K + ); +} + +void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4( + half* a, half* b, half* c, int M, int N, int K) { constexpr int MMA_M = 16; constexpr int MMA_N = 8; constexpr int MMA_K = 16; @@ -313,21 +520,137 @@ void hgemm_mma_m16n8k16_mma2x4_warp4x4( constexpr int MMA_TILE_N = 4; constexpr int WARP_TILE_M = 4; constexpr int WARP_TILE_N = 4; + // bank conflicts free via pad = 8, 拒绝幻想,相信profile + // ncu --metrics l1tex__data_bank_conflicts_pipe_lsu_mem_shared_op_ld ./hgemm_mma_swizzle.bin + // ncu --metrics sm__sass_l1tex_data_bank_conflicts_pipe_lsu_mem_shared_op_ldsm ./hgemm_mma_swizzle.bin + // constexpr int A_PAD = 8; + // constexpr int B_PAD = 8; constexpr int A_PAD = 0; - constexpr int B_PAD = 16; + constexpr int B_PAD = 0; constexpr int NUM_THREADS= ( - MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 - + MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 dim3 block(NUM_THREADS); dim3 grid(div_ceil(N, MMA_N * MMA_TILE_N * WARP_TILE_N), div_ceil(M, MMA_M * MMA_TILE_M * WARP_TILE_M)); - + hgemm_mma_m16n8k16_mma2x4_warp4x4_kernel< MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, - WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD><<>>( - reinterpret_cast(a.data_ptr()), - reinterpret_cast(b.data_ptr()), - reinterpret_cast(c.data_ptr()), - M, N, K + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD><<< + grid, block>>>( + a, b, c, M, N, K ); } + +void launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle( + half* a, half* b, half* c, int M, int N, int K) { + constexpr int MMA_M = 16; + constexpr int MMA_N = 8; + constexpr int MMA_K = 16; + constexpr int MMA_TILE_M = 2; + constexpr int MMA_TILE_N = 4; + constexpr int WARP_TILE_M = 4; + constexpr int WARP_TILE_N = 4; + constexpr int A_PAD = 0; + constexpr int B_PAD = 8; + constexpr int NUM_THREADS= ( + MMA_TILE_M * MMA_TILE_N * WARP_SIZE); // 2 * 4 * 32 = 256 + dim3 block(NUM_THREADS); + dim3 grid(div_ceil(N, MMA_N * MMA_TILE_N * WARP_TILE_N), + div_ceil(M, MMA_M * MMA_TILE_M * WARP_TILE_M)); + + hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle_kernel< + MMA_M, MMA_N, MMA_K, MMA_TILE_M, MMA_TILE_N, + WARP_TILE_M, WARP_TILE_N, A_PAD, B_PAD><<< + grid, block>>>( + a, b, c, M, N, K + ); +} + +template +float perf_gemm( + void (*gpu_hgemm) (T *, T *, T *, int, int, int), + int M, int N, int K, int warmup, int repeat) { + + size_t size_a = M * K * sizeof(T); + size_t size_b = K * N * sizeof(T); + size_t size_c = M * N * sizeof(T); + + T *d_a, *d_b; + T *d_c; + cudaMalloc(&d_a, size_a); + cudaMalloc(&d_b, size_b); + cudaMalloc(&d_c, size_c); + + // warmup + for (int i = 0; i < warmup; ++i){ + gpu_hgemm(d_a, d_b, d_c, M, N, K); + } + cudaDeviceSynchronize(); + + cudaEvent_t start, end; + cudaEventCreate(&start); + cudaEventCreate(&end); + cudaEventRecord(start); + for (int i = 0; i < repeat; i++) { + gpu_hgemm(d_a, d_b, d_c, M, N, K); + } + cudaEventRecord(end); + cudaDeviceSynchronize(); + cudaEventSynchronize(end); + + float msec, sec; + cudaEventElapsedTime(&msec, start, end); + sec = msec / 1000.0 / repeat; + + cudaFree(d_a); + cudaFree(d_b); + cudaFree(d_c); + cudaEventDestroy(start); + cudaEventDestroy(end); + + return sec; +} + +int main(int argc, char *argv[]) { + int M = 1024; + int N = 1024; + int K = 1024; + int W = 1; + int R = 10; + if (argc > 1) M = std::stoi(argv[1]); + if (argc > 2) N = std::stoi(argv[2]); + if (argc > 3) K = std::stoi(argv[3]); + if (argc > 4) W = std::stoi(argv[4]); + if (argc > 5) R = std::stoi(argv[5]); + double avg_sec, avg_Tflops; + + printf("\nALGO = HGEMM MMA NAIVE\n"); + avg_sec = perf_gemm(launch_hgemm_mma_m16n8k16_naive, + M, N, K, W, R); + avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec; + printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R); + printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops); + + printf("\nALGO = HGEMM MMA NAIVE + SMEM SWIZZLE\n"); + avg_sec = perf_gemm(launch_hgemm_mma_m16n8k16_naive_smem_swizzle, + M, N, K, W, R); + avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec; + printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R); + printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops); + + printf("\nALGO = HGEMM mma2x4_warp4x4\n"); + avg_sec = perf_gemm(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4, + M, N, K, W, R); + avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec; + printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R); + printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops); + + printf("\nALGO = HGEMM mma2x4_warp4x4 + SMEM SWIZZLE\n"); + avg_sec = perf_gemm(launch_hgemm_mma_m16n8k16_mma2x4_warp4x4_smem_swizzle, + M, N, K, W, R); + avg_Tflops = ((double)M) * N * K * 2 * 1e-12 / avg_sec; + printf("M N K = %6d %6d %6d, W = %d, R = %d, ", M, N, K, W, R); + printf("Time = %12.8lf s, AVG Performance = %10.4lf Tflops\n", avg_sec, avg_Tflops); + + return 0; +} diff --git a/kernels/swizzle/makefile b/kernels/swizzle/makefile new file mode 100644 index 00000000..f08b1c59 --- /dev/null +++ b/kernels/swizzle/makefile @@ -0,0 +1,17 @@ +INCLUDE_DIRS=-I ../../third-party/cutlass/include -I ../../third-party/cutlass/tools/util/include +ARCHS=-gencode arch=compute_80,code=sm_80 -gencode arch=compute_89,code=sm_89 +ARCHS_89=-gencode arch=compute_89,code=sm_89 +DEFAULT_FLAGS=-O2 $(ARCHS) -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas +DEFAULT_FLAGS_89=-O2 $(ARCHS_89) -std=c++17 $(INCLUDE_DIRS) --expt-relaxed-constexpr -lcublas +default: + nvcc hgemm_mma_swizzle.cu -o hgemm_mma_swizzle.bin $(DEFAULT_FLAGS) + nvcc mat_trans_swizzle.cu -o mat_trans_swizzle.bin $(DEFAULT_FLAGS) + nvcc mma_simple_swizzle.cu -o mma_simple_swizzle.bin $(DEFAULT_FLAGS) +hgemm_89: + nvcc hgemm_mma_swizzle.cu -o hgemm_mma_swizzle.89.bin $(DEFAULT_FLAGS_89) +mma_89: + nvcc mma_simple_swizzle.cu -o mma_simple_swizzle.89.bin $(DEFAULT_FLAGS_89) +mat_89: + nvcc mat_trans_swizzle.cu -o mat_trans_swizzle.89.bin $(DEFAULT_FLAGS_89) +clean: + rm -rf *.bin diff --git a/kernels/swizzle/mat_trans_swizzle.cu b/kernels/swizzle/mat_trans_swizzle.cu new file mode 100644 index 00000000..c526771d --- /dev/null +++ b/kernels/swizzle/mat_trans_swizzle.cu @@ -0,0 +1,108 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + + +// reference: https://zhuanlan.zhihu.com/p/4746910252 +// 转置前的矩阵存储在dev_A中,矩阵大小为M*N,转置后的数据存储在dev_B中 +__global__ void mat_trans_smem_naive_kernel(int* dev_A, int M, int N, int* dev_B) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + // 每个block处理32*32的矩阵块 + __shared__ int s_data[32][32]; + + if (row < M && col < N) { + // 从全局内存中加载数据,转置后写到共享内存中 + s_data[threadIdx.x][threadIdx.y] = dev_A[row * N + col]; + __syncthreads(); + int n_col = blockIdx.y * blockDim.y + threadIdx.x; + int n_row = blockIdx.x * blockDim.x + threadIdx.y; + if (n_col < M && n_row < N) { + // 从转置后的共享内存按行写到全局内存结果中 + dev_B[n_row * M + n_col] = s_data[threadIdx.y][threadIdx.x]; + } + } +} + +// reference: https://zhuanlan.zhihu.com/p/4746910252 +__global__ void mat_trans_smem_padding_kernel(int* dev_A, int M, int N, int* dev_B) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + // 每个block处理32*32的矩阵块,尾部padding来避免bank conflict + __shared__ int s_data[32][33]; + + if (row < M && col < N) { + s_data[threadIdx.x][threadIdx.y] = dev_A[row * N + col]; + __syncthreads(); + int n_col = blockIdx.y * blockDim.y + threadIdx.x; + int n_row = blockIdx.x * blockDim.x + threadIdx.y; + if (n_col < M && n_row < N) { + dev_B[n_row * M + n_col] = s_data[threadIdx.y][threadIdx.x]; + } + } +} + +// reference: https://zhuanlan.zhihu.com/p/4746910252 +__global__ void mat_trans_smem_swizzle_kernel(int* dev_A, int M, int N, int* dev_B) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ int s_data[32][32]; + + if (row < M && col < N) { + // 从全局内存读取数据写入共享内存的逻辑坐标(row=x,col=y) + // 其映射的物理存储位置位置(row=x,col=x^y) + s_data[threadIdx.x][threadIdx.x ^ threadIdx.y] = dev_A[row * N + col]; + __syncthreads(); + int n_col = blockIdx.y * blockDim.y + threadIdx.x; + int n_row = blockIdx.x * blockDim.x + threadIdx.y; + if (n_row < N && n_col < M) { + // 从共享内存的逻辑坐标(row=y,col=x)读取数据 + // 其映射的物理存储位置(row=y,col=x^y) + dev_B[n_row * M + n_col] = s_data[threadIdx.y][threadIdx.x ^ threadIdx.y]; + } + } +} + +int main(int argc, char *argv[]) { + int M = 1024; + int N = 1024; + if (argc > 1) M = std::stoi(argv[1]); + if (argc > 2) N = std::stoi(argv[2]); + size_t size_a = M * N * sizeof(int); + size_t size_b = M * N * sizeof(int); + + int* dev_A; + int* dev_B; + cudaMalloc(&dev_A, size_a); + cudaMalloc(&dev_B, size_b); + cudaDeviceSynchronize(); + + dim3 block(32, 32); + dim3 grid(N/32, M/32); + + mat_trans_smem_naive_kernel<<>>(dev_A, M, N, dev_B); + cudaDeviceSynchronize(); + + mat_trans_smem_padding_kernel<<>>(dev_A, M, N, dev_B); + cudaDeviceSynchronize(); + + mat_trans_smem_swizzle_kernel<<>>(dev_A, M, N, dev_B); + cudaDeviceSynchronize(); + + printf("Done.\n"); + cudaFree(dev_A); + cudaFree(dev_B); + + return 0; +} diff --git a/kernels/swizzle/matrix_trans_swizzle.cu b/kernels/swizzle/matrix_trans_swizzle.cu deleted file mode 100644 index 105dbf23..00000000 --- a/kernels/swizzle/matrix_trans_swizzle.cu +++ /dev/null @@ -1,35 +0,0 @@ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -using namespace nvcuda; - -// reference: https://zhuanlan.zhihu.com/p/4746910252 -__global__ void matrix_trans_swizzling(int* dev_A, int M, int N, int* dev_B) { - int row = blockIdx.y * blockDim.y + threadIdx.y; - int col = blockIdx.x * blockDim.x + threadIdx.x; - - __shared__ int s_data[32][32]; - - if (row < M && col < N) { - // 从全局内存读取数据写入共享内存的逻辑坐标(row=x,col=y) - // 其映射的物理存储位置位置(row=x,col=x^y) - s_data[threadIdx.x][threadIdx.x ^ threadIdx.y] = dev_A[row * N + col]; - __syncthreads(); - int n_col = blockIdx.y * blockDim.y + threadIdx.x; - int n_row = blockIdx.x * blockDim.x + threadIdx.y; - if (n_row < N && n_col < M) { - // 从共享内存的逻辑坐标(row=y,col=x)读取数据 - // 其映射的物理存储位置(row=y,col=x^y) - dev_B[n_row * M + n_col] = s_data[threadIdx.y][threadIdx.x ^ threadIdx.y]; - } - } -} diff --git a/kernels/swizzle/mma_simple_swizzle.cu b/kernels/swizzle/mma_simple_swizzle.cu new file mode 100644 index 00000000..01bb8939 --- /dev/null +++ b/kernels/swizzle/mma_simple_swizzle.cu @@ -0,0 +1,202 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +using namespace nvcuda; + +#define WARP_SIZE 32 +#define DEVICE_INLINE __device__ inline +#define HOST_DEVICE_INLINE __device__ __host__ inline +#define INT4(value) (reinterpret_cast(&(value))[0]) +#define FLOAT4(value) (reinterpret_cast(&(value))[0]) +#define HALF2(value) (reinterpret_cast(&(value))[0]) +#define BFLOAT2(value) (reinterpret_cast<__nv_bfloat162*>(&(value))[0]) +#define LDST32BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST64BITS(value) (reinterpret_cast(&(value))[0]) +#define LDST128BITS(value) (reinterpret_cast(&(value))[0]) +#define CP_ASYNC_COMMIT_GROUP() asm volatile("cp.async.commit_group;\n" ::) +#define CP_ASYNC_WAIT_ALL() asm volatile("cp.async.wait_all;\n" ::) +#define CP_ASYNC_WAIT_GROUP(n) asm volatile("cp.async.wait_group %0;\n" ::"n"(n)) +// ca(cache all, L1 + L2): support 4, 8, 16 bytes, cg(cache global, L2): only support 16 bytes. +#define CP_ASYNC_CA(dst, src, bytes) asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) +#define CP_ASYNC_CG(dst, src, bytes) asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" ::"r"(dst), "l"(src), "n"(bytes)) +#define LDMATRIX_X1(R, addr) asm volatile("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +#define LDMATRIX_X1_T(R, addr) asm volatile("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" : "=r"(R) : "r"(addr)) +#define LDMATRIX_X2_T(R0, R1, addr) asm volatile("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" : "=r"(R0), "=r"(R1) : "r"(addr)) +#define LDMATRIX_X4_T(R0, R1, R2, R3, addr) asm volatile("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" : "=r"(R0), "=r"(R1), "=r"(R2), "=r"(R3) : "r"(addr)) +#define HMMA16816(RD0, RD1, RA0, RA1, RA2, RA3, RB0, RB1, RC0, RC1) asm volatile("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%8, %9};\n" : "=r"(RD0), "=r"(RD1) : "r"(RA0), "r"(RA1), "r"(RA2), "r"(RA3), "r"(RB0), "r"(RB1), "r"(RC0), "r"(RC1)) + +HOST_DEVICE_INLINE +int div_ceil(int a, int b) { return (a % b != 0) ? (a / b + 8) : (a / b); } + +// i: row index; j: col index +__device__ __host__ __forceinline__ int swizzle_j(int i, int j) { + // >>> sw(0,0),sw(0,8),sw(1,0),sw(1,8),sw(2,0),sw(2,8),sw(3,0),sw(3,8) + // (0, 8, 0, 8, 0, 8, 0, 8) + // >>> sw(4,0),sw(4,8),sw(5,0),sw(5,8),sw(6,0),sw(6,8),sw(7,0),sw(7,8) + // (8, 0, 8, 0, 8, 0, 8, 0) + // >>> sw(8,0),sw(8,8),sw(9,0),sw(9,8),sw(10,0),sw(10,8),sw(11,0),sw(11,8) + // (0, 8, 0, 8, 0, 8, 0, 8) + // >>> sw(12,0),sw(12,8),sw(13,0),sw(13,8),sw(14,0),sw(14,8),sw(15,0),sw(15,8) + // (8, 0, 8, 0, 8, 0, 8, 0) + return ((int(j / 8) ^ int(i / 4)) % 2) * 8; +} + + +template +__global__ void mma_simple_swizzle_kernel( + half* A, half* B, half* C, int M, int N, int K) { + const int bx = blockIdx.x; + const int by = blockIdx.y; + const int NUM_K_TILES = div_ceil(K, MMA_K); + constexpr int BM = MMA_M; // 16 + constexpr int BN = MMA_N; // 8 + constexpr int BK = MMA_K; // 16 + + __shared__ half s_a[MMA_M][MMA_K]; // 16x16 + __shared__ half s_b[MMA_K][MMA_N]; // 16x8 + + const int tid = threadIdx.y * blockDim.x + threadIdx.x; // within block + const int lane_id = tid % WARP_SIZE; // 0~31 + + // s_a[16][16], 每行16,每线程load 8,需要2线程,共16行,需2x16=32线程 + const int load_smem_a_m = tid / 2; // row 0~15 + const int load_smem_a_k = (tid % 2) * 8; // col 0,8 + // s_b[16][8], 每行8,每线程load 8,需要1线程,共16行,需16线程,只需一半线程加载 + const int load_smem_b_k = tid; // row 0~31, but only use 0~15 + const int load_smem_b_n = 0; // col 0 + const int load_gmem_a_m = by * BM + load_smem_a_m; // global m + const int load_gmem_b_n = bx * BN + load_smem_b_n; // global n + if (load_gmem_a_m >= M && load_gmem_b_n >= N) return; + + uint32_t RC[2] = {0, 0}; + + #pragma unroll + for (int k = 0; k < NUM_K_TILES; ++k) { + // gmem_a -> smem_a + int load_gmem_a_k = k * BK + load_smem_a_k; // global col of a + int load_gmem_a_addr = load_gmem_a_m * K + load_gmem_a_k; + // LDST128BITS(s_a[load_smem_a_m][load_smem_a_k]) = ( + // LDST128BITS(A[load_gmem_a_addr])); + LDST128BITS(s_a[load_smem_a_m][swizzle_j( + load_smem_a_m, load_smem_a_k)]) = (LDST128BITS(A[load_gmem_a_addr])); + + // gmem_b -> smem_b + if (lane_id < MMA_K) { + int load_gmem_b_k = k * MMA_K + load_smem_b_k; // global row of b + int load_gmem_b_addr = load_gmem_b_k * N + load_gmem_b_n; + LDST128BITS(s_b[load_smem_b_k][load_smem_b_n]) = ( + LDST128BITS(B[load_gmem_b_addr])); + } + __syncthreads(); + if (tid == 0) { + printf("\n"); + for (int i = 0; i < MMA_M; i++) { + for (int j = 0; j < MMA_K; j++) { + printf("A[%2d][%2d]=%4d, ", i, j, __half2int_rz(s_a[i][j])); + } + printf("\n"); + } + } + __syncthreads(); + + if (tid == 0) { + printf("\n"); + for (int i = 0; i < MMA_K; i++) { + for (int j = 0; j < MMA_N; j++) { + printf("B[%2d][%2d]=%4d, ", i, j, __half2int_rz(s_b[i][j])); + } + printf("\n"); + } + } + __syncthreads(); + + uint32_t RA[4]; + uint32_t RB[2]; + + // ldmatrix for s_a, ldmatrix.trans for s_b. + // s_a: (0,8) *8 -> 0,8 -> [(0~15),(0,8)] + // uint32_t load_smem_a_ptr = __cvta_generic_to_shared( + // &s_a[lane_id % 16][(lane_id / 16) * 8]); + uint32_t load_smem_a_ptr = __cvta_generic_to_shared( + &s_a[lane_id % 16][swizzle_j(lane_id % 16, (lane_id / 16) * 8)]); + LDMATRIX_X4(RA[0], RA[1], RA[2], RA[3], load_smem_a_ptr); + uint32_t load_smem_b_ptr = __cvta_generic_to_shared( + &s_b[lane_id % 16][0]); + LDMATRIX_X2_T(RB[0], RB[1], load_smem_b_ptr); + + HMMA16816(RC[0], RC[1], RA[0], RA[1], RA[2], RA[3], RB[0], RB[1], RC[0], RC[1]); + + __syncthreads(); + } + + // s_c[16][8], https://docs.nvidia.com/cuda/parallel-thread-execution/index.html + // #matrix-fragments-for-mma-m16n8k16-with-floating-point-type + // [0~7][0~3 u32 -> 0~7 f16], [8~15][0~3 u32 -> 0~7 f16] + int store_lane_gmem_c_m = by * BM + lane_id / 4; + int store_lane_gmem_c_n = bx * BN + (lane_id % 4) * 2; + int store_gmem_c_addr_0 = store_lane_gmem_c_m * N + store_lane_gmem_c_n; + int store_gmem_c_addr_1 = (store_lane_gmem_c_m + 8) * N + store_lane_gmem_c_n; + LDST32BITS(C[store_gmem_c_addr_0]) = LDST32BITS(RC[0]); + LDST32BITS(C[store_gmem_c_addr_1]) = LDST32BITS(RC[1]); +} + +int main(int argc, char *argv[]) { + int M = 16; + int N = 8; + int K = 16; + if (argc > 8) M = std::stoi(argv[1]); + if (argc > 2) N = std::stoi(argv[2]); + if (argc > 3) K = std::stoi(argv[3]); + + size_t size_a = M * K * sizeof(half); + size_t size_b = K * N * sizeof(half); + size_t size_c = M * N * sizeof(half); + + half *h_a, *h_b, *h_c; + half *d_a, *d_b, *d_c; + h_a = (half *)malloc(size_a); + h_b = (half *)malloc(size_b); + h_c = (half *)malloc(size_c); + + cudaMalloc(&d_a, size_a); + cudaMalloc(&d_b, size_b); + cudaMalloc(&d_c, size_c); + + for (int i = 0; i < M * K; i++) + h_a[i] = __float2half((float)i); // 0~255 16x16=256 + for (int i = 0; i < K * N; i++) + h_b[i] = __float2half((float)i); // 0~127 16x8=128 + + cudaMemcpy(d_a, h_a, size_a, cudaMemcpyHostToDevice); + cudaMemcpy(d_b, h_b, size_b, cudaMemcpyHostToDevice); + + constexpr int MMA_M = 16; + constexpr int MMA_N = 8; + constexpr int MMA_K = 16; + dim3 block(WARP_SIZE); + dim3 grid(div_ceil(N, MMA_N), div_ceil(M, MMA_M)); + + mma_simple_swizzle_kernel< + MMA_M, MMA_N, MMA_K><<>>( + d_a, d_b, d_c, M, N, K + ); + cudaFree(d_a); + cudaFree(d_b); + cudaFree(d_c); + free(h_a); + free(h_b); + free(h_c); + + return 0; +} diff --git a/kernels/swizzle/print_swizzle_layout.py b/kernels/swizzle/print_swizzle_layout.py new file mode 100644 index 00000000..dcb9d48e --- /dev/null +++ b/kernels/swizzle/print_swizzle_layout.py @@ -0,0 +1,46 @@ +import argparse + + +def pretty_print_line(m: str = "", sep: str = "-", width: int = 130): + res_len = width - len(m) + left_len = int(res_len / 2) + right_len = res_len - left_len + pretty_line = sep * left_len + m + sep * right_len + print(pretty_line) + + +def swizzle_permuted_j(i: int, j: int, col_stride: int = 64, step: int = 8): + # i: row index; j: col index. + return ((int(j / step) ^ int(i / 4)) % int(col_stride / step)) * step + + +def print_swizzle_layout(rows: int = 16, col_stride: int = 64, step: int = 8): + str_len = 0 + for i in range(rows): + layout = tuple(swizzle_permuted_j(i, j, col_stride, step) + for j in range(0, col_stride, step)) + layout_str = (f"| row {i:<2} | {layout} |") + str_len = len(layout_str) + if (i == 0): + print("-" * str_len) + pretty_print_line(f"swizzle layout", width=str_len) + pretty_print_line(f"col 0~{col_stride}, step {step}", width=str_len) + print("-" * str_len) + print(layout_str) + if ((i + 1) % 4 == 0 and i != (rows - 1)): + print("-" * str_len) + print("-" * str_len) + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--col_stride", "--col", type=int, default=64) + parser.add_argument("--step", type=int, default=8) + parser.add_argument("--rows", type=int, default=16) + return parser.parse_args() + + +if __name__ == "__main__": + args = get_args() + print_swizzle_layout(args.rows, args.col_stride, args.step) +