Skip to content

Commit

Permalink
Add flash decoding(flash attention with split_kv) (#17)
Browse files Browse the repository at this point in the history
* Add flash decoding and integrate it into flash_attention
* use online logsumexp, add doc & references
  • Loading branch information
iclementine authored Feb 5, 2024
1 parent c9c0d76 commit 1641d0c
Show file tree
Hide file tree
Showing 5 changed files with 512 additions and 29 deletions.
73 changes: 73 additions & 0 deletions benchmark/flash_decoding_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import datetime
import logging
import pathlib
import torch
import triton

import flag_attn


try:
from flash_attn import flash_attn_func
FLASH_VER = 2
except BaseException:
try:
from flash_attn.flash_attn_interface import flash_attn_func
FLASH_VER = 1
except BaseException:
FLASH_VER = None
HAS_FLASH = FLASH_VER is not None


configs = [triton.testing.Benchmark(
x_names=['N_CTX'],
x_vals=[2**i for i in range(9, 20)],
line_arg='provider',
line_vals=['flag_attn', 'torch', ] + (['flash'] if HAS_FLASH else []),
line_names=['flag_attn', 'torch', ] + ([f'flash-{FLASH_VER}'] if HAS_FLASH else []),
styles=[('red', '-'), ('green', '-'), ('blue', '-'), ('cyan', '-')],
ylabel='tflop/s',
plot_name=f'attention_d-{D_HEAD}_dtype-{dtype} (ms)',
args={'D_HEAD': D_HEAD, 'dtype': dtype}
) for D_HEAD in [64, 128]
for dtype in [torch.float16]]

@triton.testing.perf_report(configs)
def bench_flash_attention(N_CTX, D_HEAD, provider, dtype=torch.float16):
BATCH = 2
H = 2048 // D_HEAD
causal = False
if provider == "flag_attn":
q = torch.randn((BATCH, H, 1, D_HEAD), dtype=dtype, device="cuda")
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
fn = lambda: flag_attn.flash_attention(q, k, v, causal=causal)
ms = triton.testing.do_bench(fn)
if provider == "torch":
q = torch.randn((BATCH, H, 1, D_HEAD), dtype=dtype, device="cuda")
k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
v = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=dtype, device="cuda")
try:
fn = lambda: flag_attn.testing.flash_attention(q, k, v, causal=causal, upcast=False)
ms = triton.testing.do_bench(fn)
except torch.cuda.OutOfMemoryError as e:
logging.info(f"torch OOM for batch_size: {BATCH}, num_heads: {H}, seqlen: {N_CTX}, headdim: {D_HEAD}")
ms = float("inf")
if provider == "flash":
q = torch.randn((BATCH, 1, H, D_HEAD), dtype=dtype, device="cuda")
k = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda")
v = torch.randn((BATCH, N_CTX, H, D_HEAD), dtype=dtype, device="cuda")
fn = lambda: flash_attn_func(q, k, v, causal=causal)
ms = triton.testing.do_bench(fn)

return ms
# # total TFLOPS: following Flash Attention v2, only gemms are counted.
# macs = 2. * BATCH * H * N_CTX * D_HEAD # Q@K, P@V
# total_flops = 2 * macs
# return total_flops / ms * 1e-9

# only works on post-Ampere GPUs right now
today = datetime.date.today().strftime(format("%Y%m%d"))
output_dir = pathlib.Path(f"results_flash_attention_with_split_kv_{today}")
output_dir.mkdir(exist_ok=True)
bench_flash_attention.run(save_path=output_dir, print_data=True)
1 change: 1 addition & 0 deletions src/flag_attn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,6 @@

from flag_attn.piecewise import attention as piecewise_attention # noqa: F401
from flag_attn.flash import attention as flash_attention # noqa: F401
from flag_attn.split_kv import attention as flash_attention_split_kv # noqa: F401

from flag_attn import testing # noqa: F401
93 changes: 70 additions & 23 deletions src/flag_attn/flash.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import triton
import triton.language as tl
from flag_attn.total import _total_attention_kernel
from flag_attn.split_kv import _fwd_split_kv_kernel, _fwd_combine_kv_splits, num_splits_herustic
from flag_attn.split_kv import get_fwd_config as get_fwd_config_kv_split

