From 2ac6c986bea5ffe46c606a7782aaf9e0401b599f Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Fri, 10 Jan 2025 23:44:18 +0700 Subject: [PATCH] Fix Sm80 tile_count_semaphore, adjust test tolerance --- hopper/flash_api.cpp | 11 +++++++++-- hopper/test_flash_attn.py | 29 ++++++++++++++++------------- 2 files changed, 25 insertions(+), 15 deletions(-) diff --git a/hopper/flash_api.cpp b/hopper/flash_api.cpp index 690d957dd..59b39a872 100644 --- a/hopper/flash_api.cpp +++ b/hopper/flash_api.cpp @@ -498,7 +498,11 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq auto q_type = q.scalar_type(); TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16 || q_type == at::ScalarType::Float8_e4m3fn, - "FlashAttention only support fp16, bf16, and fp8_e4m3 data type"); + "FlashAttention only supports fp16, bf16, and fp8_e4m3 data type"); + if (dprops->major < 9) { + TORCH_CHECK(q_type == at::ScalarType::Half || q_type == at::ScalarType::BFloat16, + "FlashAttention on Ampere/Ada cards only supports fp16 and bf16 data type"); + } TORCH_CHECK(k.scalar_type() == q_type, "query and key must have the same dtype"); TORCH_CHECK(v.scalar_type() == q_type, "query and value must have the same dtype"); @@ -788,7 +792,10 @@ mha_fwd(at::Tensor &q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seq at::Tensor tile_count_semaphore; // We don't use the persistent scheduler if Split and not Varlen - if (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) { + bool const persistent_scheduler = params.arch >= 90 + ? (((params.is_causal || params.is_local) && (params.num_splits == 1)) || is_varlen) + : ((params.is_causal && !is_varlen) || (is_varlen && params.num_splits > 1)); + if (persistent_scheduler) { tile_count_semaphore = torch::zeros({1}, opts.dtype(torch::kInt32)); params.tile_count_semaphore = tile_count_semaphore.data_ptr(); } else { diff --git a/hopper/test_flash_attn.py b/hopper/test_flash_attn.py index e646d8af5..1fe43e21f 100644 --- a/hopper/test_flash_attn.py +++ b/hopper/test_flash_attn.py @@ -27,7 +27,7 @@ DISABLE_SOFTCAP = os.getenv("FLASH_ATTENTION_DISABLE_SOFTCAP", "FALSE") == "TRUE" DISABLE_PACKGQA = os.getenv("FLASH_ATTENTION_DISABLE_PACKGQA", "FALSE") == "TRUE" DISABLE_FP16 = os.getenv("FLASH_ATTENTION_DISABLE_FP16", "FALSE") == "TRUE" -DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" +DISABLE_FP8 = os.getenv("FLASH_ATTENTION_DISABLE_FP8", "FALSE") == "TRUE" or torch.cuda.get_device_capability("cuda")[0] < 9 DISABLE_HDIM64 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM64", "FALSE") == "TRUE" DISABLE_HDIM96 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM96", "FALSE") == "TRUE" DISABLE_HDIM128 = os.getenv("FLASH_ATTENTION_DISABLE_HDIM128", "FALSE") == "TRUE" @@ -165,7 +165,9 @@ def test_flash_attn_output( # qk = torch.einsum('bthd,bshd->bhts', q_ref.float() / math.sqrt(d), k_ref.float()) # lse_ref = torch.logsumexp(qk, dim=-1) - abs_tol = 1e-4 if softcap == 0.0 else 5e-4 + # Numerical error if we just do any arithmetic on out_ref + fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() + rtol = 2 if softcap == 0.0 else 3 print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") print(f"Pytorch mean diff: {(out_pt - out_ref).abs().mean().item()}") @@ -192,8 +194,7 @@ def test_flash_attn_output( # Check that FlashAttention's numerical error is at most twice the numerical error # of a Pytorch implementation. - multiple = 2 if dtype != torch.float8_e4m3fn else 3 - assert (out - out_ref).abs().max().item() <= multiple * (out_pt - out_ref).abs().max().item() + abs_tol + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor: g = torch.randn_like(out) @@ -248,10 +249,12 @@ def test_flash_attn_output( if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn and not V_colmajor: - multiple = 2 - assert (dq - dq_ref).abs().max().item() <= multiple * (dq_pt - dq_ref).abs().max().item() + abs_tol - assert (dk - dk_ref).abs().max().item() <= multiple * (dk_pt - dk_ref).abs().max().item() + abs_tol - assert (dv - dv_ref).abs().max().item() <= multiple * (dv_pt - dv_ref).abs().max().item() + abs_tol + dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol + dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol + dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @@ -411,7 +414,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # Numerical error if we just do any arithmetic on out_ref fwd_atol = 2 * (out_ref + 0.3 - 0.3 - out_ref).abs().max().item() - rel_tol = 2 if softcap == 0.0 else 3 + rtol = 2 if softcap == 0.0 else 3 pack_gqa_vals = [False, True] if not DISABLE_PACKGQA else [False] num_splits_vals = [1, 3] if not DISABLE_SPLIT else [1] @@ -442,7 +445,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): # Check that FlashAttention's numerical error is at most 3x the numerical error # of a Pytorch implementation. - assert (out - out_ref).abs().max().item() <= rel_tol * (out_pt - out_ref).abs().max().item() + fwd_atol + assert (out - out_ref).abs().max().item() <= rtol * (out_pt - out_ref).abs().max().item() + fwd_atol if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn: @@ -515,11 +518,11 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device): if not DISABLE_BACKWARD and dtype != torch.float8_e4m3fn: dq_atol = 2 * (dq_ref + 0.3 - 0.3 - dq_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dq - dq_ref).abs().max().item() <= rel_tol * (dq_pt - dq_ref).abs().max().item() + dq_atol + assert (dq - dq_ref).abs().max().item() <= rtol * (dq_pt - dq_ref).abs().max().item() + dq_atol dk_atol = 2 * (dk_ref + 0.3 - 0.3 - dk_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dk - dk_ref).abs().max().item() <= rel_tol * (dk_pt - dk_ref).abs().max().item() + dk_atol + assert (dk - dk_ref).abs().max().item() <= rtol * (dk_pt - dk_ref).abs().max().item() + dk_atol dv_atol = 2 * (dv_ref + 0.3 - 0.3 - dv_ref).abs().max().item() + (0 if softcap == 0 else 3e-4) - assert (dv - dv_ref).abs().max().item() <= rel_tol * (dv_pt - dv_ref).abs().max().item() + dv_atol + assert (dv - dv_ref).abs().max().item() <= rtol * (dv_pt - dv_ref).abs().max().item() + dv_atol # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn])