Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Platform][Refactor] Extract func get_default_attn_backend to Platform #10358

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 6 additions & 50 deletions vllm/attention/selector.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import enum
import os
from contextlib import contextmanager
from functools import lru_cache
Expand All @@ -9,26 +8,12 @@
import vllm.envs as envs
from vllm.attention.backends.abstract import AttentionBackend
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.platforms import _Backend, current_platform
from vllm.utils import STR_BACKEND_ENV_VAR

logger = init_logger(__name__)


class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()


def backend_name_to_enum(backend_name: str) -> _Backend:
assert backend_name is not None

Expand Down Expand Up @@ -216,40 +201,11 @@ def which_attn_to_use(head_size: int,
if backend_by_env_var is not None:
selected_backend = backend_name_to_enum(backend_by_env_var)

if current_platform.is_cpu():
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA

if current_platform.is_openvino():
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO

if current_platform.is_xpu():
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX

if current_platform.is_tpu():
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS

if current_platform.is_rocm():
# AMD GPUs.
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if not current_platform.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH

if current_platform.is_hpu():
return _Backend.HPU_ATTN
# get device-specific default attn_backend
default_backend = current_platform.get_default_attn_backend(
selected_backend)
if default_backend is not None:
return default_backend

if use_v1:
return _Backend.FLASH_ATTN_VLLM_V1
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
from transformers import PretrainedConfig

from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import _Backend
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
Expand All @@ -38,6 +37,7 @@
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors,
SequenceData)
from vllm.transformers_utils.processor import get_processor
Expand Down
2 changes: 1 addition & 1 deletion vllm/model_executor/models/qwen2_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
make_batched_images, make_batched_videos, smart_resize)

from vllm.attention import AttentionMetadata
from vllm.attention.selector import _Backend
from vllm.config import VllmConfig
from vllm.distributed import get_pp_group, parallel_state
from vllm.distributed import utils as dist_utils
Expand All @@ -65,6 +64,7 @@
from vllm.multimodal.inputs import (MultiModalData, MultiModalDataDict,
MultiModalKwargs)
from vllm.multimodal.utils import cached_get_tokenizer
from vllm.platforms import _Backend
from vllm.sequence import IntermediateTensors, PoolerOutput, SequenceData
from vllm.transformers_utils.config import uses_mrope
from vllm.transformers_utils.processor import cached_get_processor
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from transformers import PretrainedConfig

import vllm.envs as envs
from vllm.attention.selector import (_Backend, backend_name_to_enum,
from vllm.attention.selector import (backend_name_to_enum,
get_global_forced_attn_backend)
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.multimodal import MultiModalPlaceholderMap, NestedTensors
from vllm.platforms import current_platform
from vllm.platforms import _Backend, current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils import is_pin_memory_available

Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .interface import _Backend # noqa: F401
from .interface import Platform, PlatformEnum, UnspecifiedPlatform

current_platform: Platform
Expand Down
12 changes: 11 additions & 1 deletion vllm/platforms/cpu.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import psutil
import torch

from .interface import Platform, PlatformEnum
from vllm.logger import init_logger

from .interface import Platform, PlatformEnum, _Backend

logger = init_logger(__name__)


class CpuPlatform(Platform):
Expand All @@ -11,6 +15,12 @@ class CpuPlatform(Platform):
def get_device_name(cls, device_id: int = 0) -> str:
return "cpu"

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
return _Backend.TORCH_SDPA

@classmethod
def get_device_total_memory(cls, device_id: int = 0) -> int:
return psutil.virtual_memory().total
Expand Down
6 changes: 5 additions & 1 deletion vllm/platforms/hpu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import torch

from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend


class HpuPlatform(Platform):
_enum = PlatformEnum.HPU

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
return _Backend.HPU_ATTN

@staticmethod
def inference_mode():
return torch.no_grad()
19 changes: 19 additions & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@
import torch


class _Backend(enum.Enum):
FLASH_ATTN = enum.auto()
FLASH_ATTN_VLLM_V1 = enum.auto()
XFORMERS = enum.auto()
ROCM_FLASH = enum.auto()
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
NO_ATTENTION = enum.auto()


class PlatformEnum(enum.Enum):
CUDA = enum.auto()
ROCM = enum.auto()
Expand Down Expand Up @@ -66,6 +80,11 @@ def is_cuda_alike(self) -> bool:
"""Stateless version of :func:`torch.cuda.is_available`."""
return self._enum in (PlatformEnum.CUDA, PlatformEnum.ROCM)

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend):
"""Get the default attention backend of a device."""
return None

