Skip to content

Commit

Permalink
[Core] Registry for processing model inputs (vllm-project#5214)
Browse files Browse the repository at this point in the history
Co-authored-by: ywang96 <[email protected]>
  • Loading branch information
2 people authored and prashantgupta24 committed Jun 28, 2024
1 parent bba1cc6 commit 4a5916d
Show file tree
Hide file tree
Showing 25 changed files with 768 additions and 354 deletions.
20 changes: 20 additions & 0 deletions docs/source/dev/input_processing/input_processing_pipeline.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
.. _input_processing_pipeline:

Input Processing Pipeline
=========================

1. Input data is passed to :class:`~vllm.LLMEngine` (or :class:`~vllm.AsyncLLMEngine`).

2. Tokenize the data if necessary.

3. Process the inputs using :meth:`INPUT_REGISTRY.process_input <vllm.inputs.registry.InputRegistry.process_input>`.

- For example, add placeholder tokens to reserve KV cache for multi-modal embeddings.

4. Send the processed inputs to :class:`~vllm.executor.executor_base.ExecutorBase`.

5. Distribute the inputs via :class:`~vllm.worker.worker_base.WorkerBase` to :class:`~vllm.worker.model_runner_base.ModelRunnerBase`.

6. If the data contains multi-modal data, convert it into keyword arguments using :meth:`MULTIMODAL_REGISTRY.map_input <vllm.multimodal.MultiModalRegistry.map_input>`.

- For example, convert a :class:`PIL.Image.Image` input to its pixel values for a vision language model.
39 changes: 39 additions & 0 deletions docs/source/dev/input_processing/model_inputs_index.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
.. _input_processing:

Input Processing
================

.. currentmodule:: vllm.inputs

vLLM provides a mechanism for defining input processors for each model so that the inputs are processed
in :class:`~vllm.LLMEngine` before they are passed to model executors.

Currently, this mechanism is only utilized in **multi-modal models** for preprocessing multi-modal input
data in addition to input prompt, but it can be extended to text-only language models when needed.

Guides
++++++

.. toctree::
:maxdepth: 1

input_processing_pipeline

Module Contents
+++++++++++++++

LLM Engine Inputs
-----------------

.. autoclass:: vllm.inputs.LLMInputs
:members:
:show-inheritance:

Registry
--------

.. autodata:: vllm.inputs.INPUT_REGISTRY

.. automodule:: vllm.inputs.registry
:members:
:show-inheritance:
8 changes: 1 addition & 7 deletions docs/source/dev/multimodal/multimodal_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,6 @@ By default, vLLM models do not support multi-modal inputs. To enable multi-modal
you must decorate the model class with :meth:`MULTIMODAL_REGISTRY.register_dummy_data <MultiModalRegistry.register_dummy_data>`,
as well as :meth:`MULTIMODAL_REGISTRY.register_input <MultiModalRegistry.register_input>` for each modality type to support.

.. contents::
:local:
:backlinks: none

Module Contents
+++++++++++++++

Expand All @@ -24,9 +20,7 @@ Module Contents
Registry
--------

.. data:: vllm.multimodal.MULTIMODAL_REGISTRY

The global :class:`MultiModalRegistry` which is used by model runners.
.. autodata:: vllm.multimodal.MULTIMODAL_REGISTRY

.. autoclass:: vllm.multimodal.MultiModalRegistry
:members:
Expand Down
1 change: 1 addition & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ Documentation
dev/offline_inference/offline_index
dev/engine/engine_index
dev/kernel/paged_attention
dev/input_processing/model_inputs_index
dev/multimodal/multimodal_index
dev/dockerfile/dockerfile

Expand Down
4 changes: 2 additions & 2 deletions docs/source/models/adding_model.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ For instance, vLLM's `OPT model <https://github.com/vllm-project/vllm/blob/main/
2. Rewrite the :code:`forward` methods
--------------------------------------

Next, you need to rewrite the :code:`forward` methods of your model by following these steps:
Next, you need to rewrite the :meth:`~torch.nn.Module.forward` method of your model by following these steps:

1. Remove any unnecessary code, such as the code only used for training.
2. Change the input parameters:
Expand Down Expand Up @@ -75,7 +75,7 @@ Next, you need to rewrite the :code:`forward` methods of your model by following

If your model is too large to fit into a single GPU, you can use tensor parallelism to manage it.
To do this, substitute your model's linear and embedding layers with their tensor-parallel versions.
For the embedding layer, you can simply replace :code:`nn.Embedding` with :code:`VocabParallelEmbedding`. For the output LM head, you can use :code:`ParallelLMHead`.
For the embedding layer, you can simply replace :class:`torch.nn.Embedding` with :code:`VocabParallelEmbedding`. For the output LM head, you can use :code:`ParallelLMHead`.
When it comes to the linear layers, we provide the following options to parallelize them:

* :code:`ReplicatedLinear`: Replicates the inputs and weights across multiple GPUs. No memory saving.
Expand Down
3 changes: 2 additions & 1 deletion examples/phi3v_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ def run_phi3v():
model_path = "microsoft/Phi-3-vision-128k-instruct"

# Note: The model has 128k context length by default which may cause OOM
# If that's the case, override `max_model_len` with a smaller value via args
# In this example, we override max_model_len to 2048.
llm = LLM(
model=model_path,
trust_remote_code=True,
image_input_type="pixel_values",
image_token_id=32044,
image_input_shape="1,3,1008,1344",
image_feature_size=1921,
max_model_len=2048,
)

image = Image.open("images/cherry_blossom.jpg")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,24 @@ def test_clip_image_processor(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
),
)

