Skip to content

Commit

Permalink
[FA2] split-q + tiling-qk D=512 performance🎉 (#177)
Browse files Browse the repository at this point in the history
* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update flash_attn_mma.py

* Update README.md

* Update README.md

* Update README.md
  • Loading branch information
DefTruth authored Dec 23, 2024
1 parent 697e06f commit d474791
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 14 deletions.
35 changes: 29 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ I have also implemented **FlashAttention-2** using pure MMA PTX instructions, wh
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
|✔️|✔️|✔️|?|

Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~ Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop):
Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2/SDPA on some Devices. For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. Moreover, on NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA(EFFICIENT_ATTENTION). However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~

- Example: B=1, H=8, N=8192, `D=64` (NVIDIA RTX 3080 Laptop), Faster than FA2~🎉🎉
```bash
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
-------------------------------------------B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
Expand All @@ -72,6 +73,27 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDI
(flash): ['-0.00516129 ', '0.05783081 ', '-0.00027728 '], time:3.776550ms, TFLOPS:37.10
----------------------------------------------------------------------------------------------------------------------------------
```

- Example: B=1, H=48, N=8192, `D=512` (RTX 3080), FA2 not supported, `QK Tiling` Faster than SDPA~🎉🎉
```bash
python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D 512 # NVIDIA RTX 3080 Laptop, Faster than SDPA
------------------------------------------B=1, H=8, N=8192, D=512, Warmup: 1, Iters: 10-------------------------------------------
mma(split-q+tiling-qk+stage1): ['-0.00433731 ', '0.02165222 ', '-0.01544189 '], time:48.775554ms, TFLOPS:22.60 (+0.00%)
mma(split-q+tiling-qk+stage2): ['-0.00433731 ', '0.02165222 ', '-0.01544189 '], time:47.503424ms, TFLOPS:23.20 (+2.68%)
(sdpa): ['-0.00438309 ', '0.02174377 ', '-0.01551056 '], time:66.486573ms, TFLOPS:16.58
----------------------------------------------------------------------------------------------------------------------------------
```

- Example: B=1, H=48, N=8192, `D=512` (NVIDIA L20), FA2 not supported, `QK Tiling` Faster than SDPA~🎉🎉
```bash
python3 flash_attn_mma.py --B 1 --H 48 --D 512 --N 16384 --show-all --check --iters 10 --sdpa
-----------------------------------------B=1, H=48, N=16384, D=512, Warmup: 1, Iters: 10------------------------------------------
mma(split-q+tiling-qk+stage1): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:387.384224ms, TFLOPS:68.28 (+0.00%)
mma(split-q+tiling-qk+stage2): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:325.593209ms, TFLOPS:81.24 (+18.98%)
(sdpa): ['0.00790405 ', '-0.02330017 ', '0.00875854 '], time:452.067018ms, TFLOPS:58.51
----------------------------------------------------------------------------------------------------------------------------------
```

The `Split KV` and `Split Q` implementations have been carried out in [flash-attention-mma⚡️⚡️](./kernels/flash-attn) for performance comparison. The `Split KV` method, which involves splitting all QKV across MMA (Warps), is slower than `Split Q` policy, which splitting Q across MMA(Warps) and keep access KV for all MMA(Warps).

- 📚 Split KV (Basic, FlashAttention-1)
Expand Down Expand Up @@ -128,9 +150,10 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half*
<div id="mma-tiling-qk"></div>

```C++
// Fine-grained tiling (MMA level) for Q/K, it cause constant SRAM size 64*kMmaAtomK for Q/K,
// and O(kMmaAtomK*d) SRAM complexity for V, thus, the SRAM complexity is O(kMmaAtomK*d).
// Thus, we can extend D(headdim) to 1024. Performance is stay tuned for updates ~
// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
// extend D (head dimension) up to 1024. Performance is stay tuned for updates ~
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
```
Expand All @@ -150,14 +173,14 @@ flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half*

<div id="cuda-kernel"></div>

The kernels listed here will guide you through a step-by-step progression, ranging from easy to very challenging topics. The **Workflow** will look like: custom **CUDA** kernel impl -> **PyTorch** Python bindings -> Run tests. 👉TIPS: `*` = Tensor Cores (WMMA, MMA, CuTe), otherwise, CUDA Cores; `/` = not supported; `✔️` = supported; `` = TODO. Contents:
The kernels listed here will guide you through a step-by-step progression, ranging from easy to very challenging topics. The **Workflow** for each topic will look like: custom **CUDA** kernel impl -> **PyTorch** Python bindings -> Run tests. 👉TIPS: `*` = Tensor Cores (WMMA, MMA, CuTe), otherwise, CUDA Cores; `/` = not supported; `✔️` = supported; `` = TODO. Contents are listed below:

