From 21513e9cc603be684ac1c5d9816cd022047092f8 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Mon, 10 Jun 2024 23:38:47 +0800 Subject: [PATCH] [Bugfix] Fix LLaVA-NeXT (#5380) --- vllm/model_executor/models/llava_next.py | 24 ++++++++++++++++++++++++ vllm/multimodal/utils.py | 2 +- 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index bb15dcb8ed917..57cbd1e4a6018 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -216,6 +216,30 @@ def _parse_and_validate_image_input( return None + def _select_image_features(self, image_features: torch.Tensor, *, + strategy: str) -> torch.Tensor: + # Copied from https://github.com/huggingface/transformers/blob/39c3c0a72af6fbda5614dde02ff236069bb79827/src/transformers/models/llava/modeling_llava.py#L421 # noqa + if strategy == "default": + return image_features[:, 1:] + elif strategy == "full": + return image_features + + raise ValueError(f"Unexpected select feature strategy: {strategy}") + + def _image_pixels_to_features(self, vision_tower: CLIPVisionModel, + pixel_values: torch.Tensor) -> torch.Tensor: + # TODO(xwjiang): Maybe port minimal CLIPVisionModel over. + image_outputs = vision_tower(pixel_values.to(vision_tower.device), + output_hidden_states=True) + + image_features = image_outputs.hidden_states[ + self.config.vision_feature_layer] + + return self._select_image_features( + image_features, + strategy=self.config.vision_feature_select_strategy, + ) + def _merge_image_patch_embeddings(self, image_size: torch.Tensor, patch_embeddings: torch.Tensor, *, strategy: str) -> torch.Tensor: diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index b8ad6f8f78e26..c6311d60e0bdd 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -77,7 +77,7 @@ def get_full_image_text_prompt(image_prompt: str, text_prompt: str, """Combine image and text prompts for vision language model depending on the model architecture.""" - if config.hf_config.model_type == "llava": + if config.hf_config.model_type in ("llava", "llava_next"): full_prompt = f"{image_prompt}\n{text_prompt}" else: raise ValueError(