@classmethod
def get_device_capability(
cls,
Expand Down
8 changes: 7 additions & 1 deletion vllm/platforms/openvino.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,20 @@
import vllm.envs as envs
from vllm.logger import init_logger

from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend

logger = init_logger(__name__)


class OpenVinoPlatform(Platform):
_enum = PlatformEnum.OPENVINO

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
return _Backend.OPENVINO

@classmethod
def get_device_name(self, device_id: int = 0) -> str:
return "openvino"
Expand Down
14 changes: 13 additions & 1 deletion vllm/platforms/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from vllm.logger import init_logger

from .interface import DeviceCapability, Platform, PlatformEnum
from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

logger = init_logger(__name__)

Expand All @@ -19,6 +19,18 @@
class RocmPlatform(Platform):
_enum = PlatformEnum.ROCM

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
if not cls.has_device_capability(90):
# not Instinct series GPUs.
logger.info("flash_attn is not supported on NAVI GPUs.")
else:
logger.info("%s is not supported in AMD GPUs.", selected_backend)
return _Backend.ROCM_FLASH

@classmethod
@lru_cache(maxsize=8)
def get_device_capability(cls, device_id: int = 0) -> DeviceCapability:
Expand Down
11 changes: 10 additions & 1 deletion vllm/platforms/tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@

import vllm.envs as envs
from vllm.compilation.levels import CompilationLevel
from vllm.logger import init_logger
from vllm.plugins import set_torch_compile_backend

from .interface import Platform, PlatformEnum
from .interface import Platform, PlatformEnum, _Backend

if "VLLM_TORCH_COMPILE_LEVEL" not in os.environ:
os.environ["VLLM_TORCH_COMPILE_LEVEL"] = str(CompilationLevel.DYNAMO_ONCE)
Expand All @@ -16,10 +17,18 @@

set_torch_compile_backend("openxla")

logger = init_logger(__name__)


class TpuPlatform(Platform):
_enum = PlatformEnum.TPU

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
return _Backend.PALLAS

@classmethod
def get_device_name(cls, device_id: int = 0) -> str:
raise NotImplementedError
Expand Down
12 changes: 11 additions & 1 deletion vllm/platforms/xpu.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,21 @@
import torch

from .interface import DeviceCapability, Platform, PlatformEnum
from vllm.logger import init_logger

from .interface import DeviceCapability, Platform, PlatformEnum, _Backend

logger = init_logger(__name__)


class XPUPlatform(Platform):
_enum = PlatformEnum.XPU

@classmethod
def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend:
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
return _Backend.IPEX

@staticmethod
def get_device_capability(device_id: int = 0) -> DeviceCapability:
major, minor, *_ = torch.xpu.get_device_capability(
Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from vllm.attention.backends.abstract import (AttentionBackend,
AttentionMetadata)
from vllm.attention.backends.utils import PAD_SLOT_ID
from vllm.attention.selector import (_Backend, get_env_variable_attn_backend,
from vllm.attention.selector import (get_env_variable_attn_backend,
get_global_forced_attn_backend)
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
Expand All @@ -18,6 +18,7 @@
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalKwargs,
MultiModalRegistry)
from vllm.platforms import _Backend
from vllm.sampling_params import SamplingParams
from vllm.sequence import (IntermediateTensors, PoolerOutput,
SequenceGroupMetadata)
Expand Down