From bec24a23ba4eb8e274910d43d85e88a8acc28c08 Mon Sep 17 00:00:00 2001 From: Roger Wang <136131678+ywang96@users.noreply.github.com> Date: Sun, 26 Jan 2025 03:56:34 -0800 Subject: [PATCH] [Misc] Revert FA on ViT #12355 and #12435 (#12445) --- vllm/attention/layer.py | 41 ++++------------------------------------- 1 file changed, 4 insertions(+), 37 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index db682b4ac63b0..da663d894aeb3 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -210,9 +210,6 @@ def __init__( self.scale = scale self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads - assert self.num_heads % self.num_kv_heads == 0 - self.num_queries_per_kv = self.num_heads // self.num_kv_heads - dtype = torch.get_default_dtype() attn_backend = get_attn_backend(head_size, dtype, @@ -220,12 +217,12 @@ def __init__( block_size=16, is_attention_free=False) backend = backend_name_to_enum(attn_backend.get_name()) + if backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: + backend = _Backend.XFORMERS self.attn_backend = backend if backend in { _Backend.TORCH_SDPA, _Backend.XFORMERS, - _Backend.FLASH_ATTN, - _Backend.FLASH_ATTN_VLLM_V1, } else _Backend.TORCH_SDPA def forward( @@ -235,6 +232,7 @@ def forward( value: torch.Tensor, ) -> torch.Tensor: """Input shape: batch_size x seq_len x hidden_size""" + # TODO(Isotr0py): Use existing backend implementations and support FA3 bsz, q_len, _ = query.size() kv_len = key.size(1) @@ -242,38 +240,7 @@ def forward( key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) - if (num_repeat := self.num_queries_per_kv) > 1: - # Handle MQA and GQA - key = torch.repeat_interleave(key, num_repeat, dim=2) - value = torch.repeat_interleave(value, num_repeat, dim=2) - - if self.attn_backend in { - _Backend.FLASH_ATTN, - _Backend.FLASH_ATTN_VLLM_V1, - }: - from vllm.vllm_flash_attn import flash_attn_varlen_func - - cu_seqlens_q = torch.arange(0, (bsz + 1) * q_len, - step=q_len, - dtype=torch.int32, - device=query.device) - cu_seqlens_k = torch.arange(0, (bsz + 1) * kv_len, - step=kv_len, - dtype=torch.int32, - device=key.device) - - out = flash_attn_varlen_func( - query.flatten(0, 1), - key.flatten(0, 1), - value.flatten(0, 1), - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=q_len, - max_seqlen_k=kv_len, - softmax_scale=self.scale, - ) - out = out.reshape(bsz, q_len, -1) - elif self.attn_backend == _Backend.XFORMERS: + if self.attn_backend == _Backend.XFORMERS: from xformers import ops as xops out = xops.memory_efficient_attention_forward(query,