diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 0b0f93e28eba6..787fd5346ca5a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -21,8 +21,9 @@ logger = init_logger(__name__) _PARTITION_SIZE_ROCM = 512 -_ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName - +_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName +_ON_NAVI = "gfx1" in _GPU_ARCH +_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"]) class ROCmFlashAttentionBackend(AttentionBackend): @@ -661,16 +662,11 @@ def _sdpa_attention( def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, block_size: int, gqa_ratio: int, max_seq_len: int) -> bool: - # Custom paged attention is only supported on MI250/MI300 GPUs - gpu_arch = torch.cuda.get_device_properties("cuda").gcnArchName.split(":")[0] - - ON_MI250_MI300 = any(s in gpu_arch for s in ["gfx90a", "gfx940", "gfx941", "gfx942"]) - - if not ON_MI250_MI300: + if not _ON_MI250_MI300: logger.warning(f"Custom Paged Attention is not currently supported on {gpu_arch}.") # rocm custom page attention not support on navi (gfx1*) - return (ON_MI250_MI300 and not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) + return (_ON_MI250_MI300 and not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) and (head_size == 64 or head_size == 128) and (block_size == 16 or block_size == 32) and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)