for asset in image_assets:
hf_result = hf_processor.preprocess(
asset.pil_image,
return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
vllm_result = MULTIMODAL_REGISTRY.process_input(
vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
model_config=model_config,
vlm_config=vlm_config,
)

assert hf_result.keys() == vllm_result.keys()
Expand Down Expand Up @@ -74,25 +73,24 @@ def test_llava_next_image_processor(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=64000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=2928,
image_processor=MODEL_NAME,
image_processor_revision=None,
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=64000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=2928,
image_processor=MODEL_NAME,
image_processor_revision=None,
),
)

for asset in image_assets:
hf_result = hf_processor.preprocess(
asset.pil_image,
return_tensors="pt",
).to(dtype=_STR_DTYPE_TO_TORCH_DTYPE[dtype])
vllm_result = MULTIMODAL_REGISTRY.process_input(
vllm_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
model_config=model_config,
vlm_config=vlm_config,
)

assert hf_result.keys() == vllm_result.keys()
Expand All @@ -119,26 +117,23 @@ def test_image_pixel_types(image_assets, dtype):
seed=0,
dtype=dtype,
revision=None,
)
vlm_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
)
multimodal_config=VisionLanguageConfig(
image_input_type=VisionLanguageConfig.ImageInputType.PIXEL_VALUES,
image_token_id=32000,
image_input_shape=(1, 3, IMAGE_HEIGHT, IMAGE_WIDTH),
image_feature_size=576,
image_processor=MODEL_NAME,
image_processor_revision=None,
))

