From 32ec3e51117a4755aa0fb0d637aef976d6db06da Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Wed, 18 Dec 2024 17:00:15 +0200 Subject: [PATCH 1/7] selecting correct backend for MultiHeadAttention fix --- vllm/attention/layer.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 05d997279893b..62cb0b76bd4d5 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -191,11 +191,14 @@ def __init__( kv_cache_dtype=None, block_size=16, is_attention_free=False) - if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: - attn_backend = _Backend.XFORMERS - self.attn_backend = attn_backend if attn_backend in { - _Backend.TORCH_SDPA, _Backend.XFORMERS + attn_backend_enum = backend_name_to_enum(attn_backend.get_name()) + + if attn_backend_enum in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: + attn_backend_enum = _Backend.XFORMERS + + self.attn_backend = attn_backend_enum if attn_backend_enum in { + _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.HPU_ATTN } else _Backend.TORCH_SDPA def forward( @@ -228,6 +231,15 @@ def forward( value, scale=self.scale) out = out.transpose(1, 2) + elif self.attn_backend == _Backend.HPU_ATTN: + query, key, value = (x.transpose(1, 2) + for x in (query, key, value)) + out = F.scaled_dot_product_attention(query, + key, + value, + scale=self.scale) + out = out.transpose(1, 2).contiguous() + return out.view(bsz, q_len, -1) From 82528ff7e149818cb2ac2f1982f1599bc82d46d2 Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Wed, 18 Dec 2024 17:09:41 +0200 Subject: [PATCH 2/7] formatting --- vllm/attention/layer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 62cb0b76bd4d5..c44082329746f 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -194,7 +194,9 @@ def __init__( attn_backend_enum = backend_name_to_enum(attn_backend.get_name()) - if attn_backend_enum in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: + if attn_backend_enum in { + _Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1 + }: attn_backend_enum = _Backend.XFORMERS self.attn_backend = attn_backend_enum if attn_backend_enum in { @@ -239,7 +241,7 @@ def forward( value, scale=self.scale) out = out.transpose(1, 2).contiguous() - + return out.view(bsz, q_len, -1) From e3fee649dd241eefd07ba2ecc5babd553297fcea Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Mon, 30 Dec 2024 11:29:20 +0200 Subject: [PATCH 3/7] change hpu backend to use fused --- vllm/attention/layer.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index c44082329746f..6aaf9202e124d 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -234,12 +234,28 @@ def forward( scale=self.scale) out = out.transpose(1, 2) elif self.attn_backend == _Backend.HPU_ATTN: + from vllm_hpu_extension.utils import ModuleFusedSDPA + from habana_frameworks.torch.hpex.kernels import FusedSDPA + + HPUFusedSDPA = FusedSDPA + fsdpa_op = None if HPUFusedSDPA is None \ + else ModuleFusedSDPA(HPUFusedSDPA) + query, key, value = (x.transpose(1, 2) for x in (query, key, value)) - out = F.scaled_dot_product_attention(query, - key, - value, - scale=self.scale) + + out = fsdpa_op(query, + key, + value, + attn_mask=None, + dropout_p=0.0, + is_causal=True, + scale=self.scale, + softmax_mode="fast", + recompute_mode=True, + valid_sequence_lengths=None, + padding_side='right') + out = out.transpose(1, 2).contiguous() return out.view(bsz, q_len, -1) From b84d380f1342d197656b99cd0400a5f5577847a9 Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Mon, 30 Dec 2024 11:37:16 +0200 Subject: [PATCH 4/7] format --- vllm/attention/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 6aaf9202e124d..b3c75b93140d0 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -234,8 +234,8 @@ def forward( scale=self.scale) out = out.transpose(1, 2) elif self.attn_backend == _Backend.HPU_ATTN: - from vllm_hpu_extension.utils import ModuleFusedSDPA from habana_frameworks.torch.hpex.kernels import FusedSDPA + from vllm_hpu_extension.utils import ModuleFusedSDPA HPUFusedSDPA = FusedSDPA fsdpa_op = None if HPUFusedSDPA is None \ From b5f029533146d25277872b89be37eabf9a0980ad Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Mon, 30 Dec 2024 13:44:46 +0200 Subject: [PATCH 5/7] mypy check formated --- vllm/attention/layer.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index b3c75b93140d0..e57c3b4aff1d5 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -237,9 +237,7 @@ def forward( from habana_frameworks.torch.hpex.kernels import FusedSDPA from vllm_hpu_extension.utils import ModuleFusedSDPA - HPUFusedSDPA = FusedSDPA - fsdpa_op = None if HPUFusedSDPA is None \ - else ModuleFusedSDPA(HPUFusedSDPA) + fsdpa_op = ModuleFusedSDPA(FusedSDPA) query, key, value = (x.transpose(1, 2) for x in (query, key, value)) @@ -247,14 +245,13 @@ def forward( out = fsdpa_op(query, key, value, - attn_mask=None, + None, dropout_p=0.0, is_causal=True, scale=self.scale, softmax_mode="fast", recompute_mode=True, - valid_sequence_lengths=None, - padding_side='right') + valid_sequence_lengths=None) out = out.transpose(1, 2).contiguous() From 060c70050bf8b6034292d01e43373c6c7bba5295 Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Thu, 2 Jan 2025 09:21:15 +0200 Subject: [PATCH 6/7] is casual set to false --- vllm/attention/layer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index e57c3b4aff1d5..424c5508f736c 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -247,7 +247,7 @@ def forward( value, None, dropout_p=0.0, - is_causal=True, + is_causal=False, scale=self.scale, softmax_mode="fast", recompute_mode=True, From ecfbc5cbe64fb074b2624ce859e7377009ef223f Mon Sep 17 00:00:00 2001 From: Agata Dobrzyniewicz Date: Thu, 16 Jan 2025 11:18:52 +0200 Subject: [PATCH 7/7] rebase changes + format --- vllm/attention/layer.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 0cc19e2e7c661..665e37c31887b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -198,14 +198,7 @@ def __init__( if attn_backend in {_Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1}: attn_backend = _Backend.XFORMERS - attn_backend_enum = backend_name_to_enum(attn_backend.get_name()) - - if attn_backend_enum in { - _Backend.FLASH_ATTN, _Backend.FLASH_ATTN_VLLM_V1 - }: - attn_backend_enum = _Backend.XFORMERS - - self.attn_backend = attn_backend_enum if attn_backend_enum in { + self.attn_backend = attn_backend if attn_backend in { _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.HPU_ATTN } else _Backend.TORCH_SDPA @@ -260,7 +253,7 @@ def forward( valid_sequence_lengths=None) out = out.transpose(1, 2) - + return out.reshape(bsz, q_len, -1)