From 6650e6a930dbdf1cd4def9b58e952376400ccfcf Mon Sep 17 00:00:00 2001 From: kakao-kevin-us Date: Sun, 27 Oct 2024 02:53:35 +0900 Subject: [PATCH] [Model] Add classification Task with Qwen2ForSequenceClassification (#9704) Signed-off-by: Kevin-Yang Co-authored-by: Kevin-Yang --- docs/source/models/supported_models.rst | 22 ++++ tests/conftest.py | 19 ++++ .../embedding/language/test_cls_models.py | 53 +++++++++ vllm/model_executor/layers/pooler.py | 9 +- vllm/model_executor/models/qwen2_cls.py | 107 ++++++++++++++++++ vllm/model_executor/models/registry.py | 2 + 6 files changed, 211 insertions(+), 1 deletion(-) create mode 100644 tests/models/embedding/language/test_cls_models.py create mode 100644 vllm/model_executor/models/qwen2_cls.py diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index 98d804052b575..ff893b613f150 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -361,6 +361,28 @@ Reward Modeling .. note:: As an interim measure, these models are supported via Embeddings API. See `this RFC `_ for upcoming changes. +Classification +--------------- + +.. list-table:: + :widths: 25 25 50 5 5 + :header-rows: 1 + + * - Architecture + - Models + - Example HF Models + - :ref:`LoRA ` + - :ref:`PP ` + * - :code:`Qwen2ForSequenceClassification` + - Qwen2-based + - :code:`jason9693/Qwen2.5-1.5B-apeach`, etc. + - + - ✅︎ + +.. note:: + As an interim measure, these models are supported via Embeddings API. It will be supported via Classification API in the future (no reference APIs exist now). + + Multimodal Language Models ^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/conftest.py b/tests/conftest.py index 6adff5e2328c4..2fce2d772c6ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -343,6 +343,17 @@ def get_inputs( return all_inputs + def classify(self, prompts: List[str]) -> List[str]: + # output is final logits + all_inputs = self.get_inputs(prompts) + outputs = [] + for inputs in all_inputs: + output = self.model(**self.wrap_device(inputs)) + logits = output.logits.softmax(dim=-1)[0].tolist() + outputs.append(logits) + + return outputs + def generate( self, prompts: List[str], @@ -688,6 +699,14 @@ def get_inputs( return inputs + def classify(self, prompts: List[str]) -> List[str]: + req_outputs = self.model.encode(prompts) + outputs = [] + for req_output in req_outputs: + embedding = req_output.outputs.embedding + outputs.append(embedding) + return outputs + def generate( self, prompts: List[str], diff --git a/tests/models/embedding/language/test_cls_models.py b/tests/models/embedding/language/test_cls_models.py new file mode 100644 index 0000000000000..d8ca6d361f0e3 --- /dev/null +++ b/tests/models/embedding/language/test_cls_models.py @@ -0,0 +1,53 @@ +"""Compare the outputs of HF and vLLM when using greedy sampling. + +This test only tests small models. Big models such as 7B should be tested from +test_big_models.py because it could use a larger instance to run tests. + +Run `pytest tests/models/test_cls_models.py`. +""" +import pytest +import torch +from transformers import AutoModelForSequenceClassification + +CLASSIFICATION_MODELS = ["jason9693/Qwen2.5-1.5B-apeach"] + + +@pytest.mark.parametrize("model", CLASSIFICATION_MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_classification_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, +) -> None: + with hf_runner(model, + dtype=dtype, + auto_cls=AutoModelForSequenceClassification) as hf_model: + hf_outputs = hf_model.classify(example_prompts) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.classify(example_prompts) + + print(hf_outputs, vllm_outputs) + + # check logits difference + for hf_output, vllm_output in zip(hf_outputs, vllm_outputs): + hf_output = torch.tensor(hf_output) + vllm_output = torch.tensor(vllm_output) + + assert torch.allclose(hf_output, vllm_output, 1e-3) + + +@pytest.mark.parametrize("model", CLASSIFICATION_MODELS) +@pytest.mark.parametrize("dtype", ["float"]) +def test_classification_model_print( + vllm_runner, + model: str, + dtype: str, +) -> None: + with vllm_runner(model, dtype=dtype) as vllm_model: + # This test is for verifying whether the model's extra_repr + # can be printed correctly. + print(vllm_model.model.llm_engine.model_executor.driver_worker. + model_runner.model) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index 3455a4ccf282f..0a1df9cb699ae 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -28,11 +28,15 @@ class Pooler(nn.Module): normalize: Whether to normalize the pooled data. """ - def __init__(self, pooling_type: PoolingType, normalize: bool): + def __init__(self, + pooling_type: PoolingType, + normalize: bool, + softmax: bool = False): super().__init__() self.pooling_type = pooling_type self.normalize = normalize + self.softmax = softmax def forward( self, @@ -64,6 +68,9 @@ def forward( if self.normalize: pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + if self.softmax: + pooled_data = nn.functional.softmax(pooled_data, dim=-1) + pooled_outputs = [ EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data ] diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py new file mode 100644 index 0000000000000..e10c6dbbb6472 --- /dev/null +++ b/vllm/model_executor/models/qwen2_cls.py @@ -0,0 +1,107 @@ +# coding=utf-8 +# Adapted from +# https://huggingface.co/Qwen/Qwen2.5-Math-RM-72B/blob/main/modeling_qwen2_rm.py +# Copyright 2024 Kakao Corp. (Kanana-X Team) +# Copyright 2024 The Qwen team. +# Copyright 2023 The vLLM team. +"""Inference-only Qwen2-Classification model compatible with HF weights.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers import Qwen2Config + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.model_executor.layers.linear import RowParallelLinear +from vllm.model_executor.layers.pooler import Pooler, PoolingType +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.models.qwen2 import Qwen2Model +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput + +from .utils import AutoWeightsLoader + + +class Qwen2ForSequenceClassification(nn.Module): + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + "gate_up_proj": [ + "gate_proj", + "up_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "gate_up_proj", + "down_proj", + ] + embedding_modules = {} + embedding_padding_modules = [] + + def __init__( + self, + config: Qwen2Config, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + # TODO (@robertgshaw2): see if this can be moved out + if (cache_config.sliding_window is not None + and hasattr(config, "max_window_layers")): + raise ValueError("Sliding window for some but all layers is not " + "supported. This model uses sliding window " + "but `max_window_layers` = %s is less than " + "`num_hidden_layers` = %s. Please open an issue " + "to discuss this feature." % ( + config.max_window_layers, + config.num_hidden_layers, + )) + + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.quant_config = quant_config + self.model = Qwen2Model(config, cache_config, quant_config) + + self.score = RowParallelLinear(config.hidden_size, + config.num_labels, + quant_config=quant_config) + self._pooler = Pooler(pooling_type=PoolingType.LAST, + normalize=False, + softmax=True) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + logits, _ = self.score(hidden_states) + return logits + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + loader = AutoWeightsLoader(self, + ignore_unexpected_prefixes=["lm_head."]) + loader.load_weights(weights) diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 717615988a907..f6713ab0898f0 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -96,6 +96,8 @@ "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), "MistralModel": ("llama", "LlamaEmbeddingModel"), "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), + "Qwen2ForSequenceClassification": ( + "qwen2_cls", "Qwen2ForSequenceClassification"), # [Multimodal] "LlavaNextForConditionalGeneration": ("llava_next", "LlavaNextForConditionalGeneration"), # noqa: E501 "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"),