Skip to content

Commit

Permalink
Update rocm_flash_attn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
MErkinSag authored Oct 21, 2024
1 parent 3cf76ff commit e107fa4
Showing 1 changed file with 5 additions and 9 deletions.
14 changes: 5 additions & 9 deletions vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@
logger = init_logger(__name__)

_PARTITION_SIZE_ROCM = 512
_ON_NAVI = "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName

_GPU_ARCH = torch.cuda.get_device_properties("cuda").gcnArchName
_ON_NAVI = "gfx1" in _GPU_ARCH
_ON_MI250_MI300 = any(arch in _GPU_ARCH for arch in ["gfx90a", "gfx940", "gfx941", "gfx942"])

Check failure on line 26 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/attention/backends/rocm_flash_attn.py:26:81: E501 Line too long (93 > 80)

class ROCmFlashAttentionBackend(AttentionBackend):

Expand Down Expand Up @@ -661,16 +662,11 @@ def _sdpa_attention(
def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int,
block_size: int, gqa_ratio: int,
max_seq_len: int) -> bool:
# Custom paged attention is only supported on MI250/MI300 GPUs
gpu_arch = torch.cuda.get_device_properties("cuda").gcnArchName.split(":")[0]

ON_MI250_MI300 = any(s in gpu_arch for s in ["gfx90a", "gfx940", "gfx941", "gfx942"])

if not ON_MI250_MI300:
if not _ON_MI250_MI300:
logger.warning(f"Custom Paged Attention is not currently supported on {gpu_arch}.")

Check failure on line 666 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Name "gpu_arch" is not defined [name-defined]

Check failure on line 666 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (G004)

vllm/attention/backends/rocm_flash_attn.py:666:24: G004 Logging statement uses f-string

Check failure on line 666 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (F821)

vllm/attention/backends/rocm_flash_attn.py:666:80: F821 Undefined name `gpu_arch`

Check failure on line 666 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/attention/backends/rocm_flash_attn.py:666:81: E501 Line too long (91 > 80)

# rocm custom page attention not support on navi (gfx1*)
return (ON_MI250_MI300 and not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)
return (_ON_MI250_MI300 and not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16)

Check failure on line 669 in vllm/attention/backends/rocm_flash_attn.py

View workflow job for this annotation

GitHub Actions / ruff (3.10)

Ruff (E501)

vllm/attention/backends/rocm_flash_attn.py:669:81: E501 Line too long (97 > 80)
and (head_size == 64 or head_size == 128)
and (block_size == 16 or block_size == 32)
and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 32768)

0 comments on commit e107fa4

Please sign in to comment.