diff --git a/tests/multimodal/test_mapper.py b/tests/multimodal/test_mapper.py index 13c1d1e342c7b..bb327b975476f 100644 --- a/tests/multimodal/test_mapper.py +++ b/tests/multimodal/test_mapper.py @@ -23,15 +23,14 @@ def test_clip_image_processor(hf_images, dtype): seed=0, dtype=dtype, revision=None, - ) - vlm_config = VisionLanguageConfig( - image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, - image_token_id=32000, - image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), - image_feature_size=576, - image_processor=MODEL_NAME, - image_processor_revision=None, - ) + multimodal_config=VisionLanguageConfig( + image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, + image_token_id=32000, + image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), + image_feature_size=576, + image_processor=MODEL_NAME, + image_processor_revision=None, + )) for image in hf_images: hf_result = hf_processor.preprocess( @@ -39,9 +38,8 @@ def test_clip_image_processor(hf_images, dtype): return_tensors="np", ) vllm_result = MULTIMODAL_REGISTRY.map_input( + model_config, ImagePixelData(image), - model_config=model_config, - vlm_config=vlm_config, ) assert hf_result.keys() == vllm_result.keys() @@ -65,26 +63,23 @@ def test_image_pixel_types(hf_images, vllm_image_tensors, dtype): seed=0, dtype=dtype, revision=None, - ) - vlm_config = VisionLanguageConfig( - image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, - image_token_id=32000, - image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), - image_feature_size=576, - image_processor=MODEL_NAME, - image_processor_revision=None, - ) + multimodal_config=VisionLanguageConfig( + image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES, + image_token_id=32000, + image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH), + image_feature_size=576, + image_processor=MODEL_NAME, + image_processor_revision=None, + )) for image, tensor in zip(hf_images, vllm_image_tensors): image_result = MULTIMODAL_REGISTRY.map_input( + model_config, ImagePixelData(image), - model_config=model_config, - vlm_config=vlm_config, ) tensor_result = MULTIMODAL_REGISTRY.map_input( + model_config, ImagePixelData(tensor), - model_config=model_config, - vlm_config=vlm_config, ) assert image_result.keys() == tensor_result.keys() diff --git a/vllm/inputs/registry.py b/vllm/inputs/registry.py index edb4ee1823f08..6efb0d4d21181 100644 --- a/vllm/inputs/registry.py +++ b/vllm/inputs/registry.py @@ -280,7 +280,7 @@ def process_input(self, model_config: "ModelConfig", return processor(model_config, inputs) - def create_input_processor(self, model_config: ModelConfig): + def create_input_processor(self, model_config: "ModelConfig"): """ Create an input processor (see :meth:`process_input`) for a specific model. diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index 9f252af13d363..49f5ad67907e7 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -2,7 +2,7 @@ from typing import (TYPE_CHECKING, Callable, Dict, Generic, Optional, Type, TypeVar) -from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.config import ModelConfig from vllm.logger import init_logger if TYPE_CHECKING: @@ -32,8 +32,7 @@ class MultiModalData: D = TypeVar("D", bound=MultiModalData) N = TypeVar("N", bound=Type["nn.Module"]) -MultiModalInputMapper = Callable[[D, ModelConfig, VisionLanguageConfig], - Dict[str, "torch.Tensor"]] +MultiModalInputMapper = Callable[[ModelConfig, D], Dict[str, "torch.Tensor"]] """Return a dictionary to be passed as keyword arguments to :meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers and processors in HuggingFace Transformers.""" @@ -63,9 +62,8 @@ def get_data_type(self) -> Type[D]: raise NotImplementedError @abstractmethod - def _default_input_mapper( - self, data: D, model_config: ModelConfig, - vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]: + def _default_input_mapper(self, model_config: ModelConfig, + data: D) -> Dict[str, "torch.Tensor"]: """Return a dictionary to be passed as keyword arguments to :meth:`torch.nn.Module.forward`. This is similar in concept to tokenizers and processors in HuggingFace Transformers. @@ -99,16 +97,11 @@ def wrapper(model_cls: N) -> N: return wrapper - def map_input( - self, data: D, model_config: ModelConfig, - vlm_config: VisionLanguageConfig) -> Dict[str, "torch.Tensor"]: + def map_input(self, model_config: ModelConfig, + data: D) -> Dict[str, "torch.Tensor"]: """ Apply an input mapper to a :class:`~MultiModalData` instance passed to the model, transforming the data into a dictionary of model inputs. - - The model is identified by ``model_config``. ``vlm_config`` is - for compatibility purposes and may be merged into ``model_config`` - in the near future. """ # Avoid circular import from vllm.model_executor.model_loader import get_model_architecture @@ -120,4 +113,4 @@ def map_input( raise KeyError(f"No input mapper in {self} is registered for " f"model class {model_cls.__name__}.") - return mapper(data, model_config, vlm_config) + return mapper(model_config, data) diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index e7e1e5bbe93c8..606afd412b06d 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -227,8 +227,8 @@ class ImagePixelPlugin(MultiModalPlugin[ImagePixelData]): def get_data_type(self) -> Type[ImagePixelData]: return ImagePixelData - def _get_hf_image_processor(self, model_config: ModelConfig, - vlm_config: VisionLanguageConfig): + def _get_hf_image_processor(self, model_config: ModelConfig): + vlm_config = model_config.multimodal_config if vlm_config is None or vlm_config.image_processor is None: return None @@ -238,12 +238,10 @@ def _get_hf_image_processor(self, model_config: ModelConfig, revision=vlm_config.image_processor_revision, ) - def _default_input_mapper( - self, data: ImagePixelData, model_config: ModelConfig, - vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: + def _default_input_mapper(self, model_config: ModelConfig, + data: ImagePixelData) -> Dict[str, torch.Tensor]: image = data.image - image_processor = self._get_hf_image_processor(model_config, - vlm_config) + image_processor = self._get_hf_image_processor(model_config) if isinstance(image, Image.Image): if image_processor is None: @@ -280,8 +278,8 @@ def get_data_type(self) -> Type[ImageFeatureData]: return ImageFeatureData def _default_input_mapper( - self, data: ImageFeatureData, model_config: ModelConfig, - vlm_config: VisionLanguageConfig) -> Dict[str, torch.Tensor]: + self, model_config: ModelConfig, + data: ImageFeatureData) -> Dict[str, torch.Tensor]: image_features = data.image_features.to(model_config.dtype) return {"image_features": image_features} diff --git a/vllm/multimodal/registry.py b/vllm/multimodal/registry.py index 13d8059c279a1..9b8e3e7d3b891 100644 --- a/vllm/multimodal/registry.py +++ b/vllm/multimodal/registry.py @@ -3,7 +3,7 @@ from torch import nn -from vllm.config import ModelConfig, VisionLanguageConfig +from vllm.config import ModelConfig from vllm.logger import init_logger from .base import MultiModalData, MultiModalInputMapper, MultiModalPlugin @@ -86,8 +86,7 @@ def register_image_feature_input_mapper( """ return self.register_input_mapper(ImageFeatureData, mapper) - def map_input(self, data: MultiModalData, model_config: ModelConfig, - vlm_config: VisionLanguageConfig): + def map_input(self, model_config: ModelConfig, data: MultiModalData): """ Apply an input mapper to a :class:`~MultiModalData` instance passed to the model. @@ -95,13 +94,10 @@ def map_input(self, data: MultiModalData, model_config: ModelConfig, See :meth:`MultiModalPlugin.map_input` for more details. """ return self._get_plugin_for_data_type(type(data)) \ - .map_input(data, model_config, vlm_config) + .map_input(model_config, data) - def create_input_mapper(self, model_config: ModelConfig, - vlm_config: VisionLanguageConfig): + def create_input_mapper(self, model_config: ModelConfig): """ Create an input mapper (see :meth:`map_input`) for a specific model. """ - return functools.partial(self.map_input, - model_config=model_config, - vlm_config=vlm_config) + return functools.partial(self.map_input, model_config=model_config) diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index e8c79ce9d9d55..95d8e44f51119 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -66,14 +66,8 @@ def __init__( ) # Create processor for multi-modal data - if self.vision_language_config is not None: - self.multi_modal_input_mapper = INPUT_REGISTRY.MULTIMODAL \ - .create_input_mapper( - self.model_config, - self.vision_language_config, - ) - else: - self.multi_modal_input_mapper = None + self.multi_modal_input_mapper = INPUT_REGISTRY.MULTIMODAL \ + .create_input_mapper(self.model_config) # Lazy initialization. self.model: nn.Module # Set after init_Model @@ -123,12 +117,6 @@ def _prepare_prompt( mm_data = seq_group_metadata.multi_modal_data if mm_data is not None: - # Process multi-modal data - if self.multi_modal_input_mapper is None: - raise ValueError( - "Multi-modal inputs are only supported by " - "vision language models.") - mm_kwargs = self.multi_modal_input_mapper(mm_data) for k, v in mm_kwargs.items(): multi_modal_kwargs_list[k].append(v) diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 94480c9f90953..995751884e9d3 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -124,14 +124,8 @@ def __init__( ) # Create processor for multi-modal data - if self.vision_language_config is not None: - self.multi_modal_input_mapper = INPUT_REGISTRY.MULTIMODAL \ - .create_input_mapper( - self.model_config, - self.vision_language_config, - ) - else: - self.multi_modal_input_mapper = None + self.multi_modal_input_mapper = INPUT_REGISTRY.MULTIMODAL \ + .create_input_mapper(self.model_config) # Lazy initialization self.model: nn.Module # Set after load_model @@ -432,11 +426,6 @@ def _prepare_model_input( mm_data = seq_group_metadata.multi_modal_data if mm_data is not None: # Process multi-modal data - if self.multi_modal_input_mapper is None: - raise ValueError( - "Multi-modal inputs are only supported by " - "vision language models.") - mm_kwargs = self.multi_modal_input_mapper(mm_data) for k, v in mm_kwargs.items(): multi_modal_kwargs_list[k].append(v)