Skip to content

Commit

Permalink
Enable headdim 256 backward on consumer GPUs (Ampere, Ada)
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Feb 21, 2024
1 parent 43950dd commit 2406f28
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 37 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ FlashAttention-2 currently supports:
GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing
GPUs for now.
2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs).
3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800.
3. All head dimensions up to 256. ~~Head dim > 192 backward requires A100/A800 or H100/H800~~. Head dim 256 backward now works on consumer GPUs (if there's no dropout) as of flash-attn 2.5.5.


## How to use FlashAttention
Expand Down
8 changes: 4 additions & 4 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -783,8 +783,8 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
if (head_size > 192) {
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
if (head_size > 192 && (head_size <= 224 || is_dropout)) {
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
}
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

Expand Down Expand Up @@ -1020,8 +1020,8 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
TORCH_CHECK(batch_size > 0, "batch size must be positive");
TORCH_CHECK(head_size % 8 == 0, "head_size should be a multiple of 8");
TORCH_CHECK(head_size <= 256, "FlashAttention backward only supports head dimension at most 256");
if (head_size > 192) {
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim > 192 requires A100/A800 or H100/H800");
if (head_size > 192 && (head_size <= 224 || is_dropout)) {
TORCH_CHECK(is_sm80 || is_sm90, "FlashAttention backward for head dim 256 with dropout, or head dim 224 with/without dropout requires A100/A800 or H100/H800");
}
TORCH_CHECK(num_heads % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query");

Expand Down
2 changes: 1 addition & 1 deletion csrc/flash_attn/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value.
flash::scale_apply_exp2</*scale_max=*/false>(scores, lse, params.scale_softmax_log2);
if (Is_dropout) {
if constexpr (Is_dropout) {
int warp_id = tidx / 32;
int block_row_idx = m_block * (kBlockM / 16) + warp_id % AtomLayoutMS;
// Need col to be multiples of 32, since we're doing dropout with block of 16 x 32
Expand Down
6 changes: 5 additions & 1 deletion csrc/flash_attn/src/flash_bwd_launch_template.h
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,12 @@ void run_mha_bwd_hdim256(Flash_bwd_params &params, cudaStream_t stream) {
DROPOUT_SWITCH(params.p_dropout < 1.f, Is_dropout, [&] {
if (max_smem_per_block >= 176 * 1024) { // H100
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, false, T>, Is_dropout>(params, stream);
} else { // A100, we don't do double buffering to save smem
} else if (max_smem_per_block >= 144 * 1024) { // A100, we don't do double buffering to save smem
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 64, 8, 4, 2, 2, false, true, T>, Is_dropout>(params, stream);
} else { // sm86 and sm89, max smem is 99 KB. Only works without dropout. V in regs and no double buffering.
if constexpr (!Is_dropout) {
run_flash_bwd<Flash_bwd_kernel_traits<Headdim, 64, 32, 8, 4, 1, 2, true, true, T>, false>(params, stream);
}
}
});
}
6 changes: 4 additions & 2 deletions csrc/flash_attn/src/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,9 +231,11 @@ struct Flash_bwd_kernel_traits : public Base {
// TODO: generalize to other values of kBlockN
// TODO: what should be the Swizzle here? 3 is faster than 1, and 1 is faster than 2
// static constexpr int kPBlockN = kBlockN;
static_assert(kBlockN >= 64);
// Temporarily disabling this for hdim 256 on sm86 and sm89
// static_assert(kBlockN >= 64);
static_assert(kBlockN >= 32);
// TD [2023-03-19]: Idk why kPBlockN = 16 and kSwizzlePdS=3 is the fastest.
static constexpr int kPBlockN = 64;
static constexpr int kPBlockN = kBlockN >= 64 ? 64 : 32;
static_assert(kPBlockN == 16 || kPBlockN == 32 || kPBlockN == 64);
// static constexpr int kSwizzlePdS = kPBlockN == 16 ? 1 : (kPBlockN == 32 ? 2 : 3);
static constexpr int kSwizzlePdS = 3;
Expand Down
66 changes: 38 additions & 28 deletions tests/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
# do_o = (g.float() * out.float()).sum(-1)
# dv_tmp = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, :64], g[:, :64])
# dv_tmp1 = torch.einsum('bhts,bthd->bshd', attn_pt[:, :, 64:], g[:, 64:])
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
(dqkv,) = torch.autograd.grad(out, qkv, g)
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
(dqkv_pt,) = torch.autograd.grad(out_pt, qkv, g)
Expand All @@ -687,7 +687,7 @@ def test_flash_attn_qkvpacked(seqlen, d, dropout_p, causal, local, alibi, determ
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)

if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()


Expand Down Expand Up @@ -811,7 +811,7 @@ def test_flash_attn_varlen_qkvpacked(
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")

g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
(dqkv_unpad,) = torch.autograd.grad(out, qkv_unpad, g)
dqkv = dqkv_pad_fn(dqkv_unpad)
(dqkv_ref,) = torch.autograd.grad(out_ref, qkv, g)
Expand All @@ -835,7 +835,7 @@ def test_flash_attn_varlen_qkvpacked(
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)

if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
assert (dqkv - dqkv_ref).abs().max().item() <= 2 * (dqkv_pt - dqkv_ref).abs().max().item()


Expand Down Expand Up @@ -1036,7 +1036,7 @@ def test_flash_attn_output(

g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
if kvpacked:
(
dq,
Expand Down Expand Up @@ -1092,7 +1092,7 @@ def test_flash_attn_output(
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)

if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
Expand Down Expand Up @@ -1339,7 +1339,7 @@ def test_flash_attn_varlen_output(
print(f"Attention Pytorch max diff: {(attn_pt - attn_ref).abs().max().item()}")

g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
if kvpacked:
(
dq_unpad,
Expand Down Expand Up @@ -1398,7 +1398,7 @@ def test_flash_attn_varlen_output(
if not alibi:
assert abs(dropout_fraction - dropout_p) <= (0.01 if not local else 0.025)

if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
assert (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
assert (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
Expand Down Expand Up @@ -1476,7 +1476,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):

g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
(
dq,
dk,
Expand Down Expand Up @@ -1509,7 +1509,7 @@ def test_flash_attn_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtype):
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5

if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
Expand Down Expand Up @@ -1625,7 +1625,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp

g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
(
dq_unpad,
dk_unpad,
Expand Down Expand Up @@ -1661,7 +1661,7 @@ def test_flash_attn_varlen_causal(seqlen_q, seqlen_k, swap_sq_sk, d, local, dtyp
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5

if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item() + 1e-5
assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item() + 1e-5
assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item() + 1e-5
Expand Down Expand Up @@ -1755,7 +1755,7 @@ def test_flash_attn_splitkv(

g = torch.randn_like(out)
do_o = (g.float() * out.float()).sum(-1)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
(
dq,
dk,
Expand Down Expand Up @@ -1789,7 +1789,7 @@ def test_flash_attn_splitkv(
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item() + 1e-5

mult = 2 if not alibi else 8
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
assert (dq - dq_ref).abs().max().item() <= mult * (dq_pt - dq_ref).abs().max().item() + 2e-4
assert (dk - dk_ref).abs().max().item() <= mult * (dk_pt - dk_ref).abs().max().item() + 2e-4
assert (dv - dv_ref).abs().max().item() <= mult * (dv_pt - dv_ref).abs().max().item() + 2e-4
Expand All @@ -1815,8 +1815,9 @@ def test_flash_attn_splitkv(
# @pytest.mark.parametrize("rotary_interleaved", [False])
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
@pytest.mark.parametrize("paged_kv_block_size", [256, 512])
@pytest.mark.parametrize("paged_kv_block_size", [None, 256])
# @pytest.mark.parametrize("paged_kv_block_size", [256, 512])
# @pytest.mark.parametrize("paged_kv_block_size", [256])
@pytest.mark.parametrize("has_batch_idx", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
Expand Down Expand Up @@ -1900,12 +1901,13 @@ def test_flash_attn_kvcache(
b=batch_size,
)
k_cache = rearrange(
k_cache_paged[block_table.flatten()],
# pytorch 1.12 doesn't have indexing with int32
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache = rearrange(
v_cache_paged[block_table.flatten()],
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
Expand Down Expand Up @@ -1972,8 +1974,12 @@ def test_flash_attn_kvcache(
cos, sin = None, None
q_ro, k_ro = q, k
# k_cache[:, 64:] = -1
k_cache_ref = (k_cache if not has_batch_idx else k_cache[cache_batch_idx]).clone()
v_cache_ref = (v_cache if not has_batch_idx else v_cache[cache_batch_idx]).clone()
k_cache_ref = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
).clone()
v_cache_ref = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
).clone()
if new_kv:
update_mask = torch.logical_and(
cache_seqlens_expanded <= arange, arange < cache_seqlens_expanded + seqlen_new
Expand Down Expand Up @@ -2044,16 +2050,20 @@ def test_flash_attn_kvcache(
# of a Pytorch implementation.
if new_kv:
if paged_kv_block_size is None:
k_cache_select = k_cache if not has_batch_idx else k_cache[cache_batch_idx]
v_cache_select = v_cache if not has_batch_idx else v_cache[cache_batch_idx]
k_cache_select = (
k_cache if not has_batch_idx else k_cache[cache_batch_idx.to(dtype=torch.long)]
)
v_cache_select = (
v_cache if not has_batch_idx else v_cache[cache_batch_idx.to(dtype=torch.long)]
)
else:
k_cache_select = rearrange(
k_cache_paged[block_table.flatten()],
k_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
v_cache_select = rearrange(
v_cache_paged[block_table.flatten()],
v_cache_paged[block_table.to(dtype=torch.long).flatten()],
"(b nblocks) block_size ... -> b (nblocks block_size) ...",
b=batch_size,
)[:, :seqlen_k]
Expand Down Expand Up @@ -2104,7 +2114,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
torch.random.manual_seed(42)
out0, lse0, _ = flash_attn_func(q, k, v, dropout_p, causal=causal, return_attn_probs=True)
g = torch.randn_like(out0)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
(
dq0,
dk0,
Expand All @@ -2119,7 +2129,7 @@ def test_flash_attn_race_condition(seqlen_q, seqlen_k, d, dropout_p, causal, dty
assert torch.equal(out, out0)
assert torch.equal(lse, lse0)

if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or (d > 224 and dropout_p == 0)) or (is_sm80 or is_sm90):
(
dq,
dk,
Expand Down Expand Up @@ -2326,7 +2336,7 @@ def test_flash_attn_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, causal, loc
out = flash_attn_func(q, k, v, 0.0, causal=causal, window_size=window_size, deterministic=True)

g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
dq0, dk0, dv0 = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q, k, v), g, retain_graph=True)
Expand Down Expand Up @@ -2414,7 +2424,7 @@ def test_flash_attn_varlen_deterministic(seqlen_q, seqlen_k, swap_sq_sk, d, caus
)

g = torch.randn_like(out)
if d <= MAX_HEADDIM_SM8x or (is_sm80 or is_sm90):
if (d <= MAX_HEADDIM_SM8x or d > 224) or (is_sm80 or is_sm90):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
for _ in range(50):
dq, dk, dv = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g, retain_graph=True)
Expand Down

0 comments on commit 2406f28

Please sign in to comment.