for asset in image_assets:
image_result = MULTIMODAL_REGISTRY.process_input(
image_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pil_image),
model_config=model_config,
vlm_config=vlm_config,
)
tensor_result = MULTIMODAL_REGISTRY.process_input(
tensor_result = MULTIMODAL_REGISTRY.map_input(
model_config,
ImagePixelData(asset.pixel_values),
model_config=model_config,
vlm_config=vlm_config,
)

assert image_result.keys() == tensor_result.keys()
Expand Down
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ def __init__(
disable_sliding_window: bool = False,
skip_tokenizer_init: bool = False,
served_model_name: Optional[Union[str, List[str]]] = None,
multimodal_config: Optional["VisionLanguageConfig"] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand Down Expand Up @@ -159,6 +160,8 @@ def __init__(
sliding_window_len=self.get_hf_config_sliding_window())
self.served_model_name = get_served_model_name(model,
served_model_name)
self.multimodal_config = multimodal_config

if not self.skip_tokenizer_init:
self._verify_tokenizer_mode()
self._verify_embedding_mode()
Expand Down
33 changes: 32 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,36 @@ def create_engine_config(self, ) -> EngineConfig:
raise ValueError(
"BitsAndBytes load format and QLoRA adapter only support "
f"'bitsandbytes' quantization, but got {self.quantization}")
if self.image_input_type:
if (not self.image_token_id or not self.image_input_shape
or not self.image_feature_size):
raise ValueError(
'Specify `image_token_id`, `image_input_shape` and '
'`image_feature_size` together with `image_input_type`.')

if self.image_processor is None:
self.image_processor = self.model
if self.disable_image_processor:
if self.image_processor != self.model:
warnings.warn(
"You've specified an image processor "
f"({self.image_processor}) but also disabled "
"it via `--disable-image-processor`.",
stacklevel=2)

self.image_processor = None

vision_language_config = VisionLanguageConfig(
image_input_type=VisionLanguageConfig.
get_image_input_enum_type(self.image_input_type),
image_token_id=self.image_token_id,
image_input_shape=str_to_int_tuple(self.image_input_shape),
image_feature_size=self.image_feature_size,
image_processor=self.image_processor,
image_processor_revision=self.image_processor_revision,
)
else:
vision_language_config = None

device_config = DeviceConfig(device=self.device)
model_config = ModelConfig(
Expand All @@ -682,7 +712,8 @@ def create_engine_config(self, ) -> EngineConfig:
max_logprobs=self.max_logprobs,
disable_sliding_window=self.disable_sliding_window,
skip_tokenizer_init=self.skip_tokenizer_init,
served_model_name=self.served_model_name)
served_model_name=self.served_model_name,
multimodal_config=vision_language_config)
cache_config = CacheConfig(
block_size=self.block_size,
gpu_memory_utilization=self.gpu_memory_utilization,
Expand Down
5 changes: 4 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vllm.engine.output_processor.util import create_output_by_sequence_group
from vllm.executor.executor_base import ExecutorBase
from vllm.executor.ray_utils import initialize_ray_cluster
from vllm.inputs import LLMInputs, PromptInputs
from vllm.inputs import INPUT_REGISTRY, LLMInputs, PromptInputs
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.outputs import (EmbeddingRequestOutput, RequestOutput,
Expand Down Expand Up @@ -230,6 +230,9 @@ def __init__(
self.generation_config_fields = _load_generation_config_dict(
model_config)

self.input_processor = INPUT_REGISTRY.create_input_processor(
self.model_config)

self.model_executor = executor_class(
model_config=model_config,
cache_config=cache_config,
Expand Down
19 changes: 19 additions & 0 deletions vllm/inputs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from .data import (LLMInputs, ParsedText, ParsedTokens, PromptInputs,
PromptStrictInputs, TextPrompt, TextTokensPrompt,
TokensPrompt, parse_and_batch_prompt)
from .registry import InputContext, InputRegistry

INPUT_REGISTRY = InputRegistry()
"""
The global :class:`~InputRegistry` which is used by :class:`~vllm.LLMEngine`
to dispatch data processing according to the target model.
See also:
:ref:`input_processing_pipeline`
"""

__all__ = [
"ParsedText", "ParsedTokens", "parse_and_batch_prompt", "TextPrompt",
"TokensPrompt", "TextTokensPrompt", "PromptStrictInputs", "PromptInputs",
"LLMInputs", "INPUT_REGISTRY", "InputContext", "InputRegistry"
]
18 changes: 16 additions & 2 deletions vllm/inputs.py → vllm/inputs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ class TextTokensPrompt(TypedDict):
"""The prompt text."""

prompt_token_ids: List[int]
"""The token IDs of the prompt. If None, we use the
tokenizer to convert the prompts to token IDs."""
"""The token IDs of the prompt."""

multi_modal_data: NotRequired["MultiModalData"]
"""
Expand All @@ -125,6 +124,21 @@ class TextTokensPrompt(TypedDict):


class LLMInputs(TypedDict):
"""
The inputs in :class:`~vllm.LLMEngine` before they are
passed to the model executor.
"""

prompt_token_ids: List[int]
"""The token IDs of the prompt."""

prompt: NotRequired[Optional[str]]
"""
The original prompt text corresponding to the token IDs, if available.
"""

multi_modal_data: NotRequired[Optional["MultiModalData"]]
"""
Optional multi-modal data to pass to the model,
if the model supports it.
"""
Loading

0 comments on commit 4a5916d

Please sign in to comment.