From 19ed7288105f786e7d34754012e748de5eb18c5e Mon Sep 17 00:00:00 2001 From: Yi Qian Date: Mon, 10 Jun 2024 21:30:57 +0000 Subject: [PATCH 1/2] Add a script for tuning flash attention kernels --- .../generated_fa_kernel16_16_1024_128.py | 1000 +++++++++++++++++ python/perf-kernels/tune_fa.py | 301 +++++ .../amd/gemm/flash_attention_fwd_kernel.py | 98 ++ scripts/amd/gemm/tune_fa.py | 282 +++++ 4 files changed, 1681 insertions(+) create mode 100644 python/perf-kernels/generated_fa_kernel16_16_1024_128.py create mode 100644 python/perf-kernels/tune_fa.py create mode 100644 scripts/amd/gemm/flash_attention_fwd_kernel.py create mode 100644 scripts/amd/gemm/tune_fa.py diff --git a/python/perf-kernels/generated_fa_kernel16_16_1024_128.py b/python/perf-kernels/generated_fa_kernel16_16_1024_128.py new file mode 100644 index 000000000000..00c7aa66cb49 --- /dev/null +++ b/python/perf-kernels/generated_fa_kernel16_16_1024_128.py @@ -0,0 +1,1000 @@ +import pytest +import torch +import sys + +import triton +import triton.language as tl +# Pick the fp8 data type + +# AMD E5M2B16 +# float8:tl.constexpr = torch.float8_e5m2fnuz + +# AMD E4M3B8 +# Note: When picking this f8 data type, scaling is required when using f8 +# for the second gemm +TORCH_HAS_FP8E4 = hasattr(torch, 'float8_e4m3fnuz') +float8:tl.constexpr = None if not TORCH_HAS_FP8E4 else torch.float8_e4m3fnuz + + +@triton.jit +def _attn_fwd_BLOCKM16_BLOCKN16_PreloadvTrue( + Q, K, V, sm_scale, M, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qkv_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(0, 1) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + q = tl.load(Q_block_ptr) + # it's even better to multiply the qk_scale and convert to f16 + # than doing it inside the loop + # So conversion is quite cheap + q = (q * qk_scale).to(q.dtype) + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + #qk = (qk * qk_scale) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + acc = acc / l_i[:, None] + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_fwd_BLOCKM16_BLOCKN16_PreloadvFalse( + Q, K, V, sm_scale, M, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qkv_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(0, 1) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + q = tl.load(Q_block_ptr) + # it's even better to multiply the qk_scale and convert to f16 + # than doing it inside the loop + # So conversion is quite cheap + q = (q * qk_scale).to(q.dtype) + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + #qk = (qk * qk_scale) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + acc = acc / l_i[:, None] + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_fwd_BLOCKM16_BLOCKN32_PreloadvTrue( + Q, K, V, sm_scale, M, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qkv_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(0, 1) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + q = tl.load(Q_block_ptr) + # it's even better to multiply the qk_scale and convert to f16 + # than doing it inside the loop + # So conversion is quite cheap + q = (q * qk_scale).to(q.dtype) + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + #qk = (qk * qk_scale) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + acc = acc / l_i[:, None] + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_fwd_BLOCKM16_BLOCKN32_PreloadvFalse( + Q, K, V, sm_scale, M, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qkv_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(0, 1) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + q = tl.load(Q_block_ptr) + # it's even better to multiply the qk_scale and convert to f16 + # than doing it inside the loop + # So conversion is quite cheap + q = (q * qk_scale).to(q.dtype) + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + #qk = (qk * qk_scale) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + acc = acc / l_i[:, None] + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_fwd_BLOCKM32_BLOCKN16_PreloadvTrue( + Q, K, V, sm_scale, M, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qkv_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(0, 1) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + q = tl.load(Q_block_ptr) + # it's even better to multiply the qk_scale and convert to f16 + # than doing it inside the loop + # So conversion is quite cheap + q = (q * qk_scale).to(q.dtype) + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + #qk = (qk * qk_scale) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + acc = acc / l_i[:, None] + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_fwd_BLOCKM32_BLOCKN16_PreloadvFalse( + Q, K, V, sm_scale, M, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qkv_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(0, 1) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + q = tl.load(Q_block_ptr) + # it's even better to multiply the qk_scale and convert to f16 + # than doing it inside the loop + # So conversion is quite cheap + q = (q * qk_scale).to(q.dtype) + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + #qk = (qk * qk_scale) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + acc = acc / l_i[:, None] + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_fwd_BLOCKM32_BLOCKN32_PreloadvTrue( + Q, K, V, sm_scale, M, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qkv_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(0, 1) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + q = tl.load(Q_block_ptr) + # it's even better to multiply the qk_scale and convert to f16 + # than doing it inside the loop + # So conversion is quite cheap + q = (q * qk_scale).to(q.dtype) + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + #qk = (qk * qk_scale) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + acc = acc / l_i[:, None] + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +@triton.jit +def _attn_fwd_BLOCKM32_BLOCKN32_PreloadvFalse( + Q, K, V, sm_scale, M, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qkv_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(0, 1) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + q = tl.load(Q_block_ptr) + # it's even better to multiply the qk_scale and convert to f16 + # than doing it inside the loop + # So conversion is quite cheap + q = (q * qk_scale).to(q.dtype) + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + #qk = (qk * qk_scale) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + acc = acc / l_i[:, None] + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + + +name_to_torch_types = { + 'fp16': torch.float16, + 'bf16': torch.bfloat16, + 'fp8': float8 +} + +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal=False, dtype='fp16'): + if dtype == 'fp8' and not TORCH_HAS_FP8E4: + sys.exit("fp8 is not available") + init_dtype = torch.float16 if dtype != 'bf16' else torch.bfloat16 + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, D_HEAD, N_CTX), dtype=init_dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + # q,k casting for partial fp8 + q = q.to(name_to_torch_types[dtype]) + k = k.to(name_to_torch_types[dtype]) + + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-2] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q, dtype=v.dtype) + waves_per_eu = 2 + num_warps = 4 + num_stages = 1 + slice_k_tile = 32 + kpack = 1 + + M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + + + for i in range(100): + grid = ( triton.cdiv(q.shape[2], 16), q.shape[0] * q.shape[1], 1) + _attn_fwd_BLOCKM16_BLOCKN16_PreloadvTrue[grid]( + q, k, v, sm_scale, M, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + BLOCK_M = 16, + BLOCK_N = 16, + waves_per_eu = waves_per_eu, + num_warps = num_warps, + num_stages = num_stages, + pre_load_v = True, + slice_k_tile = slice_k_tile, + kpack = kpack, + ) + + for i in range(100): + grid = ( triton.cdiv(q.shape[2], 16), q.shape[0] * q.shape[1], 1) + _attn_fwd_BLOCKM16_BLOCKN16_PreloadvFalse[grid]( + q, k, v, sm_scale, M, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + BLOCK_M = 16, + BLOCK_N = 16, + waves_per_eu = waves_per_eu, + num_warps = num_warps, + num_stages = num_stages, + pre_load_v = False, + slice_k_tile = slice_k_tile, + kpack = kpack, + ) + + for i in range(100): + grid = ( triton.cdiv(q.shape[2], 16), q.shape[0] * q.shape[1], 1) + _attn_fwd_BLOCKM16_BLOCKN32_PreloadvTrue[grid]( + q, k, v, sm_scale, M, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + BLOCK_M = 16, + BLOCK_N = 32, + waves_per_eu = waves_per_eu, + num_warps = num_warps, + num_stages = num_stages, + pre_load_v = True, + slice_k_tile = slice_k_tile, + kpack = kpack, + ) + + for i in range(100): + grid = ( triton.cdiv(q.shape[2], 16), q.shape[0] * q.shape[1], 1) + _attn_fwd_BLOCKM16_BLOCKN32_PreloadvFalse[grid]( + q, k, v, sm_scale, M, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + BLOCK_M = 16, + BLOCK_N = 32, + waves_per_eu = waves_per_eu, + num_warps = num_warps, + num_stages = num_stages, + pre_load_v = False, + slice_k_tile = slice_k_tile, + kpack = kpack, + ) + + for i in range(100): + grid = ( triton.cdiv(q.shape[2], 32), q.shape[0] * q.shape[1], 1) + _attn_fwd_BLOCKM32_BLOCKN16_PreloadvTrue[grid]( + q, k, v, sm_scale, M, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + BLOCK_M = 32, + BLOCK_N = 16, + waves_per_eu = waves_per_eu, + num_warps = num_warps, + num_stages = num_stages, + pre_load_v = True, + slice_k_tile = slice_k_tile, + kpack = kpack, + ) + + for i in range(100): + grid = ( triton.cdiv(q.shape[2], 32), q.shape[0] * q.shape[1], 1) + _attn_fwd_BLOCKM32_BLOCKN16_PreloadvFalse[grid]( + q, k, v, sm_scale, M, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + BLOCK_M = 32, + BLOCK_N = 16, + waves_per_eu = waves_per_eu, + num_warps = num_warps, + num_stages = num_stages, + pre_load_v = False, + slice_k_tile = slice_k_tile, + kpack = kpack, + ) + + for i in range(100): + grid = ( triton.cdiv(q.shape[2], 32), q.shape[0] * q.shape[1], 1) + _attn_fwd_BLOCKM32_BLOCKN32_PreloadvTrue[grid]( + q, k, v, sm_scale, M, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + BLOCK_M = 32, + BLOCK_N = 32, + waves_per_eu = waves_per_eu, + num_warps = num_warps, + num_stages = num_stages, + pre_load_v = True, + slice_k_tile = slice_k_tile, + kpack = kpack, + ) + + for i in range(100): + grid = ( triton.cdiv(q.shape[2], 32), q.shape[0] * q.shape[1], 1) + _attn_fwd_BLOCKM32_BLOCKN32_PreloadvFalse[grid]( + q, k, v, sm_scale, M, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + BLOCK_M = 32, + BLOCK_N = 32, + waves_per_eu = waves_per_eu, + num_warps = num_warps, + num_stages = num_stages, + pre_load_v = False, + slice_k_tile = slice_k_tile, + kpack = kpack, + ) + + +def main(): + bench_flash_attention(16, 16, 1024, 128) + +if __name__ == '__main__': + sys.exit(main()) + diff --git a/python/perf-kernels/tune_fa.py b/python/perf-kernels/tune_fa.py new file mode 100644 index 000000000000..9467c4e4cbe4 --- /dev/null +++ b/python/perf-kernels/tune_fa.py @@ -0,0 +1,301 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Extra Credits: +- Original flash attention paper (https://arxiv.org/abs/2205.14135) +- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) +- Adam P. Goucher for simplified vector math + +""" + +import argparse +import pytest +import torch +import sys + +import triton +import triton.language as tl + +# Pick the fp8 data type + +# AMD E5M2B16 +# float8:tl.constexpr = torch.float8_e5m2fnuz + +# AMD E4M3B8 +# Note: When picking this f8 data type, scaling is required when using f8 +# for the second gemm +TORCH_HAS_FP8E4 = hasattr(torch, 'float8_e4m3fnuz') +float8:tl.constexpr = None if not TORCH_HAS_FP8E4 else torch.float8_e4m3fnuz + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + +def generate_one_fa_kernel_from_config(Batch, H, N_Ctx, D_Head, block_m, block_n, pre_load_v): + attn_fwd_str = f""" +@triton.jit +def _attn_fwd_BLOCKM{block_m}_BLOCKN{block_n}_Preloadv{pre_load_v}( + Q, K, V, sm_scale, M, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qkv_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(0, 1) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + q = tl.load(Q_block_ptr) + # it's even better to multiply the qk_scale and convert to f16 + # than doing it inside the loop + # So conversion is quite cheap + q = (q * qk_scale).to(q.dtype) + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + #qk = (qk * qk_scale) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + acc = acc / l_i[:, None] + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) + """ + return attn_fwd_str + +def generate_wrapper(tuning_parms): + dri_str = """ +name_to_torch_types = { + 'fp16': torch.float16, + 'bf16': torch.bfloat16, + 'fp8': float8 +} + +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal=False, dtype='fp16'): + if dtype == 'fp8' and not TORCH_HAS_FP8E4: + sys.exit("fp8 is not available") + init_dtype = torch.float16 if dtype != 'bf16' else torch.bfloat16 + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, D_HEAD, N_CTX), dtype=init_dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + # q,k casting for partial fp8 + q = q.to(name_to_torch_types[dtype]) + k = k.to(name_to_torch_types[dtype]) + + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-2] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q, dtype=v.dtype) + waves_per_eu = 2 + num_warps = 4 + num_stages = 1 + slice_k_tile = 32 + kpack = 1 + + M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + """ + dri_str += '\n' + for tp in tuning_parms: + block_m = tp[0] + block_n = tp[1] + pre_load_v = tp[2] + dri_str += f""" + for i in range(100): + grid = ( triton.cdiv(q.shape[2], {block_m}), q.shape[0] * q.shape[1], 1) + _attn_fwd_BLOCKM{block_m}_BLOCKN{block_n}_Preloadv{pre_load_v}[grid]( + q, k, v, sm_scale, M, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + BLOCK_M = {block_m}, + BLOCK_N = {block_n}, + waves_per_eu = waves_per_eu, + num_warps = num_warps, + num_stages = num_stages, + pre_load_v = {pre_load_v}, + slice_k_tile = slice_k_tile, + kpack = kpack, + ) + """ + + return dri_str + +def generate_main(Batch, H, N_Ctx, D_Head): + main_str = f""" +def main(): + bench_flash_attention({Batch}, {H}, {N_Ctx}, {D_Head}) + +if __name__ == '__main__': + sys.exit(main()) + """ + +def generate_fa_kernel(Batch, H, N_Ctx, D_Head): + # create the kernel file + file_name = f"{Batch}_{H}_{N_Ctx}_{D_Head}.py" + f_kernel = open("./generated_fa_kernel"+file_name, 'w') + + # import string + import_str = """import pytest +import torch +import sys + +import triton +import triton.language as tl +# Pick the fp8 data type + +# AMD E5M2B16 +# float8:tl.constexpr = torch.float8_e5m2fnuz + +# AMD E4M3B8 +# Note: When picking this f8 data type, scaling is required when using f8 +# for the second gemm +TORCH_HAS_FP8E4 = hasattr(torch, 'float8_e4m3fnuz') +float8:tl.constexpr = None if not TORCH_HAS_FP8E4 else torch.float8_e4m3fnuz +""" + + f_kernel.write(import_str + '\n') + + # generate kernels with tuning parameters + tuning_parms = [] + block_m_range = [16, 32] + block_n_range = [16, 32] + pre_load_v_range = [True, False] + + for block_m in block_m_range: + for block_n in block_n_range: + for pre_load_v in pre_load_v_range: + tuning_parms.append((block_m, block_n, pre_load_v)) + kernel_str = generate_one_fa_kernel_from_config(Batch, H, N_Ctx, D_Head, block_m, block_n, pre_load_v) + f_kernel.write(kernel_str + "\n") + + # generate the driver + dri_str = generate_wrapper(tuning_parms) + f_kernel.write(dri_str + "\n") + + main_str = f""" +def main(): + bench_flash_attention({Batch}, {H}, {N_Ctx}, {D_Head}) + +if __name__ == '__main__': + sys.exit(main()) + """ + f_kernel.write(main_str+'\n') + +def parse_args(): + parser = argparse.ArgumentParser( + prog="tune a flash attention kernel", + allow_abbrev=False, + ) + + parser.add_argument("-b", type=int, default=16, help='batch') + parser.add_argument("-H", type=int, default=16) + parser.add_argument("-n_ctx", type=int, default=1024) + parser.add_argument("-d_head", type=int, default=128) + parser.add_argument("--keep", action='store_true', default=False, help='keep generated files') + parser.add_argument("--verbose", action='store_true', default=False, help="enables time_breakdown and additional logging messages") + parser.add_argument("--num_threads", type=int, default=16, help="number of threads to use for kernel compilation and post processing") + parser.add_argument("--jobs", type=int, default=1, help="number of generated files") + parser.add_argument("--iters", type=int, default=1000, help="number of iterations") + parser.add_argument("--datatype", type=str, default='fp16', help="element type") + parser.add_argument("--no_warmup", action='store_true', default=False, help="Do not call the warmup kernel") + + args = parser.parse_args() + return args + +def main(): + args = parse_args() + keepTmp = args.keep + jobs = args.jobs + iters = args.iters + skipWarmup = args.no_warmup + + # Get element type + dtype = args.datatype + + mnks = [] + # TODO: make it more robust to get user input + batch = args.b + #h = args.h + h = 16 + n_ctx = args.n_ctx + d_head = args.d_head + generate_fa_kernel(batch, h, n_ctx, d_head) + +if __name__ == '__main__': + sys.exit(main()) diff --git a/scripts/amd/gemm/flash_attention_fwd_kernel.py b/scripts/amd/gemm/flash_attention_fwd_kernel.py new file mode 100644 index 000000000000..0e4b03fcf644 --- /dev/null +++ b/scripts/amd/gemm/flash_attention_fwd_kernel.py @@ -0,0 +1,98 @@ +import triton +import triton.language as tl + +@triton.jit +def _attn_fwd_kernel( + Q, K, V, sm_scale, M, Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vn, stride_vk, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, + N_CTX, + BLOCK_DMODEL: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, + pre_load_v: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qkv_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qkv_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(0, 1) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # scale sm_scale by log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + q = tl.load(Q_block_ptr) + # it's even better to multiply the qk_scale and convert to f16 + # than doing it inside the loop + # So conversion is quite cheap + q = (q * qk_scale).to(q.dtype) + lo, hi = 0, N_CTX + # loop over k, v and update accumulator + for start_n in range(lo, hi, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load(K_block_ptr) + if pre_load_v: + v = tl.load(V_block_ptr) + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + #qk = (qk * qk_scale) + m_ij = tl.maximum(m_i, tl.max(qk, 1)) + qk = qk - m_ij[:, None] + p = tl.math.exp2(qk) + # -- update output accumulator -- + alpha = tl.math.exp2(m_i - m_ij) + acc = acc * alpha[:, None] + if not pre_load_v: + v = tl.load(V_block_ptr) + acc += tl.dot(p.to(v.dtype), v) + # -- update m_i and l_i + l_ij = tl.sum(p, 1) + l_i = l_i * alpha + l_ij + # update m_i and l_i + m_i = m_ij + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + acc = acc / l_i[:, None] + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qkv_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + tl.store(O_block_ptr, acc.to(Out.type.element_ty)) diff --git a/scripts/amd/gemm/tune_fa.py b/scripts/amd/gemm/tune_fa.py new file mode 100644 index 000000000000..4b622fd3cade --- /dev/null +++ b/scripts/amd/gemm/tune_fa.py @@ -0,0 +1,282 @@ +""" +Fused Attention +=============== + +This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) + +Extra Credits: +- Original flash attention paper (https://arxiv.org/abs/2205.14135) +- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) +- Adam P. Goucher for simplified vector math + +""" + +import argparse +import pytest +import torch +import sys +import yaml +import csv +import re +import pandas as pd +import os + +import triton +import triton.language as tl + +from datetime import datetime +import subprocess + +# Pick the fp8 data type + +# AMD E5M2B16 +# float8:tl.constexpr = torch.float8_e5m2fnuz + +# AMD E4M3B8 +# Note: When picking this f8 data type, scaling is required when using f8 +# for the second gemm +TORCH_HAS_FP8E4 = hasattr(torch, 'float8_e4m3fnuz') +float8:tl.constexpr = None if not TORCH_HAS_FP8E4 else torch.float8_e4m3fnuz + +def run_bash_command_wrapper(commandstring, capture=True): + try: + run_bash_command(commandstring, capture) + except subprocess.CalledProcessError as e: + if not capture: + print(f"running {commandstring} one more time") + run_bash_command(commandstring, capture) + +def run_bash_command(commandstring, capture=True): + if capture: + proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash', stdout=subprocess.PIPE) + return proc.stdout.splitlines() + proc = subprocess.run(commandstring, shell=True, check=True, executable='/bin/bash') + return None + +def format_output(unformatted): + if unformatted < 0.0001: + formatted = "{:.3e}".format(unformatted) + elif unformatted > 1000: + formatted = "{:.1f}".format(unformatted) + else: + formatted = "{:.2f}".format(unformatted) + return formatted + +@triton.jit +def max_fn(x, y): + return tl.math.max(x, y) + +def generate_wrapper(tuning_parms): + dri_str = """ +name_to_torch_types = { + 'fp16': torch.float16, + 'bf16': torch.bfloat16, + 'fp8': float8 +} + +def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal=False, dtype='fp16'): + if dtype == 'fp8' and not TORCH_HAS_FP8E4: + sys.exit("fp8 is not available") + init_dtype = torch.float16 if dtype != 'bf16' else torch.bfloat16 + q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) + v = torch.randn((BATCH, H, D_HEAD, N_CTX), dtype=init_dtype, device="cuda", requires_grad=True) + sm_scale = 1.3 + # q,k casting for partial fp8 + q = q.to(name_to_torch_types[dtype]) + k = k.to(name_to_torch_types[dtype]) + + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-2] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q, dtype=v.dtype) + waves_per_eu = 2 + num_warps = 4 + num_stages = 1 + slice_k_tile = 32 + kpack = 1 + + M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + """ + dri_str += '\n' + for tp in tuning_parms: + block_m = tp[0] + block_n = tp[1] + pre_load_v = tp[2] + dri_str += f""" + for i in range(100): + grid = ( triton.cdiv(q.shape[2], {block_m}), q.shape[0] * q.shape[1], 1) + _attn_fwd_BLOCKM_{block_m}_BLOCKN_{block_n}_Preloadv_{pre_load_v}[grid]( + q, k, v, sm_scale, M, o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], + N_CTX=q.shape[2], + BLOCK_DMODEL=Lk, + BLOCK_M = {block_m}, + BLOCK_N = {block_n}, + waves_per_eu = waves_per_eu, + num_warps = num_warps, + num_stages = num_stages, + pre_load_v = {pre_load_v}, + slice_k_tile = slice_k_tile, + kpack = kpack, + ) + """ + + return dri_str + +def generate_main(Batch, H, N_Ctx, D_Head): + main_str = f""" +def main(): + bench_flash_attention({Batch}, {H}, {N_Ctx}, {D_Head}) + +if __name__ == '__main__': + sys.exit(main()) + """ + +def generate_fa_kernel(Batch, H, N_Ctx, D_Head): + # create the kernel file + file_name = f"{Batch}_{H}_{N_Ctx}_{D_Head}.py" + f_kernel = open("./generated_fa_kernel_"+file_name, 'w') + + # import string + import_str = """import pytest +import torch +import sys + +import triton +import triton.language as tl +# Pick the fp8 data type + +# AMD E5M2B16 +# float8:tl.constexpr = torch.float8_e5m2fnuz + +# AMD E4M3B8 +# Note: When picking this f8 data type, scaling is required when using f8 +# for the second gemm +TORCH_HAS_FP8E4 = hasattr(torch, 'float8_e4m3fnuz') +float8:tl.constexpr = None if not TORCH_HAS_FP8E4 else torch.float8_e4m3fnuz +""" + + f_kernel.write(import_str + '\n') + + # generate kernels with tuning parameters + tuning_parms = [] + block_m_range = [16, 32] + block_n_range = [16, 32] + pre_load_v_range = [True, False] + with open(os.path.dirname(os.path.abspath(__file__))+"/flash_attention_fwd_kernel.py") as file: + fa_kernel_code = file.read() + + for block_m in block_m_range: + for block_n in block_n_range: + for pre_load_v in pre_load_v_range: + tuning_parms.append((block_m, block_n, pre_load_v)) + fa_kernel_str = fa_kernel_code.replace("attn_fwd_kernel", f"attn_fwd_BLOCKM_{block_m}_BLOCKN_{block_n}_Preloadv_{pre_load_v}") + fa_kernel_str = fa_kernel_str.replace("import triton.language as tl", "") + fa_kernel_str = fa_kernel_str.replace("import triton", "") + f_kernel.write(fa_kernel_str + "\n") + + # generate the driver + dri_str = generate_wrapper(tuning_parms) + f_kernel.write(dri_str + "\n") + + main_str = f""" +def main(): + bench_flash_attention({Batch}, {H}, {N_Ctx}, {D_Head}) + +if __name__ == '__main__': + sys.exit(main()) + """ + f_kernel.write(main_str+'\n') + f_kernel.close() + +def tune_fa_config(Batch, H, N_Ctx, D_Head, num_threads, verbose): + # create the kernel file + generate_fa_kernel(Batch, H, N_Ctx, D_Head) + run_bash_command("rm -rf ~/.triton/cache") + start_time = datetime.now() + + file_name = f"generated_fa_kernel_{Batch}_{H}_{N_Ctx}_{D_Head}.py" + run_bash_command(f"python {file_name} -n {num_threads}", capture=(verbose < 2)) + compile_end = datetime.now() + compile_time = compile_end - start_time + if verbose: + print(f"compile time: {compile_time}", flush=True) + run_bash_command_wrapper(f"rocprof --stats python {file_name}", capture=(verbose < 2)) + df_prof = pd.read_csv(f"results.stats.csv") + filtered_df = df_prof[df_prof['Name'].str.startswith('_attn_fwd_BLOCK')] + # Find the row with the minimal 'AverageNs' + min_row = filtered_df.loc[filtered_df['AverageNs'].idxmin()] + + splitted_config = min_row["Name"].split('_') + best_config = {'Batch':Batch, 'H':H, 'N_Ctx':N_Ctx, 'D_Head':D_Head} + best_config.update({'Block_M':splitted_config[4], 'Block_N':splitted_config[6], 'Preloadv':splitted_config[8]}) + return min_row['AverageNs'], best_config + +def parse_args(): + parser = argparse.ArgumentParser( + prog="tune a flash attention kernel", + allow_abbrev=False, + ) + + parser.add_argument("-batch", type=int, default=16) + parser.add_argument("-H", type=int, default=16) + parser.add_argument("-n_ctx", type=int, default=1024) + parser.add_argument("-d_head", type=int, default=128) + parser.add_argument("--o", type=str, default='tuning_fa.yaml', help='yaml file to store tuning results') + parser.add_argument("--fa_config_file", type=str, default="", help='yaml file to indicate flash attention configs') + parser.add_argument("--keep", action='store_true', default=False, help='keep generated files') + parser.add_argument("--verbose", action='store_true', default=False, help="enables time_breakdown and additional logging messages") + parser.add_argument("--num_threads", type=int, default=16, help="number of threads to use for kernel compilation and post processing") + parser.add_argument("--jobs", type=int, default=1, help="number of generated files") + parser.add_argument("--iters", type=int, default=1000, help="number of iterations") + parser.add_argument("--datatype", type=str, default='fp16', help="element type") + parser.add_argument("--no_warmup", action='store_true', default=False, help="Do not call the warmup kernel") + + args = parser.parse_args() + return args + +def main(): + args = parse_args() + output = args.o + keepTmp = args.keep + jobs = args.jobs + iters = args.iters + skipWarmup = args.no_warmup + + # Get element type + dtype = args.datatype + + fa_configs = [] + fa_config_file = args.fa_config_file + if fa_config_file == "" or not os.path.isfile(fa_config_file): + batch = args.batch + h = args.H + n_ctx = args.n_ctx + d_head = args.d_head + fa_configs = [(batch, h, n_ctx, d_head)] + else: + with open(fa_config_file) as file: + inputs = yaml.safe_load(file) + for item in inputs: + fa_configs.append((item['Batch'], item['H'], item['N_Ctx'], item['D_Head'])) + + f_results = open(output, 'w') + for config in fa_configs: + batch = config[0] + h = config[1] + n_ctx = config[2] + d_head = config[3] + minTime, bestConfig = tune_fa_config(batch, h, n_ctx, d_head, args.num_threads, args.verbose) + minTime = format_output(minTime) + print('best_config: ',str(bestConfig)) + f_results.write('- ' + str(bestConfig) + ' ') + f_results.write(f'# time(us): {minTime}\n') + f_results.close() + +if __name__ == '__main__': + sys.exit(main()) From d6056a0563646e175d7e0777559e4b768de24e19 Mon Sep 17 00:00:00 2001 From: Yi Qian Date: Wed, 26 Jun 2024 05:58:38 +0000 Subject: [PATCH 2/2] Add --benchmark --- .../generated_fa_kernel16_16_1024_128.py | 1000 ----------------- scripts/amd/gemm/tune_fa.py | 111 +- 2 files changed, 68 insertions(+), 1043 deletions(-) delete mode 100644 python/perf-kernels/generated_fa_kernel16_16_1024_128.py diff --git a/python/perf-kernels/generated_fa_kernel16_16_1024_128.py b/python/perf-kernels/generated_fa_kernel16_16_1024_128.py deleted file mode 100644 index 00c7aa66cb49..000000000000 --- a/python/perf-kernels/generated_fa_kernel16_16_1024_128.py +++ /dev/null @@ -1,1000 +0,0 @@ -import pytest -import torch -import sys - -import triton -import triton.language as tl -# Pick the fp8 data type - -# AMD E5M2B16 -# float8:tl.constexpr = torch.float8_e5m2fnuz - -# AMD E4M3B8 -# Note: When picking this f8 data type, scaling is required when using f8 -# for the second gemm -TORCH_HAS_FP8E4 = hasattr(torch, 'float8_e4m3fnuz') -float8:tl.constexpr = None if not TORCH_HAS_FP8E4 else torch.float8_e4m3fnuz - - -@triton.jit -def _attn_fwd_BLOCKM16_BLOCKN16_PreloadvTrue( - Q, K, V, sm_scale, M, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, - N_CTX, - BLOCK_DMODEL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - pre_load_v: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - qkv_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - K_block_ptr = tl.make_block_ptr( - base=K + qkv_offset, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(0, 1) - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - q = tl.load(Q_block_ptr) - # it's even better to multiply the qk_scale and convert to f16 - # than doing it inside the loop - # So conversion is quite cheap - q = (q * qk_scale).to(q.dtype) - lo, hi = 0, N_CTX - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(K_block_ptr) - if pre_load_v: - v = tl.load(V_block_ptr) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - #qk = (qk * qk_scale) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] - p = tl.math.exp2(qk) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not pre_load_v: - v = tl.load(V_block_ptr) - acc += tl.dot(p.to(v.dtype), v) - # -- update m_i and l_i - l_ij = tl.sum(p, 1) - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - acc = acc / l_i[:, None] - # write back O - O_block_ptr = tl.make_block_ptr( - base=Out + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) - - -@triton.jit -def _attn_fwd_BLOCKM16_BLOCKN16_PreloadvFalse( - Q, K, V, sm_scale, M, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, - N_CTX, - BLOCK_DMODEL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - pre_load_v: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - qkv_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - K_block_ptr = tl.make_block_ptr( - base=K + qkv_offset, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(0, 1) - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - q = tl.load(Q_block_ptr) - # it's even better to multiply the qk_scale and convert to f16 - # than doing it inside the loop - # So conversion is quite cheap - q = (q * qk_scale).to(q.dtype) - lo, hi = 0, N_CTX - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(K_block_ptr) - if pre_load_v: - v = tl.load(V_block_ptr) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - #qk = (qk * qk_scale) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] - p = tl.math.exp2(qk) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not pre_load_v: - v = tl.load(V_block_ptr) - acc += tl.dot(p.to(v.dtype), v) - # -- update m_i and l_i - l_ij = tl.sum(p, 1) - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - acc = acc / l_i[:, None] - # write back O - O_block_ptr = tl.make_block_ptr( - base=Out + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) - - -@triton.jit -def _attn_fwd_BLOCKM16_BLOCKN32_PreloadvTrue( - Q, K, V, sm_scale, M, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, - N_CTX, - BLOCK_DMODEL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - pre_load_v: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - qkv_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - K_block_ptr = tl.make_block_ptr( - base=K + qkv_offset, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(0, 1) - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - q = tl.load(Q_block_ptr) - # it's even better to multiply the qk_scale and convert to f16 - # than doing it inside the loop - # So conversion is quite cheap - q = (q * qk_scale).to(q.dtype) - lo, hi = 0, N_CTX - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(K_block_ptr) - if pre_load_v: - v = tl.load(V_block_ptr) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - #qk = (qk * qk_scale) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] - p = tl.math.exp2(qk) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not pre_load_v: - v = tl.load(V_block_ptr) - acc += tl.dot(p.to(v.dtype), v) - # -- update m_i and l_i - l_ij = tl.sum(p, 1) - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - acc = acc / l_i[:, None] - # write back O - O_block_ptr = tl.make_block_ptr( - base=Out + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) - - -@triton.jit -def _attn_fwd_BLOCKM16_BLOCKN32_PreloadvFalse( - Q, K, V, sm_scale, M, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, - N_CTX, - BLOCK_DMODEL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - pre_load_v: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - qkv_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - K_block_ptr = tl.make_block_ptr( - base=K + qkv_offset, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(0, 1) - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - q = tl.load(Q_block_ptr) - # it's even better to multiply the qk_scale and convert to f16 - # than doing it inside the loop - # So conversion is quite cheap - q = (q * qk_scale).to(q.dtype) - lo, hi = 0, N_CTX - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(K_block_ptr) - if pre_load_v: - v = tl.load(V_block_ptr) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - #qk = (qk * qk_scale) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] - p = tl.math.exp2(qk) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not pre_load_v: - v = tl.load(V_block_ptr) - acc += tl.dot(p.to(v.dtype), v) - # -- update m_i and l_i - l_ij = tl.sum(p, 1) - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - acc = acc / l_i[:, None] - # write back O - O_block_ptr = tl.make_block_ptr( - base=Out + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) - - -@triton.jit -def _attn_fwd_BLOCKM32_BLOCKN16_PreloadvTrue( - Q, K, V, sm_scale, M, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, - N_CTX, - BLOCK_DMODEL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - pre_load_v: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - qkv_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - K_block_ptr = tl.make_block_ptr( - base=K + qkv_offset, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(0, 1) - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - q = tl.load(Q_block_ptr) - # it's even better to multiply the qk_scale and convert to f16 - # than doing it inside the loop - # So conversion is quite cheap - q = (q * qk_scale).to(q.dtype) - lo, hi = 0, N_CTX - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(K_block_ptr) - if pre_load_v: - v = tl.load(V_block_ptr) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - #qk = (qk * qk_scale) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] - p = tl.math.exp2(qk) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not pre_load_v: - v = tl.load(V_block_ptr) - acc += tl.dot(p.to(v.dtype), v) - # -- update m_i and l_i - l_ij = tl.sum(p, 1) - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - acc = acc / l_i[:, None] - # write back O - O_block_ptr = tl.make_block_ptr( - base=Out + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) - - -@triton.jit -def _attn_fwd_BLOCKM32_BLOCKN16_PreloadvFalse( - Q, K, V, sm_scale, M, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, - N_CTX, - BLOCK_DMODEL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - pre_load_v: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - qkv_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - K_block_ptr = tl.make_block_ptr( - base=K + qkv_offset, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(0, 1) - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - q = tl.load(Q_block_ptr) - # it's even better to multiply the qk_scale and convert to f16 - # than doing it inside the loop - # So conversion is quite cheap - q = (q * qk_scale).to(q.dtype) - lo, hi = 0, N_CTX - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(K_block_ptr) - if pre_load_v: - v = tl.load(V_block_ptr) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - #qk = (qk * qk_scale) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] - p = tl.math.exp2(qk) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not pre_load_v: - v = tl.load(V_block_ptr) - acc += tl.dot(p.to(v.dtype), v) - # -- update m_i and l_i - l_ij = tl.sum(p, 1) - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - acc = acc / l_i[:, None] - # write back O - O_block_ptr = tl.make_block_ptr( - base=Out + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) - - -@triton.jit -def _attn_fwd_BLOCKM32_BLOCKN32_PreloadvTrue( - Q, K, V, sm_scale, M, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, - N_CTX, - BLOCK_DMODEL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - pre_load_v: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - qkv_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - K_block_ptr = tl.make_block_ptr( - base=K + qkv_offset, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(0, 1) - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - q = tl.load(Q_block_ptr) - # it's even better to multiply the qk_scale and convert to f16 - # than doing it inside the loop - # So conversion is quite cheap - q = (q * qk_scale).to(q.dtype) - lo, hi = 0, N_CTX - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(K_block_ptr) - if pre_load_v: - v = tl.load(V_block_ptr) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - #qk = (qk * qk_scale) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] - p = tl.math.exp2(qk) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not pre_load_v: - v = tl.load(V_block_ptr) - acc += tl.dot(p.to(v.dtype), v) - # -- update m_i and l_i - l_ij = tl.sum(p, 1) - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - acc = acc / l_i[:, None] - # write back O - O_block_ptr = tl.make_block_ptr( - base=Out + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) - - -@triton.jit -def _attn_fwd_BLOCKM32_BLOCKN32_PreloadvFalse( - Q, K, V, sm_scale, M, Out, - stride_qz, stride_qh, stride_qm, stride_qk, - stride_kz, stride_kh, stride_kn, stride_kk, - stride_vz, stride_vh, stride_vn, stride_vk, - stride_oz, stride_oh, stride_om, stride_on, - Z, H, - N_CTX, - BLOCK_DMODEL: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, - pre_load_v: tl.constexpr, -): - start_m = tl.program_id(0) - off_hz = tl.program_id(1) - qkv_offset = off_hz * stride_qh - Q_block_ptr = tl.make_block_ptr( - base=Q + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_qm, stride_qk), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - K_block_ptr = tl.make_block_ptr( - base=K + qkv_offset, - shape=(BLOCK_DMODEL, N_CTX), - strides=(stride_kk, stride_kn), - offsets=(0, 0), - block_shape=(BLOCK_DMODEL, BLOCK_N), - order=(0, 1) - ) - V_block_ptr = tl.make_block_ptr( - base=V + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_vk, stride_vn), - offsets=(0, 0), - block_shape=(BLOCK_N, BLOCK_DMODEL), - order=(0, 1) - ) - # initialize offsets - offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) - offs_n = tl.arange(0, BLOCK_N) - # initialize pointer to m and l - m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") - l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0 - acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) - # scale sm_scale by log_2(e) and use - # 2^x instead of exp in the loop because CSE and LICM - # don't work as expected with `exp` in the loop - qk_scale = sm_scale * 1.44269504 - q = tl.load(Q_block_ptr) - # it's even better to multiply the qk_scale and convert to f16 - # than doing it inside the loop - # So conversion is quite cheap - q = (q * qk_scale).to(q.dtype) - lo, hi = 0, N_CTX - # loop over k, v and update accumulator - for start_n in range(lo, hi, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - # -- compute qk ---- - k = tl.load(K_block_ptr) - if pre_load_v: - v = tl.load(V_block_ptr) - qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) - #qk = (qk * qk_scale) - m_ij = tl.maximum(m_i, tl.max(qk, 1)) - qk = qk - m_ij[:, None] - p = tl.math.exp2(qk) - # -- update output accumulator -- - alpha = tl.math.exp2(m_i - m_ij) - acc = acc * alpha[:, None] - if not pre_load_v: - v = tl.load(V_block_ptr) - acc += tl.dot(p.to(v.dtype), v) - # -- update m_i and l_i - l_ij = tl.sum(p, 1) - l_i = l_i * alpha + l_ij - # update m_i and l_i - m_i = m_ij - V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) - acc = acc / l_i[:, None] - # write back O - O_block_ptr = tl.make_block_ptr( - base=Out + qkv_offset, - shape=(N_CTX, BLOCK_DMODEL), - strides=(stride_om, stride_on), - offsets=(start_m * BLOCK_M, 0), - block_shape=(BLOCK_M, BLOCK_DMODEL), - order=(1, 0) - ) - tl.store(O_block_ptr, acc.to(Out.type.element_ty)) - - -name_to_torch_types = { - 'fp16': torch.float16, - 'bf16': torch.bfloat16, - 'fp8': float8 -} - -def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal=False, dtype='fp16'): - if dtype == 'fp8' and not TORCH_HAS_FP8E4: - sys.exit("fp8 is not available") - init_dtype = torch.float16 if dtype != 'bf16' else torch.bfloat16 - q = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) - k = torch.randn((BATCH, H, N_CTX, D_HEAD), dtype=init_dtype, device="cuda", requires_grad=True) - v = torch.randn((BATCH, H, D_HEAD, N_CTX), dtype=init_dtype, device="cuda", requires_grad=True) - sm_scale = 1.3 - # q,k casting for partial fp8 - q = q.to(name_to_torch_types[dtype]) - k = k.to(name_to_torch_types[dtype]) - - Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-2] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128} - o = torch.empty_like(q, dtype=v.dtype) - waves_per_eu = 2 - num_warps = 4 - num_stages = 1 - slice_k_tile = 32 - kpack = 1 - - M = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) - - - for i in range(100): - grid = ( triton.cdiv(q.shape[2], 16), q.shape[0] * q.shape[1], 1) - _attn_fwd_BLOCKM16_BLOCKN16_PreloadvTrue[grid]( - q, k, v, sm_scale, M, o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], - N_CTX=q.shape[2], - BLOCK_DMODEL=Lk, - BLOCK_M = 16, - BLOCK_N = 16, - waves_per_eu = waves_per_eu, - num_warps = num_warps, - num_stages = num_stages, - pre_load_v = True, - slice_k_tile = slice_k_tile, - kpack = kpack, - ) - - for i in range(100): - grid = ( triton.cdiv(q.shape[2], 16), q.shape[0] * q.shape[1], 1) - _attn_fwd_BLOCKM16_BLOCKN16_PreloadvFalse[grid]( - q, k, v, sm_scale, M, o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], - N_CTX=q.shape[2], - BLOCK_DMODEL=Lk, - BLOCK_M = 16, - BLOCK_N = 16, - waves_per_eu = waves_per_eu, - num_warps = num_warps, - num_stages = num_stages, - pre_load_v = False, - slice_k_tile = slice_k_tile, - kpack = kpack, - ) - - for i in range(100): - grid = ( triton.cdiv(q.shape[2], 16), q.shape[0] * q.shape[1], 1) - _attn_fwd_BLOCKM16_BLOCKN32_PreloadvTrue[grid]( - q, k, v, sm_scale, M, o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], - N_CTX=q.shape[2], - BLOCK_DMODEL=Lk, - BLOCK_M = 16, - BLOCK_N = 32, - waves_per_eu = waves_per_eu, - num_warps = num_warps, - num_stages = num_stages, - pre_load_v = True, - slice_k_tile = slice_k_tile, - kpack = kpack, - ) - - for i in range(100): - grid = ( triton.cdiv(q.shape[2], 16), q.shape[0] * q.shape[1], 1) - _attn_fwd_BLOCKM16_BLOCKN32_PreloadvFalse[grid]( - q, k, v, sm_scale, M, o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], - N_CTX=q.shape[2], - BLOCK_DMODEL=Lk, - BLOCK_M = 16, - BLOCK_N = 32, - waves_per_eu = waves_per_eu, - num_warps = num_warps, - num_stages = num_stages, - pre_load_v = False, - slice_k_tile = slice_k_tile, - kpack = kpack, - ) - - for i in range(100): - grid = ( triton.cdiv(q.shape[2], 32), q.shape[0] * q.shape[1], 1) - _attn_fwd_BLOCKM32_BLOCKN16_PreloadvTrue[grid]( - q, k, v, sm_scale, M, o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], - N_CTX=q.shape[2], - BLOCK_DMODEL=Lk, - BLOCK_M = 32, - BLOCK_N = 16, - waves_per_eu = waves_per_eu, - num_warps = num_warps, - num_stages = num_stages, - pre_load_v = True, - slice_k_tile = slice_k_tile, - kpack = kpack, - ) - - for i in range(100): - grid = ( triton.cdiv(q.shape[2], 32), q.shape[0] * q.shape[1], 1) - _attn_fwd_BLOCKM32_BLOCKN16_PreloadvFalse[grid]( - q, k, v, sm_scale, M, o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], - N_CTX=q.shape[2], - BLOCK_DMODEL=Lk, - BLOCK_M = 32, - BLOCK_N = 16, - waves_per_eu = waves_per_eu, - num_warps = num_warps, - num_stages = num_stages, - pre_load_v = False, - slice_k_tile = slice_k_tile, - kpack = kpack, - ) - - for i in range(100): - grid = ( triton.cdiv(q.shape[2], 32), q.shape[0] * q.shape[1], 1) - _attn_fwd_BLOCKM32_BLOCKN32_PreloadvTrue[grid]( - q, k, v, sm_scale, M, o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], - N_CTX=q.shape[2], - BLOCK_DMODEL=Lk, - BLOCK_M = 32, - BLOCK_N = 32, - waves_per_eu = waves_per_eu, - num_warps = num_warps, - num_stages = num_stages, - pre_load_v = True, - slice_k_tile = slice_k_tile, - kpack = kpack, - ) - - for i in range(100): - grid = ( triton.cdiv(q.shape[2], 32), q.shape[0] * q.shape[1], 1) - _attn_fwd_BLOCKM32_BLOCKN32_PreloadvFalse[grid]( - q, k, v, sm_scale, M, o, - q.stride(0), q.stride(1), q.stride(2), q.stride(3), - k.stride(0), k.stride(1), k.stride(2), k.stride(3), - v.stride(0), v.stride(1), v.stride(2), v.stride(3), - o.stride(0), o.stride(1), o.stride(2), o.stride(3), - q.shape[0], q.shape[1], - N_CTX=q.shape[2], - BLOCK_DMODEL=Lk, - BLOCK_M = 32, - BLOCK_N = 32, - waves_per_eu = waves_per_eu, - num_warps = num_warps, - num_stages = num_stages, - pre_load_v = False, - slice_k_tile = slice_k_tile, - kpack = kpack, - ) - - -def main(): - bench_flash_attention(16, 16, 1024, 128) - -if __name__ == '__main__': - sys.exit(main()) - diff --git a/scripts/amd/gemm/tune_fa.py b/scripts/amd/gemm/tune_fa.py index 4b622fd3cade..93cfc3991aee 100644 --- a/scripts/amd/gemm/tune_fa.py +++ b/scripts/amd/gemm/tune_fa.py @@ -1,14 +1,8 @@ """ -Fused Attention -=============== - -This is a Triton implementation of the Flash Attention v2 algorithm from Tri Dao (https://tridao.me/publications/flash2/flash2.pdf) - -Extra Credits: -- Original flash attention paper (https://arxiv.org/abs/2205.14135) -- Rabe and Staats (https://arxiv.org/pdf/2112.05682v2.pdf) -- Adam P. Goucher for simplified vector math - +Usage: + python3 ./tune_fa.py -batch 16 -H 16 -n_ctx 1024 -d_head 128 + python3 ./tune_fa.py --fa_config_file flash_att_configs.yaml --o tuning_fa.yaml + python3 ./tune_fa.py --fa_config_file tuning_fa.yaml --benchmark --o bench_results.cdv """ import argparse @@ -62,9 +56,33 @@ def format_output(unformatted): formatted = "{:.2f}".format(unformatted) return formatted -@triton.jit -def max_fn(x, y): - return tl.math.max(x, y) +def get_full_tuning_space(): + configs = [] + block_m_range = [32, 64, 128, 256] + block_n_range = [64, 128] + waves_per_eu_range = [0, 2, 3] + num_warps_range = [1, 2, 4, 8] + num_stages = [1] + pre_load_v_range = [True, False] + kpack_range = [1, 2] + matrix_instr_nonkdim_range = [16, 32] + + for block_m in block_m_range: + for block_n in block_n_range: + for pre_load_v in pre_load_v_range: + configs.append({'Block_M': block_m, 'Block_N': block_n, 'Pre_load_v': pre_load_v}) + return configs + +def process_item(item): + batch = item['Batch'] + h = item['H'] + n_ctx = item['N_Ctx'] + d_head = item['D_Head'] + del item['Batch'] + del item['H'] + del item['N_Ctx'] + del item['D_Head'] + return batch, h, n_ctx, d_head, item def generate_wrapper(tuning_parms): dri_str = """ @@ -100,9 +118,9 @@ def bench_flash_attention(BATCH, H, N_CTX, D_HEAD, causal=False, dtype='fp16'): """ dri_str += '\n' for tp in tuning_parms: - block_m = tp[0] - block_n = tp[1] - pre_load_v = tp[2] + block_m = tp['Block_M'] + block_n = tp['Block_N'] + pre_load_v = bool(tp['Pre_load_v']) dri_str += f""" for i in range(100): grid = ( triton.cdiv(q.shape[2], {block_m}), q.shape[0] * q.shape[1], 1) @@ -137,7 +155,7 @@ def main(): sys.exit(main()) """ -def generate_fa_kernel(Batch, H, N_Ctx, D_Head): +def generate_fa_kernel(Batch, H, N_Ctx, D_Head, tuning_parms): # create the kernel file file_name = f"{Batch}_{H}_{N_Ctx}_{D_Head}.py" f_kernel = open("./generated_fa_kernel_"+file_name, 'w') @@ -164,21 +182,18 @@ def generate_fa_kernel(Batch, H, N_Ctx, D_Head): f_kernel.write(import_str + '\n') # generate kernels with tuning parameters - tuning_parms = [] - block_m_range = [16, 32] - block_n_range = [16, 32] - pre_load_v_range = [True, False] + with open(os.path.dirname(os.path.abspath(__file__))+"/flash_attention_fwd_kernel.py") as file: fa_kernel_code = file.read() - for block_m in block_m_range: - for block_n in block_n_range: - for pre_load_v in pre_load_v_range: - tuning_parms.append((block_m, block_n, pre_load_v)) - fa_kernel_str = fa_kernel_code.replace("attn_fwd_kernel", f"attn_fwd_BLOCKM_{block_m}_BLOCKN_{block_n}_Preloadv_{pre_load_v}") - fa_kernel_str = fa_kernel_str.replace("import triton.language as tl", "") - fa_kernel_str = fa_kernel_str.replace("import triton", "") - f_kernel.write(fa_kernel_str + "\n") + for config in tuning_parms: + block_m = config['Block_M'] + block_n = config['Block_N'] + pre_load_v = bool(config['Pre_load_v']) + fa_kernel_str = fa_kernel_code.replace("attn_fwd_kernel", f"attn_fwd_BLOCKM_{block_m}_BLOCKN_{block_n}_Preloadv_{pre_load_v}") + fa_kernel_str = fa_kernel_str.replace("import triton.language as tl", "") + fa_kernel_str = fa_kernel_str.replace("import triton", "") + f_kernel.write(fa_kernel_str + "\n") # generate the driver dri_str = generate_wrapper(tuning_parms) @@ -194,9 +209,9 @@ def main(): f_kernel.write(main_str+'\n') f_kernel.close() -def tune_fa_config(Batch, H, N_Ctx, D_Head, num_threads, verbose): +def tune_fa_config(Batch, H, N_Ctx, D_Head, tuning_parms, num_threads, verbose): # create the kernel file - generate_fa_kernel(Batch, H, N_Ctx, D_Head) + generate_fa_kernel(Batch, H, N_Ctx, D_Head, tuning_parms) run_bash_command("rm -rf ~/.triton/cache") start_time = datetime.now() @@ -214,7 +229,7 @@ def tune_fa_config(Batch, H, N_Ctx, D_Head, num_threads, verbose): splitted_config = min_row["Name"].split('_') best_config = {'Batch':Batch, 'H':H, 'N_Ctx':N_Ctx, 'D_Head':D_Head} - best_config.update({'Block_M':splitted_config[4], 'Block_N':splitted_config[6], 'Preloadv':splitted_config[8]}) + best_config.update({'Block_M':int(splitted_config[4]), 'Block_N':int(splitted_config[6]), 'Pre_load_v':splitted_config[8]}) return min_row['AverageNs'], best_config def parse_args(): @@ -229,6 +244,7 @@ def parse_args(): parser.add_argument("-d_head", type=int, default=128) parser.add_argument("--o", type=str, default='tuning_fa.yaml', help='yaml file to store tuning results') parser.add_argument("--fa_config_file", type=str, default="", help='yaml file to indicate flash attention configs') + parser.add_argument("--benchmark", action='store_true', default=False, help="Benchmark the given config") parser.add_argument("--keep", action='store_true', default=False, help='keep generated files') parser.add_argument("--verbose", action='store_true', default=False, help="enables time_breakdown and additional logging messages") parser.add_argument("--num_threads", type=int, default=16, help="number of threads to use for kernel compilation and post processing") @@ -238,6 +254,8 @@ def parse_args(): parser.add_argument("--no_warmup", action='store_true', default=False, help="Do not call the warmup kernel") args = parser.parse_args() + if args.benchmark: + args.o = 'flash_attention_bechmark.csv' return args def main(): @@ -247,6 +265,7 @@ def main(): jobs = args.jobs iters = args.iters skipWarmup = args.no_warmup + run_bench = args.benchmark # Get element type dtype = args.datatype @@ -258,24 +277,30 @@ def main(): h = args.H n_ctx = args.n_ctx d_head = args.d_head - fa_configs = [(batch, h, n_ctx, d_head)] + fa_configs = [(batch, h, n_ctx, d_head, None)] else: with open(fa_config_file) as file: inputs = yaml.safe_load(file) for item in inputs: - fa_configs.append((item['Batch'], item['H'], item['N_Ctx'], item['D_Head'])) + batch, h, n_ctx, d_head, item = process_item(item) + fa_configs.append((batch, h, n_ctx, d_head, item)) + + full_space = get_full_tuning_space() f_results = open(output, 'w') - for config in fa_configs: - batch = config[0] - h = config[1] - n_ctx = config[2] - d_head = config[3] - minTime, bestConfig = tune_fa_config(batch, h, n_ctx, d_head, args.num_threads, args.verbose) + if run_bench: + f_results.write("Batch,H,N_Ctx,D_Head,us\n") + for (batch, h, n_ctx, d_head, config) in fa_configs: + tuning_parms = [config] if run_bench else full_space + minTime, bestConfig = tune_fa_config(batch, h, n_ctx, d_head, tuning_parms, args.num_threads, args.verbose) minTime = format_output(minTime) - print('best_config: ',str(bestConfig)) - f_results.write('- ' + str(bestConfig) + ' ') - f_results.write(f'# time(us): {minTime}\n') + if run_bench: + print(f'{batch},{h},{n_ctx},{d_head},{minTime}\n') + f_results.write(f'{batch},{h},{n_ctx},{d_head},{minTime}\n') + else: + print('best_config: ',str(bestConfig)) + f_results.write('- ' + str(bestConfig) + ' ') + f_results.write(f'# time(us): {minTime}\n') f_results.close() if __name__ == '__main__':