diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index c7e7a4d52e5a7..665e37c31887b 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -199,7 +199,7 @@ def __init__( attn_backend = _Backend.XFORMERS self.attn_backend = attn_backend if attn_backend in { - _Backend.TORCH_SDPA, _Backend.XFORMERS + _Backend.TORCH_SDPA, _Backend.XFORMERS, _Backend.HPU_ATTN } else _Backend.TORCH_SDPA def forward( @@ -232,6 +232,28 @@ def forward( value, scale=self.scale) out = out.transpose(1, 2) + elif self.attn_backend == _Backend.HPU_ATTN: + from habana_frameworks.torch.hpex.kernels import FusedSDPA + from vllm_hpu_extension.utils import ModuleFusedSDPA + + fsdpa_op = ModuleFusedSDPA(FusedSDPA) + + query, key, value = (x.transpose(1, 2) + for x in (query, key, value)) + + out = fsdpa_op(query, + key, + value, + None, + dropout_p=0.0, + is_causal=False, + scale=self.scale, + softmax_mode="fast", + recompute_mode=True, + valid_sequence_lengths=None) + + out = out.transpose(1, 2) + return out.reshape(bsz, q_len, -1)