- [📚 Easy ⭐️](#cuda-kernel-easy-medium)
- [📚 Medium ⭐️⭐️](#cuda-kernel-easy-medium)
- [📚 Hard ⭐️⭐️⭐️](#cuda-kernel-hard)
- [📚 Hard++ ⭐⭐⭐️⭐️⭐️](#cuda-kernel-hard)

[📚 Easy](#cuda-kernel-easy-medium) and [📚 Medium](#cuda-kernel-easy-medium) sections cover fundamental operations such as element-wise, mat_trans, warp/block reduce, online-softmax, nms, layer-norm, rms-norm, dot-prod etc. [📚 Hard](#cuda-kernel-hard) and [📚 Hard++](#cuda-kernel-hard) sections delve deeper into advanced topics, primarily focusing on operations like `sgemv, sgemm, hgemv, hgemm and flash-attention`. These sections also provide numerous kernels implemented using Tensor Cores with pure MMA PTX instructions.
[📚 Easy](#cuda-kernel-easy-medium) and [📚 Medium](#cuda-kernel-easy-medium) sections cover operations such as `element-wise, mat_trans, warp/block reduce, online-softmax, nms, layer-norm, rms-norm, dot-prod, relu, gelu, swish, embedding` and basic usages for `FP32/FP16/BF16/FP8` . [📚 Hard](#cuda-kernel-hard) and [📚 Hard++](#cuda-kernel-hard) sections delve deeper into advanced topics, primarily focusing on operations like `sgemv, sgemm, hgemv, hgemm and flash-attention`. These sections also provide numerous kernels implemented using Tensor Cores with pure MMA PTX.

### 📚 Easy ⭐️ & Medium ⭐️⭐️ ([©️back👆🏻](#cuda-kernel))
<div id="cuda-kernel-easy-medium"></div>
Expand Down
28 changes: 21 additions & 7 deletions kernels/flash-attn/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,12 @@
|**Shared QKV/KV** SMEM|**Prefetch Q** s2r|**Prefetch K/V** g2s|SMEM/Block Swizzle|
|✔️|✔️|✔️|?|

This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2 on some Devices, for example, NVIDIA RTX 3080 Laptop. However, for large-scale attention computations, there remains a performance gap. Performance optimizations are ongoing; stay tuned for updates.
This repository's implementation of FlashAttention is intended solely for learning CUDA programming. For optimal performance, please use the official [flash-attention](https://github.com/Dao-AILab/flash-attention). Currently, for small-scale attention `(B<=4, H <=48, SeqLen <= 8192)` can run faster than offical FA2/SDPA on some Devices. However, for large-scale attention, there remains a performance gap. Performance is continuously being optimized. Stay tuned for updates ~

- Example: B=1, H=8, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
For example, on NVIDIA RTX 3080 Laptop, [📚 Split Q + Fully Shared QKV SMEM](#mma-share-qkv) can achieve **55 TFLOPS (D=64)** that almost **~1.5x** 🎉 faster than FA2. Moreover, on NVIDIA L20, [📚 Split Q + QK Fine-grained Tiling](mma-tiling-qk) can achieve **81 TFLOPS (D=512)** that almost **~1.4x** 🎉 faster than SDPA(EFFICIENT_ATTENTION).


- Example: B=1, H=8, N=8192, `D=64` (NVIDIA RTX 3080 Laptop), Faster than FA2~🎉🎉
```bash
python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
-------------------------------------------B=1, H=8, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
Expand All @@ -31,7 +34,7 @@ python3 flash_attn_mma.py --B 1 --H 8 --D 64 --N 8192 --iters 10 --torch # NVIDI
----------------------------------------------------------------------------------------------------------------------------------
```

- Example: B=1, H=48, N=8192, D=64 (NVIDIA RTX 3080 Laptop)
- Example: B=1, H=48, N=8192, `D=64` (NVIDIA RTX 3080 Laptop), Faster than FA2~🎉🎉
```bash
python3 flash_attn_mma.py --B 1 --H 48 --D 64 --N 8192 --iters 10 --torch # NVIDIA RTX 3080 Laptop
------------------------------------------B=1, H=48, N=8192, D=64, Warmup: 1, Iters: 10-------------------------------------------
Expand All @@ -47,7 +50,7 @@ python3 flash_attn_mma.py --B 1 --H 48 --D 64 --N 8192 --iters 10 --torch # NVI
(flash): ['-0.00041986 ', '0.03292847 ', '0.01330566 '], time:22.468138ms, TFLOPS:37.42
----------------------------------------------------------------------------------------------------------------------------------
```
- Example: B=1, H=48, N=8192, D=512 (NVIDIA RTX 3080 Laptop), FA2 not supported.
- Example: B=1, H=48, N=8192, `D=512` (NVIDIA RTX 3080 Laptop), FA2 not supported, `QK Tiling` Faster than SDPA~🎉🎉
```bash
python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D 512 # NVIDIA RTX 3080 Laptop, Faster than SDPA
------------------------------------------B=1, H=8, N=8192, D=512, Warmup: 1, Iters: 10-------------------------------------------
Expand All @@ -57,6 +60,16 @@ python3 flash_attn_mma.py --B 1 --H 8 --N 8192 --iters 10 --show-all --sdpa --D
----------------------------------------------------------------------------------------------------------------------------------
```

- Example: B=1, H=48, N=8192, `D=512` (NVIDIA L20), FA2 not supported, `QK Tiling` Faster than SDPA~🎉🎉
```bash
python3 flash_attn_mma.py --B 1 --H 48 --D 512 --N 16384 --show-all --check --iters 10 --sdpa
-----------------------------------------B=1, H=48, N=16384, D=512, Warmup: 1, Iters: 10------------------------------------------
mma(split-q+tiling-qk+stage1): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:387.384224ms, TFLOPS:68.28 (+0.00%)
mma(split-q+tiling-qk+stage2): ['0.0079422 ', '-0.02334595 ', '0.00881958 '], time:325.593209ms, TFLOPS:81.24 (+18.98%)
(sdpa): ['0.00790405 ', '-0.02330017 ', '0.00875854 '], time:452.067018ms, TFLOPS:58.51
----------------------------------------------------------------------------------------------------------------------------------
```

## 📖 Contents

- [📖 FlashAttetion MMA Kernels](#mma)
Expand Down Expand Up @@ -114,9 +127,10 @@ flash_attn_mma_stages_split_q_shared_qkv_kernel(half* Q, half* K, half* V, half*
<div id="mma-tiling-qk"></div>
```C++
// Fine-grained tiling (MMA level) for Q/K, it cause constant SRAM size 64*kMmaAtomK for Q/K,
// and O(kMmaAtomK*d) SRAM complexity for V, thus, the SRAM complexity is O(kMmaAtomK*d).
// Thus, we can extend D(headdim) to 1024. Performance is stay tuned for updates ~
// Fine-grained tiling at the MMA level for Q and K results in a constant SRAM usage of
// 64 * kMmaAtomK for Q and K. For V, the SRAM complexity is O(kMmaAtomK * d), leading to
// an overall SRAM complexity of O(kMmaAtomK * d). Consequently, this approach allows us to
// extend D (head dimension) up to 1024. Performance is stay tuned for updates ~
__global__ void // Q, K, V, O -> [B, H, N, D]
flash_attn_mma_stages_split_q_tiling_qk_kernel(half* Q, half* K, half* V, half* O, ...);
```
Expand Down
15 changes: 14 additions & 1 deletion kernels/flash-attn/flash_attn_mma.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
import math
import time
import torch
from torch import Tensor
from torch.nn import functional as F
from torch.utils.cpp_extension import load
from typing import Optional
from torch.nn.attention import sdpa_kernel, SDPBackend
from flash_attn import flash_attn_func
from functools import partial
import argparse
import random
import numpy as np
Expand Down Expand Up @@ -263,6 +266,16 @@ def unfused_standard_attn(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
return y


def sdpa(q: Tensor, k: Tensor, v: Tensor, use_flash: bool = False):
if not use_flash:
with sdpa_kernel(SDPBackend.EFFICIENT_ATTENTION):
out: Tensor = F.scaled_dot_product_attention(q, k, v)
else:
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
out: Tensor = F.scaled_dot_product_attention(q, k, v)
return out


def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
tag: str = "out_mma", check_all: bool = False,
is_flash: bool = True):
Expand Down Expand Up @@ -330,7 +343,7 @@ def check_all_close(out_flash_or_sdpa: torch.Tensor, out_mma: torch.Tensor,
if D <= 256:
out_flash, _ = run_benchmark(flash_attn_func, fq, fk, fv, "(flash)")
if args.run_torch_sdpa:
out_sdpa, _ = run_benchmark(F.scaled_dot_product_attention, q, k, v, "(sdpa)")
out_sdpa, _ = run_benchmark(partial(sdpa, use_flash=(D<=256)), q, k, v, "(sdpa)")
pretty_print_line()

torch.cuda.synchronize()
Expand Down

0 comments on commit d474791

Please sign in to comment.