__all__ = ["attention"]

Expand All @@ -24,30 +26,76 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_

# to work around https://github.com/openai/triton/issues/2441
device = torch.cuda.device_of(q)
num_sms = torch.cuda.get_device_properties(device).multi_processor_count

with torch.cuda.device(device):
config = get_fwd_config(B, H, M, N, D, causal)
BLOCK_M, BLOCK_N, num_stages, num_warps = config
config_for_split_kv = get_fwd_config_kv_split(B, H, M, N, D, causal)
S = num_splits_herustic(B, H, M, N, config_for_split_kv[0], config_for_split_kv[1], num_sms, 128)
split_kv: bool = S > 1

if not split_kv:
config = get_fwd_config(B, H, M, N, D, causal)
BLOCK_M, BLOCK_N, num_stages, num_warps = config

divisible_m = M % BLOCK_M == 0
divisible_n = N % BLOCK_N == 0
# consider using 3d grid to avoid div & rem
grid = (triton.cdiv(M, BLOCK_M), H, B)
o = torch.empty_like(q)
L = torch.empty((B, H, M), device=q.device, dtype=torch.float32)
_fwd_kernel[grid](
q, k, v, sm_scale,
L, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
B, H, M, N, P_SEQ,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D,
IS_CAUSAL=causal, LARGER_M=larger_m,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
num_warps=num_warps, num_stages=num_stages,
)
else: # split kv
BLOCK_M, BLOCK_N, num_stages, num_warps = config_for_split_kv

divisible_m = M % BLOCK_M == 0
divisible_n = N % BLOCK_N == 0

# consider using 3d grid to avoid div & rem
multiple_l = torch.empty((B, H, S, M), dtype=torch.float32, device="cuda")
multiple_o = torch.empty((B, H, S, M, D), dtype=torch.float16, device="cuda")
grid = (triton.cdiv(M, BLOCK_M), S, H * B)
N_SPLIT_SIZE = triton.cdiv(triton.cdiv(N, BLOCK_N), S) * BLOCK_N
_fwd_split_kv_kernel[grid](
q, k, v, sm_scale,
multiple_l, multiple_o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4),
B, H, M, N, P_SEQ, N_SPLIT_SIZE, S,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N,
IS_CAUSAL=causal, LARGER_M=larger_m,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
num_stages=num_stages, num_warps=num_warps,
)

divisible_m = M % BLOCK_M == 0
divisible_n = N % BLOCK_N == 0
# consider using 3d grid to avoid div & rem
grid = (triton.cdiv(M, BLOCK_M), H, B)
o = torch.empty_like(q)
L = torch.empty((B, H, M), device=q.device, dtype=torch.float32)
_fwd_kernel[grid](
q, k, v, sm_scale,
L, o,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
k.stride(0), k.stride(1), k.stride(2), k.stride(3),
v.stride(0), v.stride(1), v.stride(2), v.stride(3),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
B, H, M, N, P_SEQ,
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D,
IS_CAUSAL=causal, LARGER_M=larger_m,
DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n,
num_warps=num_warps, num_stages=num_stages,
)
L = torch.empty((B, H, M), dtype=torch.float32, device="cuda")
o = torch.empty_like(q)
grid = (triton.cdiv(M, BLOCK_M), H, B)
_fwd_combine_kv_splits[grid](
multiple_o, multiple_l,
o, L,
multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4),
o.stride(0), o.stride(1), o.stride(2), o.stride(3),
B, H, M, S,
BLOCK_M=BLOCK_M, BLOCK_DMODEL=D,
DIVISIBLE_M=divisible_m,
num_stages=num_stages, num_warps=num_warps,
)

# total attention
if return_total_attention:
tot_attn = torch.empty((B, H, N), device=q.device, dtype=torch.float32)
grid = (triton.cdiv(N, BLOCK_N), H, B)
Expand Down Expand Up @@ -75,8 +123,7 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_
tot_attn if return_total_attention else None
)
return outs
else:
return o
return o

@staticmethod
def backward(ctx, do, *ignored):
Expand Down
Loading

0 comments on commit 1641d0c

Please sign in to comment.