Skip to content

Commit

Permalink
Unify with new forward tests and set num_stages
Browse files Browse the repository at this point in the history
  • Loading branch information
groenenboomj committed Aug 12, 2024
1 parent 51d0d92 commit ae4633c
Show file tree
Hide file tree
Showing 2 changed files with 525 additions and 360 deletions.
132 changes: 63 additions & 69 deletions python/perf-kernels/flash-attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, k_ptrs, v_ptrs, bias_ptrs, stride_kn, stri
num_warps=4),
],
key=['IS_CAUSAL', 'dropout_p', 'MAX_SEQLENS_Q', 'MAX_SEQLENS_K', 'ACTUAL_BLOCK_DMODEL', 'VARLEN', 'HQ', 'HK'],
use_cuda_graph=True,
#use_cuda_graph=True,
)
@triton.jit
def attn_fwd(Q, K, V, bias, SM_SCALE: tl.constexpr, L, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz,
Expand Down Expand Up @@ -639,7 +639,7 @@ def _bwd_kernel_dk_dv(dk, dv, Q, k, v, sm_scale, alibi_slope, DO, M, D,
# Compute dP and dS.
dpT = tl.dot(v, tl.trans(do))
dsT = pT * (dpT - Di[None, :])
dsT = dsT.to(qT.dype)
dsT = dsT.to(qT.dtype)
dk += tl.dot(dsT, tl.trans(qT))
# Increment pointers.
curr_m += step_m
Expand Down Expand Up @@ -695,13 +695,12 @@ def _bwd_kernel_dq(dq, q, K, V, do, m, D, alibi_slope,
VT_block_ptr = tl.advance(VT_block_ptr, (0, step_n))
return dq


@triton.jit
def _attn_bwd(Q, K, V, sm_scale, alibi_slopes, DO, DQ, DK, DV, M, D,
# shared by Q/K/V/DO.
stride_z, stride_h, stride_tok, stride_d,
# H = 16, N_CTX = 1024
H, N_CTX, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr,
H, N_CTX, CAUSAL: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_M1: tl.constexpr, BLOCK_N1: tl.constexpr,
BLOCK_M2: tl.constexpr, BLOCK_N2: tl.constexpr, BLK_SLICE_FACTOR: tl.constexpr, USE_ALIBI: tl.constexpr):
LN2: tl.constexpr = 0.6931471824645996 # = ln(2)

Expand Down Expand Up @@ -943,6 +942,7 @@ def backward(ctx, do, _):
BLOCK = 64
else:
BLOCK = 128
num_stages = 1
q, k, v, o, M = ctx.saved_tensors
assert do.is_contiguous()
assert q.stride() == k.stride() == v.stride() == o.stride() == do.stride()
Expand Down Expand Up @@ -1007,6 +1007,7 @@ def backward(ctx, do, _):
BLOCK_N2=BLOCK_N2,
BLK_SLICE_FACTOR=BLK_SLICE_FACTOR,
USE_ALIBI=False if ctx.alibi_slopes is None else True,
num_stages = 1,
)

return dq, dk, dv, None, None
Expand Down Expand Up @@ -1260,100 +1261,93 @@ def test_op_varlen_mqa_fwd(Z, HQ, HK, N_CTX, D_HEAD, causal, dtype=torch.float16
#(1, 16, 8192, 63),
#(1, 16, 1022, 64),
])
@pytest.mark.parametrize('qseqlen_not_equal_kseqlen', [None])
@pytest.mark.parametrize('torch_sdpa_test', [False, True])
@pytest.mark.parametrize('causal', [False,True])
@pytest.mark.parametrize('causal', [False, True])
@pytest.mark.parametrize('use_alibi', [False, True])
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
def test_op_bwd(Z, H, N_CTX, D_HEAD, qseqlen_not_equal_kseqlen, causal, torch_sdpa_test, use_alibi,
dtype):
pytest.skip()
torch.manual_seed(20)
if qseqlen_not_equal_kseqlen is not None:
seqlen_q = qseqlen_not_equal_kseqlen
else:
seqlen_q = N_CTX
seqlen_k = N_CTX

if causal and ((N_CTX - 1) & N_CTX):
pytest.skip()
if causal and seqlen_q != seqlen_k:
pytest.skip()

sm_scale = D_HEAD**-0.5
input_metadata = MetaData(sm_scale=sm_scale)
input_metadata.max_seqlens_q = seqlen_q
input_metadata.max_seqlens_k = seqlen_k

dropout_p = 0
q = (torch.empty((Z, H, seqlen_q, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
k = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
v = (torch.empty((Z, H, seqlen_k, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.0, std=0.5).requires_grad_())
o = torch.empty_like(q)

@pytest.mark.parametrize('layout', ['bhsd'])
def test_op_bwd(Z, H, N_CTX, D_HEAD, causal, use_alibi,
layout, dtype):
torch.manual_seed(20)

N_CTX_Q = N_CTX_K = N_CTX
HQ = HK = H

q, k, v, input_metadata = input_helper(Z, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, layout)
dout = torch.randn_like(q)

if causal:
input_metadata.need_causal()

if use_alibi and not torch_sdpa_test:
if use_alibi:
# for n heads the set of slopes is the geometric sequence that starts 2^(-8/n)
alibi_slopes = torch.tensor([2**(-8 / H * i) for i in range(1, H + 1)], dtype=torch.float32,
alibi_slopes = torch.tensor([2**(-8 / HQ * i) for i in range(1, HQ + 1)], dtype=torch.float32,
device="cuda").repeat(Z, 1)
input_metadata.need_alibi(alibi_slopes, Z, H)
dout = torch.randn_like(q)
# reference implementation
if torch_sdpa_test:
ref_out, ref_softmax = torch.ops.aten._scaled_dot_product_attention_math(q, k, v, dropout_p=dropout_p,
is_causal=causal, scale=sm_scale,
dropout_mask=None)
ref_out.backward(dout.to(device=ref_out.device, dtype=ref_out.dtype))
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
input_metadata.need_alibi(alibi_slopes, Z, HQ)
else:
M = torch.tril(torch.ones((seqlen_q, seqlen_k), device="cuda"))
p = torch.matmul(q, k.transpose(2, 3)) * sm_scale
if use_alibi:
p += compute_alibi_tensor(alibi_slopes, N_CTX, N_CTX)
if causal:
p[:, :, M == 0] = float("-inf")
alibi_slopes = None

p = torch.softmax(p.float(), dim=-1).type(dtype=p.dtype)
ref_out = torch.matmul(p, v)
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None
o = torch.empty_like(q)

# # triton implementation
# triton implementation
tri_out, _ = attention(q, k, v, o, input_metadata)
tri_out.backward(dout)
tri_dv, v.grad = v.grad.clone(), None
tri_dk, k.grad = k.grad.clone(), None
tri_dq, q.grad = q.grad.clone(), None
# test
#print("reference")
#print(ref_dv)
#print("tri")
#print(tri_dv)

# Transpose here if layout is bshd so we have same reference code for all layouts
if layout == 'bshd':
q = q.transpose(1, 2).clone()
k = k.transpose(1, 2).clone()
v = v.transpose(1, 2).clone()
# Replicate K and V if using MQA/GQA
if HQ != HK:
k = k.view(k.shape[0], k.shape[1], -1, k.shape[2],
k.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(k.shape[0], -1, k.shape[2], k.shape[3])
v = v.view(v.shape[0], v.shape[1], -1, v.shape[2],
v.shape[3]).expand(-1, -1, HQ // HK, -1, -1).reshape(v.shape[0], -1, v.shape[2], v.shape[3])

scores = torch.einsum('bhqd,bhkd->bhqk', q, k).float() * input_metadata.sm_scale
if causal:
mask = torch.tril(torch.ones(N_CTX_Q, N_CTX_K, device="cuda"), diagonal=N_CTX_K - N_CTX_Q)
scores[:, :, mask == 0] = float("-inf")
if use_alibi:
scores += compute_alibi_tensor(alibi_slopes, N_CTX_Q, N_CTX_K)

p = torch.softmax(scores, dim=-1)
if causal:
# If N_CTX_Q > N_CTX_K, there is at least one row of all -infs going into
# the softmax. This produces a row of NaNs as -inf - -inf == NaN. So we fix
# this by converting the NaNs to 0s, which is what they should be out of the softmax.
nan_mask = torch.isnan(p)
p = torch.where(nan_mask == 1,0,p)
ref_out = torch.einsum('bhqk,bhkd->bhqd', p.to(dtype), v)
# compare
torch.testing.assert_close(ref_out, tri_out, atol=1e-2, rtol=0)
if layout == 'bshd':
ref_out = ref_out.transpose(1, 2).clone()
ref_out.backward(dout)
ref_dv, v.grad = v.grad.clone(), None
ref_dk, k.grad = k.grad.clone(), None
ref_dq, q.grad = q.grad.clone(), None

torch.testing.assert_close(ref_out, tri_out, atol=2e-2, rtol=2e-2)

# The current block size for MI200 series is 64x64. This results in
# larger differences in float results due to rounding.

if dtype == torch.bfloat16:
ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0)
ATOL = 1e-1 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0)
if dtype == torch.float32:
ATOL = 1e-3 * max(1.0, (seqlen_q + D_HEAD) / 64.0)
ATOL = 1e-3 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0)
else:
ATOL = 1e-1 * max(1.0, (seqlen_q + D_HEAD) / 64.0)
ATOL = 1e-1 * max(1.0, (N_CTX_Q + D_HEAD) / 64.0)

RTOL = 0

torch.testing.assert_close(ref_dv, tri_dv, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(ref_dk, tri_dk, atol=ATOL, rtol=RTOL)
torch.testing.assert_close(ref_dq, tri_dq, atol=ATOL, rtol=RTOL)


def nonvarlen_benchmark_configs():
configs = [
(16, 16, 16, 1024, 1024),
Expand Down
Loading

0 comments on commit ae4633c

Please sign in to comment.