Skip to content

Commit

Permalink
Fix Sm80 tile_count_semaphore, adjust test tolerance
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Jan 10, 2025
1 parent 07bddf9 commit 2ac6c98
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 15 deletions.
11 changes: 9 additions & 2 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,11 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq

auto q_type = q.scalar_type();
TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn,
"FlashAttention only support fp16, bf16, and fp8_e4m3 data type");
"FlashAttention only supports fp16, bf16, and fp8_e4m3 data type");
if (dprops->major < 9) {
TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16,
"FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type");
}
TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype");
TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype");

Expand Down Expand Up @@ -788,7 +792,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq

at::Tensor tile_count_semaphore;
// We don't use the persistent scheduler if Split and not Varlen
if (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) {
bool const persistent_scheduler = params.arch >= 90
? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen)
: ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1));
if (persistent_scheduler) {
tile_count_semaphore = torch::zeros({1}, opts.dtype(torch::kInt32));
params.tile_count_semaphore = tile_count_semaphore.data_ptr<int>();
} else {
Expand Down
29 changes: 16 additions & 13 deletions hopper/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE"
DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE"
DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE"
DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE"
DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9
DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE"
DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE"
DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE"
Expand Down Expand Up @@ -165,7 +165,9 @@ def test_flash_attn_output(
# qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float())
# lse_ref = torch.logsumexp(qk, dim=-1)

abs_tol = 1e-4 if softcap == 0.0 else 5e-4
# Numerical error if we just do any arithmetic on out_ref
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
rtol = 2 if softcap == 0.0 else 3

print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}")
print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}")
Expand All @@ -192,8 +194,7 @@ def test_flash_attn_output(

# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
multiple = 2 if dtype != torch.float8_e4m3fn else 3
assert (out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item() + abs_tol
assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol

if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor:
g = torch.randn_like(out)
Expand Down Expand Up @@ -248,10 +249,12 @@ def test_flash_attn_output(


if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor:
multiple = 2
assert (dq - dq_ref).abs().max().item() <= multiple * (dq_pt - dq_ref).abs().max().item() + abs_tol
assert (dk - dk_ref).abs().max().item() <= multiple * (dk_pt - dk_ref).abs().max().item() + abs_tol
assert (dv - dv_ref).abs().max().item() <= multiple * (dv_pt - dv_ref).abs().max().item() + abs_tol
dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)
assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol
dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)
assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol
dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)
assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol


# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
Expand Down Expand Up @@ -411,7 +414,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):

# Numerical error if we just do any arithmetic on out_ref
fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item()
rel_tol = 2 if softcap == 0.0 else 3
rtol = 2 if softcap == 0.0 else 3

pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False]
num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1]
Expand Down Expand Up @@ -442,7 +445,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):

# Check that FlashAttention's numerical error is at most 3x the numerical error
# of a Pytorch implementation.
assert (out - out_ref).abs().max().item() <= rel_tol * (out_pt - out_ref).abs().max().item() + fwd_atol
assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol


if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn:
Expand Down Expand Up @@ -515,11 +518,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):

if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn:
dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)
assert (dq - dq_ref).abs().max().item() <= rel_tol * (dq_pt - dq_ref).abs().max().item() + dq_atol
assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol
dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)
assert (dk - dk_ref).abs().max().item() <= rel_tol * (dk_pt - dk_ref).abs().max().item() + dk_atol
assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol
dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4)
assert (dv - dv_ref).abs().max().item() <= rel_tol * (dv_pt - dv_ref).abs().max().item() + dv_atol
assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol


# @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])
Expand Down

0 comments on commit 2ac6c98

Please sign in to comment.