From 13664fc33785524f66d22c657386ad93d8bd5f3a Mon Sep 17 00:00:00 2001 From: Clement Chan Date: Mon, 27 May 2024 16:31:00 +0800 Subject: [PATCH] support grouped query attention(MQA & GQA) for flash_attn (#22) * support grouped query attention(GQA) for flash_attn(fwd, bwd, split_kv, total_attention) * add mqa/gqa into feature list; update documentations and testings for flash attention. --- README.md | 1 + src/flag_attn/flash.py | 73 +++++++++---- src/flag_attn/split_kv.py | 13 ++- src/flag_attn/testing/flash.py | 9 ++ src/flag_attn/total.py | 11 +- tests/flag_attn/test_flash_attention.py | 139 +++++++++++++++--------- 6 files changed, 165 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index c89a9df..c8414cb 100644 --- a/README.md +++ b/README.md @@ -233,6 +233,7 @@ The performance of piecewise_attention has improved compared to that in v0.1. In - the sequence length of k/v can be different from that of q; - support computation of total attention of each `k` gets from all `q`'s; - supports returning accumulative attention of each keys. +- supports [MQA](https://arxiv.org/abs/1911.02150) and [GQA](https://arxiv.org/pdf/2305.13245). #### Limitations diff --git a/src/flag_attn/flash.py b/src/flag_attn/flash.py index a32564b..73c4bac 100644 --- a/src/flag_attn/flash.py +++ b/src/flag_attn/flash.py @@ -8,22 +8,40 @@ __all__ = ["attention"] + +def maybe_contiguous(x): + # only when the inner most dimension is contiguous can LDGSTS be used + # so inner-dimension contiguity is enforced. + return x.contiguous() if x.stride(-1) != 1 else x + +def rounded_multiple(a, b): + return (a + b - 1) // b * b + # --------------------------- public API --------------------------- class FlashAttention(torch.autograd.Function): @staticmethod def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention): + # size, stride, dtype checking Dq, Dk, Dv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Dq == Dk == Dv + assert Dq == Dk == Dv, "feature size of q, k, v should be equal" assert Dk in {16, 32, 64, 128} B, H, M, D = q.shape N = k.shape[2] + Hk, Hv = k.shape[1], v.shape[1] + assert Hk == Hv, "num of heads in k and v should be equal" + assert H % Hk == 0, "number of heads in q must be a multiple of that in k & v" + num_groups = H // Hk + P_SEQ = N - M larger_m = M > N if sm_scale is None: sm_scale = 1. / math.sqrt(D) + # contiguity + q, k, v = maybe_contiguous(q), maybe_contiguous(k), maybe_contiguous(v) + # to work around https://github.com/openai/triton/issues/2441 device = torch.cuda.device_of(q) num_sms = torch.cuda.get_device_properties(device).multi_processor_count @@ -32,6 +50,7 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ config_for_split_kv = get_fwd_config_kv_split(B, H, M, N, D, causal) S = num_splits_herustic(B, H, M, N, config_for_split_kv[0], config_for_split_kv[1], num_sms, 128) split_kv: bool = S > 1 + # print(f"flag_attn choose {S} splits") if not split_kv: config = get_fwd_config(B, H, M, N, D, causal) @@ -50,7 +69,7 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), o.stride(0), o.stride(1), o.stride(2), o.stride(3), - B, H, M, N, P_SEQ, + B, H, M, N, P_SEQ, num_groups, BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=D, IS_CAUSAL=causal, LARGER_M=larger_m, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, @@ -61,7 +80,6 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ divisible_m = M % BLOCK_M == 0 divisible_n = N % BLOCK_N == 0 - # consider using 3d grid to avoid div & rem multiple_l = torch.empty((B, H, S, M), dtype=torch.float32, device="cuda") multiple_o = torch.empty((B, H, S, M, D), dtype=torch.float16, device="cuda") @@ -74,7 +92,7 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4), - B, H, M, N, P_SEQ, N_SPLIT_SIZE, S, + B, H, M, N, P_SEQ, N_SPLIT_SIZE, S, num_groups, BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, IS_CAUSAL=causal, LARGER_M=larger_m, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, @@ -103,7 +121,7 @@ def forward(ctx, q, k, v, causal, sm_scale, return_log_normalizer, return_total_ q, k, L, tot_attn, sm_scale, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), - B, H, M, N, P_SEQ, + B, H, M, N, P_SEQ, num_groups, BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, @@ -133,6 +151,8 @@ def backward(ctx, do, *ignored): B, H, M, D = q.shape N = k.shape[2] + Hk = k.shape[1] + num_groups = H // Hk P_SEQ = N - M larger_m = M > N @@ -161,8 +181,9 @@ def backward(ctx, do, *ignored): DIVISIBLE_M=divisible_m, ) - dk = torch.empty_like(k) - dv = torch.empty_like(v) + # NOTE that dk & dv always have the same number of heads as q, instead of q. + dk = torch.empty((B, H, N, D), dtype=k.dtype, device=q.device) + dv = torch.empty((B, H, N, D), dtype=v.dtype, device=q.device) grid = (triton.cdiv(N, BLOCK_N), H, B) _bwd_kv_kernel[grid]( q, k, v, sm_scale, do, @@ -175,6 +196,7 @@ def backward(ctx, do, *ignored): dk.stride(0), dk.stride(1), dk.stride(2), dk.stride(3), dv.stride(0), dv.stride(1), dv.stride(2), dv.stride(3), B, H, M, N, P_SEQ, + num_groups, BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, num_stages=num_stages, num_warps=num_warps, @@ -192,12 +214,14 @@ def backward(ctx, do, *ignored): do.stride(0), do.stride(1), do.stride(2), do.stride(3), dq.stride(0), dq.stride(1), dq.stride(2), dq.stride(3), B, H, M, N, P_SEQ, + num_groups, BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, LARGER_M=larger_m, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, num_stages=num_stages, num_warps = num_warps, ) - + dk = dk.reshape((B, Hk, num_groups, N, D)).sum(2) + dv = dv.reshape((B, Hk, num_groups, N, D)).sum(2) return dq, dk, dv, None, None, None, None @@ -208,21 +232,24 @@ def attention(q, k, v, causal=False, sm_scale=None, An implementation of FlashAttention v2(https://arxiv.org/abs/2307.08691). Arguments: - q(torch.Tensor): The first queries. The shape is (batch_size, nheads, seqlen_q, headdim). - k(torch.Tensor): The first keys. The shape is (batch_size, nheads, seqlen_k, headdim). - v(torch.Tensor): The values. The shape is (batch_size, nheads, seqlen_k, headdim). + q(torch.Tensor): The first queries. The shape is (batch_size, num_heads_q, seqlen_q, headdim). + k(torch.Tensor): The first keys. The shape is (batch_size, num_heads_k, seqlen_k, headdim). + v(torch.Tensor): The values. The shape is (batch_size, num_heads_k, seqlen_k, headdim). causal(bool): Whether causal masking is applied to attention scores before applying softmax. sm_scale(float): The scaling of attention scores before applying softmax. return_log_normalizer(bool): Whether to return the log normalizer of softmax inside attention. return_total_attention(bool): Whether to return the sum of attention along q's sequence dimendion. Returns: - out(torch.Tensor): The output. The shape is (batch_size, nheads, seqlen_q, headdim). + out(torch.Tensor): The output. The shape is (batch_size, num_heads_q, seqlen_q, headdim). If `return_log_normalizer` or `return_total_attention`, return the following results in addition. - log_normalizer(torch.Tensor): The log normalizer. The shape is (batch_size, nheads, seqlen_q). - total_attention(torch.Tensor): The total attention. The shape is (batch_size, nheads, seqlen_k). + log_normalizer(torch.Tensor): The log normalizer. The shape is (batch_size, num_heads_q, seqlen_q). + total_attention(torch.Tensor): The total attention. The shape is (batch_size, num_heads_q, seqlen_k). + + Notes: + `num_heads_q` must be a multiple of `num_heads_k`. """ return FlashAttention.apply(q, k, v, causal, sm_scale, return_log_normalizer, return_total_attention) @@ -272,6 +299,7 @@ def _fwd_kernel( stride_vz, stride_vh, stride_vn, stride_vk, stride_oz, stride_oh, stride_om, stride_ok, Z, H, M, N, P_SEQ, + num_groups, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, @@ -289,9 +317,10 @@ def _fwd_kernel( qk_scale = sm_scale * log2e # offset pointers for (batch, head) + off_hk = off_h // num_groups Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_kz + off_h * stride_kh - V += off_z * stride_vz + off_h * stride_vh + K += off_z * stride_kz + off_hk * stride_kh + V += off_z * stride_vz + off_hk * stride_vh O += off_z * stride_oz + off_h * stride_oh L += (off_z * H + off_h) * M # l's shape is (B, H, M) @@ -491,6 +520,7 @@ def _bwd_kv_kernel( stride_dkz, stride_dkh, stride_dkn, stride_dkk, stride_dvz, stride_dvh, stride_dvn, stride_dvk, Z, H, M, N, P_SEQ, + num_groups, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, CAUSAL: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, @@ -504,9 +534,10 @@ def _bwd_kv_kernel( qk_scale = sm_scale * log2e # offset pointers for (batch, head) + off_hk = off_h // num_groups Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_kz + off_h * stride_kh - V += off_z * stride_vz + off_h * stride_vh + K += off_z * stride_kz + off_hk * stride_kh + V += off_z * stride_vz + off_hk * stride_vh DO += off_z * stride_doz + off_h * stride_doh # offset pointers for batch/head @@ -637,6 +668,7 @@ def _bwd_q_kernel( stride_doz, stride_doh, stride_dom, stride_dok, stride_dqz, stride_dqh, stride_dqm, stride_dqk, Z, H, M, N, P_SEQ, + num_groups, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, @@ -654,9 +686,10 @@ def _bwd_q_kernel( qk_scale = sm_scale * log2e # offset pointers for (batch, head) + off_hk = off_h // num_groups Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_kz + off_h * stride_kh - V += off_z * stride_vz + off_h * stride_vh + K += off_z * stride_kz + off_hk * stride_kh + V += off_z * stride_vz + off_hk * stride_vh DO += off_z * stride_doz + off_h * stride_doh D += (off_z * H + off_h) * M L += (off_z * H + off_h) * M diff --git a/src/flag_attn/split_kv.py b/src/flag_attn/split_kv.py index c3fc0cd..85def19 100644 --- a/src/flag_attn/split_kv.py +++ b/src/flag_attn/split_kv.py @@ -19,7 +19,7 @@ def _fwd_split_kv_kernel( stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vn, stride_vk, stride_oz, stride_oh, stride_os, stride_om, stride_ok, - Z, H, M, N, P_SEQ, N_SPLIT_SIZE, S, + Z, H, M, N, P_SEQ, N_SPLIT_SIZE, S, num_groups, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, LARGER_M: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, @@ -31,6 +31,7 @@ def _fwd_split_kv_kernel( off_zh = tl.program_id(2) off_h = off_zh % H off_z = off_zh // H + off_hk = off_h // num_groups # scale sm_scale by log_2(e) and use # 2^x instead of exp in the loop because CSE and LICM @@ -40,8 +41,8 @@ def _fwd_split_kv_kernel( # offset pointers for (batch & head) Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_kz + off_h * stride_kh - V += off_z * stride_vz + off_h * stride_vh + K += off_z * stride_kz + off_hk * stride_kh + V += off_z * stride_vz + off_hk * stride_vh # offset pointers for (batch & head, split) O += off_z * stride_oz + off_h * stride_oh + n_split_id * stride_os # o's shape is (B, H, S, M, D) @@ -269,6 +270,10 @@ def attention(q, k, v, causal=False, sm_scale=None): B, H, M, D = q.shape N = k.shape[2] + Hk, Hv = k.shape[1], v.shape[1] + assert Hk == Hv, "num of heads in k and v should be equal" + assert H % Hk == 0, "number of heads in q must be a multiple of that in k & v" + num_groups = H // Hk P_SEQ = N - M larger_m = M > N @@ -299,7 +304,7 @@ def attention(q, k, v, causal=False, sm_scale=None): k.stride(0), k.stride(1), k.stride(2), k.stride(3), v.stride(0), v.stride(1), v.stride(2), v.stride(3), multiple_o.stride(0), multiple_o.stride(1), multiple_o.stride(2), multiple_o.stride(3), multiple_o.stride(4), - B, H, M, N, P_SEQ, N_SPLIT_SIZE, S, + B, H, M, N, P_SEQ, N_SPLIT_SIZE, S, num_groups, BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, IS_CAUSAL=causal, LARGER_M=larger_m, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, diff --git a/src/flag_attn/testing/flash.py b/src/flag_attn/testing/flash.py index 7f272ed..4314019 100644 --- a/src/flag_attn/testing/flash.py +++ b/src/flag_attn/testing/flash.py @@ -16,6 +16,15 @@ def attention(q, D = q.shape[-1] if sm_scale is None: sm_scale = 1. / math.sqrt(D) + + num_heads_q = q.shape[1] + num_heads_k = k.shape[1] + assert num_heads_q % num_heads_k == 0 + num_groups = num_heads_q // num_heads_k + + if num_groups > 1: + k = torch.repeat_interleave(k, repeats=num_groups, dim=1) + v = torch.repeat_interleave(v, repeats=num_groups, dim=1) kv_seq_len = k.shape[-2] q_seq_len = q.shape[-2] p_seq = kv_seq_len - q_seq_len diff --git a/src/flag_attn/total.py b/src/flag_attn/total.py index 318be34..bdede6a 100644 --- a/src/flag_attn/total.py +++ b/src/flag_attn/total.py @@ -14,6 +14,10 @@ def total_attention(q, k, l, causal=False, sm_scale=None): B, H, M, D = q.shape N = k.shape[2] + Hk = k.shape[1] + assert H % Hk == 0, "number of heads in q must be a multiple of that in k" + num_groups = H // Hk + P_SEQ = N - M if sm_scale is None: @@ -34,7 +38,7 @@ def total_attention(q, k, l, causal=False, sm_scale=None): q, k, l, tot_attn, sm_scale, q.stride(0), q.stride(1), q.stride(2), q.stride(3), k.stride(0), k.stride(1), k.stride(2), k.stride(3), - B, H, M, N, P_SEQ, + B, H, M, N, P_SEQ, num_groups, BLOCK_M=BLOCK_M, BLOCK_DMODEL=D, BLOCK_N=BLOCK_N, CAUSAL=causal, DIVISIBLE_M=divisible_m, DIVISIBLE_N=divisible_n, @@ -48,7 +52,7 @@ def _total_attention_kernel( Q, K, L, TA, sm_scale, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, - Z, H, M, N, P_SEQ, + Z, H, M, N, P_SEQ, num_groups, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, CAUSAL: tl.constexpr, DIVISIBLE_M: tl.constexpr, DIVISIBLE_N: tl.constexpr, @@ -61,8 +65,9 @@ def _total_attention_kernel( qk_scale = sm_scale * log2e # offset pointers for (batch, head) + off_hk = off_h // num_groups Q += off_z * stride_qz + off_h * stride_qh - K += off_z * stride_kz + off_h * stride_kh + K += off_z * stride_kz + off_hk * stride_kh L += (off_z * H + off_h) * M TA += (off_z * H + off_h) * N # (b, h, n) diff --git a/tests/flag_attn/test_flash_attention.py b/tests/flag_attn/test_flash_attention.py index 8c6d06b..df6ca20 100644 --- a/tests/flag_attn/test_flash_attention.py +++ b/tests/flag_attn/test_flash_attention.py @@ -19,30 +19,40 @@ def report(name, actual, expected): @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) @pytest.mark.parametrize('scale', [1.0, 2.0, 3.0, 4.0]) -@pytest.mark.parametrize('B, H, M, N, D', [ - (2, 4, 512, 612, 128), - (2, 4, 1024, 1034, 64), - (2, 4, 2048, 2048, 32), - (2, 4, 4096, 4096, 16), - (2, 4, 4001, 4001, 32), - (2, 4, 4001, 4096, 64), - (2, 4, 4096, 4000, 128), - (1, 2, 8192, 8202, 16), - (1, 2, 8192, 8192, 32), +@pytest.mark.parametrize('B, Hq, Hk, M, N, D', [ + (2, 4, 4, 512, 612, 128), + (2, 4, 4, 1024, 1034, 64), + (2, 4, 4, 2048, 2048, 32), + (2, 4, 4, 4096, 4096, 16), + (2, 4, 4, 4001, 4001, 32), + (2, 4, 4, 4001, 4096, 64), + (2, 4, 4, 4096, 4000, 128), + (1, 2, 2, 8192, 8202, 16), + (1, 2, 2, 8192, 8192, 32), + # test for mqa/gqa + (2, 4, 2, 512, 612, 128), + (2, 4, 1, 1024, 1034, 64), + (2, 4, 2, 2048, 2048, 32), + (2, 4, 1, 4096, 4096, 16), + (2, 4, 2, 4001, 4001, 32), + (2, 4, 1, 4001, 4096, 64), + (2, 4, 2, 4096, 4000, 128), + (1, 2, 1, 8192, 8202, 16), + (1, 2, 1, 8192, 8192, 32), ]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) -def test_attention_fwd(B, H, M, N, D, causal, stride_order, dtype, scale, device_id): +def test_attention_fwd(B, Hq, Hk, M, N, D, causal, stride_order, dtype, scale, device_id): device = f"cuda:{device_id}" if stride_order == "BHTD": - q = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) - k = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) - v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) + q = torch.empty((B, Hq, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) + k = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) + v = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) else: - q = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) - k = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) - v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) + q = torch.empty((B, M, Hq, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) + k = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) + v = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) o_ref = flag_attn.testing.flash_attention(q, k, v, causal, upcast=True) o_torch = flag_attn.testing.flash_attention(q, k, v, causal, upcast=False) @@ -57,30 +67,40 @@ def test_attention_fwd(B, H, M, N, D, causal, stride_order, dtype, scale, device @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) @pytest.mark.parametrize('scale', [10.0]) -@pytest.mark.parametrize('B, H, M, N, D', [ - (2, 4, 1, 612, 128), - (2, 4, 1, 1034, 64), - (2, 4, 1, 2048, 32), - (2, 4, 1, 4096, 16), - (2, 4, 1, 4001, 32), - (2, 4, 1, 4096, 64), - (2, 4, 2, 4000, 128), - (1, 2, 4, 8202, 16), - (1, 2, 1, 8192, 32), +@pytest.mark.parametrize('B, Hq, Hk, M, N, D', [ + (2, 4, 4, 1, 612, 128), + (2, 4, 4, 1, 1034, 64), + (2, 4, 4, 1, 2048, 32), + (2, 4, 4, 1, 4096, 16), + (2, 4, 4, 1, 4001, 32), + (2, 4, 4, 1, 4096, 64), + (2, 4, 4, 2, 4000, 128), + (1, 2, 2, 4, 8202, 16), + (1, 2, 2, 1, 8192, 32), + # test for mqa/gqa + (2, 4, 2, 1, 612, 128), + (2, 4, 1, 1, 1034, 64), + (2, 4, 2, 1, 2048, 32), + (2, 4, 1, 1, 4096, 16), + (2, 4, 2, 1, 4001, 32), + (2, 4, 1, 1, 4096, 64), + (2, 4, 2, 2, 4000, 128), + (1, 2, 1, 4, 8202, 16), + (1, 2, 1, 1, 8192, 32), ]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) -def test_attention_splitkv(B, H, M, N, D, causal, stride_order, dtype, scale, device_id): +def test_attention_splitkv(B, Hq, Hk, M, N, D, causal, stride_order, dtype, scale, device_id): device = f"cuda:{device_id}" if stride_order == "BHTD": - q = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) - k = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) - v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) + q = torch.empty((B, Hq, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale) + k = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) + v = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale) else: - q = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) - k = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) - v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) + q = torch.empty((B, M, Hq, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) + k = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) + v = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2) o_ref = flag_attn.testing.flash_attention(q, k, v, causal, upcast=True) o_torch = flag_attn.testing.flash_attention(q, k, v, causal, upcast=False) @@ -94,33 +114,44 @@ def test_attention_splitkv(B, H, M, N, D, causal, stride_order, dtype, scale, de @pytest.mark.parametrize('device_id', list(range(torch.cuda.device_count()))) @pytest.mark.parametrize('scale', [1.0, 2.0, 3.0, 4.0]) -@pytest.mark.parametrize('B, H, M, N, D', [ - (2, 4, 512, 612, 128), - (2, 4, 1024, 1034, 64), - (2, 4, 2048, 2048, 32), - (2, 4, 4096, 4096, 16), - (2, 4, 4001, 4001, 32), - (2, 4, 4001, 4096, 64), - (2, 4, 4096, 4001, 128), - (1, 2, 8192, 8202, 16), - (1, 2, 8192, 8192, 32), - (2, 4, 10006, 10, 128), +@pytest.mark.parametrize('B, Hq, Hk, M, N, D', [ + (2, 4, 4, 512, 612, 128), + (2, 4, 4, 1024, 1034, 64), + (2, 4, 4, 2048, 2048, 32), + (2, 4, 4, 4096, 4096, 16), + (2, 4, 4, 4001, 4001, 32), + (2, 4, 4, 4001, 4096, 64), + (2, 4, 4, 4096, 4001, 128), + (1, 2, 2, 8192, 8202, 16), + (1, 2, 2, 8192, 8192, 32), + (2, 4, 4, 10006, 10, 128), + # test for mqa/gqa + (2, 4, 2, 512, 612, 128), + (2, 4, 1, 1024, 1034, 64), + (2, 4, 2, 2048, 2048, 32), + (2, 4, 1, 4096, 4096, 16), + (2, 4, 2, 4001, 4001, 32), + (2, 4, 1, 4001, 4096, 64), + (2, 4, 2, 4096, 4001, 128), + (1, 2, 1, 8192, 8202, 16), + (1, 2, 1, 8192, 8192, 32), + (2, 4, 2, 10006, 10, 128), ]) @pytest.mark.parametrize('causal', [True, False]) @pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16]) @pytest.mark.parametrize('stride_order', ['BHTD', 'BTHD']) -def test_attention_bwd(B, H, M, N, D, causal, stride_order, dtype, scale, device_id): +def test_attention_bwd(B, Hq, Hk, M, N, D, causal, stride_order, dtype, scale, device_id): device = f"cuda:{device_id}" if stride_order == "BHTD": - q = torch.empty((B, H, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() - k = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() - v = torch.empty((B, H, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() - do = torch.randn((B, H, M, D), dtype=dtype, device=device) + q = torch.empty((B, Hq, M, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() + k = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() + v = torch.empty((B, Hk, N, D), dtype=dtype, device=device).normal_(mean=0., std=scale).requires_grad_() + do = torch.randn((B, Hq, M, D), dtype=dtype, device=device) else: - q = torch.empty((B, M, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() - k = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() - v = torch.empty((B, N, H, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() - do = torch.randn((B, M, H, D), dtype=dtype, device=device).transpose(1, 2) + q = torch.empty((B, M, Hq, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() + k = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() + v = torch.empty((B, N, Hk, D), dtype=dtype, device=device).normal_(mean=0., std=scale).transpose(1, 2).requires_grad_() + do = torch.randn((B, M, Hq, D), dtype=dtype, device=device).transpose(1, 2) o_ref = flag_attn.testing.flash_attention(q, k, v, causal=causal, upcast=True) o_torch = flag_attn.testing.flash_attention(q, k, v, causal=causal, upcast=False)