From ab612a6110cd4f639c65d5e427b92699a74d089f Mon Sep 17 00:00:00 2001 From: Jayon02 Date: Thu, 27 Feb 2025 20:31:17 +0800 Subject: [PATCH] feat/1001: sglang integration (#1122) --- pyproject.toml | 12 +- src/distilabel/embeddings.py | 3 +- src/distilabel/llms.py | 3 + src/distilabel/models/__init__.py | 5 + src/distilabel/models/embeddings/__init__.py | 2 + src/distilabel/models/embeddings/sglang.py | 125 ++++ src/distilabel/models/llms/__init__.py | 3 + src/distilabel/models/llms/sglang.py | 738 +++++++++++++++++++ tests/unit/models/embeddings/test_sglang.py | 50 ++ tests/unit/models/llms/test_sglang.py | 299 ++++++++ 10 files changed, 1238 insertions(+), 2 deletions(-) create mode 100644 src/distilabel/models/embeddings/sglang.py create mode 100644 src/distilabel/models/llms/sglang.py create mode 100644 tests/unit/models/embeddings/test_sglang.py create mode 100644 tests/unit/models/llms/test_sglang.py diff --git a/pyproject.toml b/pyproject.toml index 30bd06262d..ac4af7db15 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -79,7 +79,7 @@ argilla = ["argilla >= 2.0.0", "ipython"] cohere = ["cohere >= 5.2.0"] groq = ["groq >= 0.4.1"] hf-inference-endpoints = ["huggingface_hub >= 0.22.0"] -hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"] +hf-transformers = ["transformers == 4.48.3", "torch >= 2.0.0"] instructor = ["instructor >= 1.2.3"] litellm = ["litellm >= 1.30.0"] llama-cpp = ["llama-cpp-python >= 0.2.0"] @@ -107,6 +107,16 @@ vision = ["Pillow >= 10.3.0"] # To work with images. # minhash minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"] +sglang = ["sglang[all]>=0.4.3.post2", "transformers == 4.48.3"] + +[tool.hatch.envs.default] +dependencies = [ + "sglang[all]>=0.4.3.post2", + "transformers == 4.48.3", +] +installer = "pip" +pip-args = ["--find-links", "https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python"] + [project.urls] Documentation = "https://distilabel.argilla.io/" Issues = "https://github.com/argilla/distilabel/issues" diff --git a/src/distilabel/embeddings.py b/src/distilabel/embeddings.py index aa470e5b4d..6a41ae8d47 100644 --- a/src/distilabel/embeddings.py +++ b/src/distilabel/embeddings.py @@ -27,10 +27,11 @@ from distilabel.models.embeddings.sentence_transformers import ( SentenceTransformerEmbeddings, ) -from distilabel.models.embeddings.vllm import vLLMEmbeddings +from distilabel.models.embeddings.vllm import SGLangEmbeddings, vLLMEmbeddings __all__ = [ "Embeddings", + "SGLangEmbeddings", "SentenceTransformerEmbeddings", "vLLMEmbeddings", ] diff --git a/src/distilabel/llms.py b/src/distilabel/llms.py index 730950a109..754889e4df 100644 --- a/src/distilabel/llms.py +++ b/src/distilabel/llms.py @@ -37,6 +37,7 @@ from distilabel.models.llms.moa import MixtureOfAgentsLLM from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM +from distilabel.models.llms.sglang import ClientSGLang, SGLang from distilabel.models.llms.together import TogetherLLM from distilabel.models.llms.vertexai import VertexAILLM from distilabel.models.llms.vllm import ClientvLLM, vLLM @@ -49,6 +50,7 @@ "AnyscaleLLM", "AsyncLLM", "AzureOpenAILLM", + "ClientSGLang", "ClientvLLM", "CohereLLM", "CudaDevicePlacementMixin", @@ -63,6 +65,7 @@ "MlxLLM", "OllamaLLM", "OpenAILLM", + "SGLang", "TogetherLLM", "TransformersLLM", "VertexAILLM", diff --git a/src/distilabel/models/__init__.py b/src/distilabel/models/__init__.py index 1c96f5ab0b..7247b3f098 100644 --- a/src/distilabel/models/__init__.py +++ b/src/distilabel/models/__init__.py @@ -18,6 +18,7 @@ from distilabel.models.embeddings.sentence_transformers import ( SentenceTransformerEmbeddings, ) +from distilabel.models.embeddings.sglang import SGLangEmbeddings from distilabel.models.embeddings.vllm import vLLMEmbeddings from distilabel.models.image_generation.base import ( AsyncImageGenerationModel, @@ -41,6 +42,7 @@ from distilabel.models.llms.moa import MixtureOfAgentsLLM from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM +from distilabel.models.llms.sglang import ClientSGLang, SGLang from distilabel.models.llms.together import TogetherLLM from distilabel.models.llms.vertexai import VertexAILLM from distilabel.models.llms.vllm import ClientvLLM, vLLM @@ -54,6 +56,7 @@ "AsyncImageGenerationModel", "AsyncLLM", "AzureOpenAILLM", + "ClientSGLang", "ClientvLLM", "CohereLLM", "CudaDevicePlacementMixin", @@ -73,6 +76,8 @@ "OllamaLLM", "OpenAIImageGeneration", "OpenAILLM", + "SGLang", + "SGLangEmbeddings", "SentenceTransformerEmbeddings", "TogetherLLM", "TransformersLLM", diff --git a/src/distilabel/models/embeddings/__init__.py b/src/distilabel/models/embeddings/__init__.py index 65eb00c469..6da7b8edc7 100644 --- a/src/distilabel/models/embeddings/__init__.py +++ b/src/distilabel/models/embeddings/__init__.py @@ -17,11 +17,13 @@ from distilabel.models.embeddings.sentence_transformers import ( SentenceTransformerEmbeddings, ) +from distilabel.models.embeddings.sglang import SGLangEmbeddings from distilabel.models.embeddings.vllm import vLLMEmbeddings __all__ = [ "Embeddings", "LlamaCppEmbeddings", + "SGLangEmbeddings", "SentenceTransformerEmbeddings", "vLLMEmbeddings", ] diff --git a/src/distilabel/models/embeddings/sglang.py b/src/distilabel/models/embeddings/sglang.py new file mode 100644 index 0000000000..d31e6c09da --- /dev/null +++ b/src/distilabel/models/embeddings/sglang.py @@ -0,0 +1,125 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union + +from pydantic import Field, PrivateAttr + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.models.embeddings.base import Embeddings +from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin + +if TYPE_CHECKING: + from sglang import Engine + + +class SGLangEmbeddings(Embeddings, CudaDevicePlacementMixin): + """`sglang` library implementation for embedding generation. + + Attributes: + model: the model Hugging Face Hub repo id or a path to a directory containing the + model weights and configuration files. + dtype: the data type to use for the model. Defaults to `auto`. + trust_remote_code: whether to trust the remote code when loading the model. Defaults + to `False`. + quantization: the quantization mode to use for the model. Defaults to `None`. + revision: the revision of the model to load. Defaults to `None`. + seed: the seed to use for the random number generator. Defaults to `0`. + extra_kwargs: additional dictionary of keyword arguments that will be passed to the + `Engine` class of `sglang` library. Defaults to `{}`. + _model: the `SGLang` model instance. This attribute is meant to be used internally + and should not be accessed directly. It will be set in the `load` method. + + References: + - https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py + + Examples: + Generating sentence embeddings: + + ```python + if __name__ == "__main__": + + from distilabel.models import SGLangEmbeddings + embeddings = SGLangEmbeddings(model="intfloat/e5-mistral-7b-instruct") + embeddings.load() + results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"]) + print(results) + # [ + # [0.0203704833984375, -0.0060882568359375, ...], + # [0.02398681640625, 0.0177001953125 ...], + # ] + ``` + """ + + model: str + dtype: str = "auto" + trust_remote_code: bool = False + quantization: Optional[str] = None + revision: Optional[str] = None + + seed: int = 0 + + extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field( + default_factory=dict, + description="Additional dictionary of keyword arguments that will be passed to the" + " `Engine` class of `sglang` library. See all the supported arguments at: " + "https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/engine.py", + ) + + _model: "Engine" = PrivateAttr(None) + + def load(self) -> None: + """Loads the `sglang` model using either the path or the Hugging Face Hub repository id.""" + super().load() + + CudaDevicePlacementMixin.load(self) + + try: + from sglang import Engine + except ImportError as err: + raise ImportError( + "sglang is not installed. Please install it with sglang document https://docs.sglang.ai/start/install.html." + ) from err + + self._model = Engine( + model_path=self.model, + dtype=self.dtype, + trust_remote_code=self.trust_remote_code, + quantization=self.quantization, + revision=self.revision, + random_seed=self.seed, + **self.extra_kwargs, # type: ignore + ) + + def unload(self) -> None: + """Unloads the `SGLang` model.""" + self._model = None + CudaDevicePlacementMixin.unload(self) + super().unload() + + @property + def model_name(self) -> str: + """Returns the name of the model.""" + return self.model + + def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]: + """Generates embeddings for the provided inputs. + + Args: + inputs: a list of texts for which an embedding has to be generated. + + Returns: + The generated embeddings. + """ + return [output["embedding"] for output in self._model.encode(inputs)] diff --git a/src/distilabel/models/llms/__init__.py b/src/distilabel/models/llms/__init__.py index 3469c1e2bc..ac9a229681 100644 --- a/src/distilabel/models/llms/__init__.py +++ b/src/distilabel/models/llms/__init__.py @@ -26,6 +26,7 @@ from distilabel.models.llms.moa import MixtureOfAgentsLLM from distilabel.models.llms.ollama import OllamaLLM from distilabel.models.llms.openai import OpenAILLM +from distilabel.models.llms.sglang import ClientSGLang, SGLang from distilabel.models.llms.together import TogetherLLM from distilabel.models.llms.vertexai import VertexAILLM from distilabel.models.llms.vllm import ClientvLLM, vLLM @@ -38,6 +39,7 @@ "AnyscaleLLM", "AsyncLLM", "AzureOpenAILLM", + "ClientSGLang", "ClientvLLM", "CohereLLM", "CudaDevicePlacementMixin", @@ -52,6 +54,7 @@ "MlxLLM", "OllamaLLM", "OpenAILLM", + "SGLang", "TogetherLLM", "TransformersLLM", "VertexAILLM", diff --git a/src/distilabel/models/llms/sglang.py b/src/distilabel/models/llms/sglang.py new file mode 100644 index 0000000000..c48def713c --- /dev/null +++ b/src/distilabel/models/llms/sglang.py @@ -0,0 +1,738 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import contextlib +import gc +import inspect +import json +from functools import cached_property +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + List, + Literal, + Optional, + Tuple, + Union, +) + +from pydantic import Field, PrivateAttr, SecretStr, validate_call + +from distilabel.mixins.runtime_parameters import RuntimeParameter +from distilabel.models.llms.base import LLM +from distilabel.models.llms.openai import OpenAILLM +from distilabel.models.llms.utils import compute_tokens, prepare_output +from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin +from distilabel.models.mixins.magpie import MagpieChatTemplateMixin +from distilabel.steps.tasks.structured_outputs.utils import schema_as_dict +from distilabel.typing import ( + FormattedInput, + GenerateOutput, + Logprob, + OutlinesStructuredOutputType, +) + +if TYPE_CHECKING: + from openai import OpenAI # noqa + from transformers import PreTrainedTokenizer + from sglang import Engine + + from distilabel.typing import ( + StandardInput, + StructuredInput, + LLMStatistics, + LLMLogprobs, + LLMOutput, + ) + + +LogitsProcessorFn = Union[ + Callable[[List[int], Any], Any], + Callable[[List[int], List[int], Any], Any], +] + +LogitsProcessors = List[LogitsProcessorFn] + + +class SGLang(LLM, MagpieChatTemplateMixin, CudaDevicePlacementMixin): + """`SGLang` library LLM implementation. + + Attributes: + model: the model Hugging Face Hub repo id or a path to a directory containing the + model weights and configuration files. + dtype: the data type to use for the model. Defaults to `auto`. + trust_remote_code: whether to trust the remote code when loading the model. Defaults + to `False`. + quantization: the quantization mode to use for the model. Defaults to `None`. + revision: the revision of the model to load. Defaults to `None`. + tokenizer: the tokenizer Hugging Face Hub repo id or a path to a directory containing + the tokenizer files. Defaults to `model`. + tokenizer_mode: the mode to use for the tokenizer. Defaults to `auto`. + skip_tokenizer_init: whether to skip the initialization of the tokenizer. Defaults + to `False`. + chat_template: a chat template that will be used to build the prompts before + sending them to the model. If not provided, the chat template defined in the + tokenizer config will be used. If not provided and the tokenizer doesn't have + a chat template, then ChatML template will be used. Defaults to `None`. + structured_output: a dictionary containing the structured output configuration or if more + fine-grained control is needed, an instance of `OutlinesStructuredOutput`. Defaults to None. + seed: the seed to use for the random number generator. Defaults to `0`. + extra_kwargs: additional dictionary of keyword arguments that will be passed to the + `Engine` class of `sglang` library. Defaults to `{}`. + _model: the `SGLang` model instance. This attribute is meant to be used internally + and should not be accessed directly. It will be set in the `load` method. + _tokenizer: the tokenizer instance used to format the prompt before passing it to + the `LLM`. It will be set in the `load` method. + use_magpie_template: a flag used to enable/disable applying the Magpie pre-query + template. Defaults to `False`. + magpie_pre_query_template: the pre-query template to be applied to the prompt or + sent to the LLM to generate an instruction or a follow up user message. Valid + values are "llama3", "qwen2" or another pre-query template provided. Defaults + to `None`. + + References: + - https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py + + Runtime parameters: + - `extra_kwargs`: additional dictionary of keyword arguments that will be passed to + the `LLM` class of `SGLang` library. + + Examples: + Generate text: + + ```python + from distilabel.models.llms import SGLang + if __name__ == "__main__": + llm = SGLang( + model="Qwen/Qwen2.5-Coder-3B-Instruct", + chat_template="[INST] {{ messages[0]['content']}} [/INST]" + ) + + llm.load() + # Call the model + output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Hello world!"}]]) + ``` + + Generate structured data: + + ```python + from distilabel.models.llms import SGLang + from pydantic import BaseModel + + if __name__ == "__main__": + + class User(BaseModel): + name: str + last_name: str + id: int + + llm = SGLang( + model="Qwen/Qwen2.5-Coder-3B-Instruct", + structured_output={"format": "json", "schema": User}, + ) + + llm.load() + # Call the model + output = llm.generate_outputs(inputs=[[{"role": "user", "content": "Create a user profile for the following marathon"}]]) + ``` + """ + + model: str + dtype: str = "auto" + trust_remote_code: bool = False + quantization: Optional[str] = None + revision: Optional[str] = None + + tokenizer: Optional[str] = None + tokenizer_mode: Literal["auto", "slow"] = "auto" + skip_tokenizer_init: bool = False + chat_template: Optional[str] = None + + seed: int = 0 + + extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field( + default_factory=dict, + description="Additional dictionary of keyword arguments that will be passed to the" + " `Engine` class of `sglang` library. See all the supported arguments at: " + "https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/engine.py", + ) + structured_output: Optional[RuntimeParameter[OutlinesStructuredOutputType]] = Field( + default=None, + description="The structured output format to use across all the generations.", + ) + + _model: "Engine" = PrivateAttr(None) + _tokenizer: "PreTrainedTokenizer" = PrivateAttr(None) + + def load(self) -> None: + """Loads the `SGLang` model using either the path or the Hugging Face Hub repository id. + Additionally, this method also sets the `chat_template` for the tokenizer, so as to properly + parse the list of OpenAI formatted inputs using the expected format by the model, otherwise, the + default value is ChatML format, unless explicitly provided. + """ + super().load() + + CudaDevicePlacementMixin.load(self) + + try: + from sglang import Engine + except ImportError as err: + raise ImportError( + "sglang is not installed. Please install it with sglang document https://docs.sglang.ai/start/install.html." + ) from err + + self._model = Engine( + model_path=self.model, + dtype=self.dtype, + trust_remote_code=self.trust_remote_code, + quantization=self.quantization, + revision=self.revision, + tokenizer_path=self.tokenizer, + tokenizer_mode=self.tokenizer_mode, + skip_tokenizer_init=self.skip_tokenizer_init, + random_seed=self.seed, + **self.extra_kwargs, # type: ignore + ) + from sglang.srt.hf_transformers_utils import get_tokenizer + + self._tokenizer = get_tokenizer( + self.model, + tokenizer_mode=self.tokenizer_mode, + trust_remote_code=self.trust_remote_code, + tokenizer_revision="main", + ) # type: ignore + if self.chat_template is not None: + self._tokenizer.chat_template = self.chat_template # type: ignore + + def unload(self) -> None: + """Unloads the `SGLang` model.""" + self._cleanup_sglang_model() + self._model = None # type: ignore + self._tokenizer = None # type: ignore + CudaDevicePlacementMixin.unload(self) + super().unload() + + def _cleanup_sglang_model(self) -> None: + if self._model is None: + return + + import torch # noqa + + self._model.shutdown() + with contextlib.suppress(AssertionError): + torch.distributed.destroy_process_group() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + + @property + def model_name(self) -> str: + """Returns the model name used for the LLM.""" + return self.model + + def prepare_input(self, input: Union["StandardInput", str]) -> str: + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. + """ + if isinstance(input, str): + return input + + prompt: str = ( + self._tokenizer.apply_chat_template( + input, # type: ignore + tokenize=False, + add_generation_prompt=True, # type: ignore + ) + if input + else "" + ) + return super().apply_magpie_pre_query_template(prompt, input) + + def _prepare_batches( + self, inputs: List["StructuredInput"] + ) -> Tuple[List[Tuple[List[str], "OutlinesStructuredOutputType"]], List[int]]: + """Prepares the inputs by grouping them by the structured output. + + When we generate structured outputs with schemas obtained from a dataset, we need to + prepare the data to try to send batches of inputs instead of single inputs to the model + to take advante of the engine. So we group the inputs by the structured output to be + passed in the `generate` method. + + Args: + inputs: The batch of inputs passed to the generate method. As we expect to be generating + structured outputs, each element will be a tuple containing the instruction and the + structured output. + + Returns: + The prepared batches (sub-batches let's say) to be passed to the `generate` method. + Each new tuple will contain instead of the single instruction, a list of instructions + """ + instruction_order = {} + batches: Dict[str, List[str]] = {} + for i, (instruction, structured_output) in enumerate(inputs): + instruction = self.prepare_input(instruction) + instruction_order[instruction] = i + + structured_output = json.dumps(structured_output) + if structured_output not in batches: + batches[structured_output] = [instruction] + else: + batches[structured_output].append(instruction) + + # Built a list with instructions sorted by structured output + flat_instructions = [ + instruction for _, group in batches.items() for instruction in group + ] + + # Generate the list of indices based on the original order + sorted_indices = [ + instruction_order[instruction] for instruction in flat_instructions + ] + + return [ + (batch, json.loads(schema)) for schema, batch in batches.items() + ], sorted_indices + + @validate_call + def generate( # noqa: C901 # type: ignore + self, + inputs: List[FormattedInput], + num_generations: int = 1, + max_new_tokens: int = 128, + stop: Optional[Union[str, List[str]]] = None, + stop_token_ids: Optional[List[int]] = None, + temperature: float = 1.0, + top_p: float = 1.0, + top_k: int = -1, + min_p: float = 0.0, + frequency_penalty: float = 0.0, + presence_penalty: float = 0.0, + repetition_penalty: float = 1.0, + min_new_tokens: int = 0, + spaces_between_special_tokens: bool = True, + json_schema: Optional[str] = None, + regex: Optional[str] = None, + no_stop_trim: bool = False, + ignore_eos: bool = False, + skip_special_tokens: bool = True, + custom_params: Optional[Dict[str, Any]] = None, # no use? + return_logprob: bool = False, + top_logprobs_num: int = 0, + echo: bool = False, + ) -> List[GenerateOutput]: + """Generates `num_generations` responses for each input. + + Args: + inputs: a list of inputs in chat format to generate responses for. + num_generations: the number of generations to create per input. Defaults to + `1`. + max_new_tokens: the maximum number of new tokens that the model will generate. + Defaults to `128`. + stop: a list of strings that will be used to stop the generation when found. + Defaults to `None`. + stop_token_ids: a list of token ids that will be used to stop the generation + when found. Defaults to `None`. + temperature: the temperature to use for the generation. Defaults to `0.1`. + top_p: the top-p value to use for the generation. Defaults to `1.0`. + top_k: the top-k value to use for the generation. Defaults to `0`. + min_p: the minimum probability to use for the generation. Defaults to `0.0`. + frequency_penalty: the repetition penalty to use for the generation. Defaults + to `0.0`. + presence_penalty: the presence penalty to use for the generation. Defaults to + `0.0`. + repetition_penalty: the repetition penalty to use for the generation Defaults to + `1.0`. + min_new_tokens: Forces the model to generate at least `min_new_tokens` until + a stop word or EOS token is sampled. Defaults to `0`. + spaces_between_special_tokens: Whether or not to add spaces between special + tokens during detokenization. Defaults to `True`. + json_schema: json structure output. Defaults to `None`. + regex: regex structure output. Defaults to `None`. + no_stop_trim: Don't trim stop words or EOS token from the generated text. + Defaults to `False`. + ignore_eos: Don't stop generation when EOS token is sampled. Defaults to `False`. + skip_special_tokens: Whether to exclude special tokens from the output. Defaults + to `False`. + custom_params: Used when employing `CustomLogitProcessor`. Defaults to `None`. + return_logprob: Whether to return log probabilities for tokens. Defaults to `False`. + top_logprobs_num: If returning log probabilities, specifies the number of + top logprobs to return at each position. Defaults to `0`. + echo: Whether to echo the include the prompt in the response or not. Defaults + to `False`. + + Returns: + A list of lists of strings containing the generated responses for each input. + """ + + if isinstance(inputs[0], tuple): + # Prepare the batches for structured generation + prepared_batches, sorted_indices = self._prepare_batches(inputs) # type: ignore + else: + # Simulate a batch without the structured output content + prepared_batches = [([self.prepare_input(input) for input in inputs], None)] # type: ignore + sorted_indices = None + + batched_outputs: List["LLMOutput"] = [] + generations = [] + + for prepared_inputs, structured_output in prepared_batches: + if self.structured_output is not None and structured_output is not None: + self._logger.warning( + "An `structured_output` was provided in the model configuration, but" + " one was also provided in the input. The input structured output will" + " be used." + ) + + temp_structure = None + if structured_output is not None: + temp_structure = structured_output + elif self.structured_output is not None: + temp_structure = self.structured_output + + if temp_structure is not None: + format = temp_structure.get("format") + schema = temp_structure.get("schema") + if not format: + if isinstance(schema, dict) or inspect.isclass(schema): + format = "json" + elif isinstance(schema, str): + format = "regex" + + if format == "json": + json_schema = json.dumps(schema_as_dict(schema)) + elif format == "regex": + regex = schema + + sampling_params = { + "max_new_tokens": max_new_tokens, + "stop": stop, + "stop_token_ids": stop_token_ids, + "temperature": temperature, + "top_p": top_p, + "top_k": top_k, + "min_p": min_p, + "frequency_penalty": frequency_penalty, + "presence_penalty": presence_penalty, + "repetition_penalty": repetition_penalty, + "min_new_tokens": min_new_tokens, + "spaces_between_special_tokens": spaces_between_special_tokens, + "n": num_generations, + "json_schema": json_schema, + "regex": regex, + "no_stop_trim": no_stop_trim, + "ignore_eos": ignore_eos, + "skip_special_tokens": skip_special_tokens, + "custom_params": custom_params, + } + + batch_outputs = self._model.generate( + prompt=prepared_inputs, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=0 if echo else -1, + top_logprobs_num=top_logprobs_num if return_logprob else 0, + ) + + group = int(len(batch_outputs) / len(prepared_inputs)) + for num in range(len(prepared_inputs)): + input = prepared_inputs[num] + outputs = batch_outputs[num * group : (num + 1) * group] + processed_prompt_logprobs = [] + meta_info = outputs[0]["meta_info"] + if "input_top_logprobs" in meta_info: + processed_prompt_logprobs = self._get_llm_logprobs( + top_logprob=meta_info["input_top_logprobs"], + choose_logprob=meta_info["input_token_logprobs"], + ) + texts, statistics, outputs_logprobs = self._process_outputs( + input=input, + outputs=outputs, + echo=echo, + prompt_logprobs=processed_prompt_logprobs, + ) + batched_outputs.append(texts) + generation = prepare_output( + generations=texts, + input_tokens=statistics["input_tokens"], + output_tokens=statistics["output_tokens"], + logprobs=outputs_logprobs, + ) + + generations.append(generation) + + if sorted_indices is not None: + pairs = list(enumerate(sorted_indices)) + pairs.sort(key=lambda x: x[1]) + generations = [generations[original_idx] for original_idx, _ in pairs] + + return generations + + def _process_outputs( + self, + input: str, + outputs, + prompt_logprobs: List[List["Logprob"]], + echo: bool = False, + ) -> Tuple["LLMOutput", "LLMStatistics", "LLMLogprobs"]: + texts = [] + outputs_logprobs = [] + lens = 1 + if isinstance(outputs, list): + lens = len(outputs) + statistics = { + "input_tokens": [compute_tokens(input, self._tokenizer.encode)] * lens, + "output_tokens": [], + } + + for output in outputs: + text = output["text"] + if echo: + text = input + text + texts.append(text) + statistics["output_tokens"].append(output["meta_info"]["completion_tokens"]) + if "output_top_logprobs" in output["meta_info"]: + processed_output_logprobs = self._get_llm_logprobs( + output["meta_info"]["output_top_logprobs"] + ) + outputs_logprobs.append(prompt_logprobs + processed_output_logprobs) + + return texts, statistics, outputs_logprobs + + def _get_llm_logprobs( + self, + top_logprob, + choose_logprob=None, + ) -> List[List["Logprob"]]: + processed_logprobs = [] + if choose_logprob is not None: + token_logprobs = [] + for num in range(len(choose_logprob)): + if choose_logprob[num][0] is None: + processed_logprobs.append(None) + continue + else: + token_logprobs.append( + { + "token": self._tokenizer.decode(choose_logprob[num][1]), + "logprob": choose_logprob[num][0], + } + ) + for top_num in range(len(top_logprob[num]) - 1): + token_logprobs.append( + { + "token": self._tokenizer.decode( + top_logprob[num][top_num][1] + ), + "logprob": top_logprob[num][top_num][0], + } + ) + processed_logprobs.append(token_logprobs) + else: + for probs in top_logprob: + token_logprobs = [] + for item in probs: + token_logprobs.append( + {"token": self._tokenizer.decode(item[1]), "logprob": item[0]} + ) + processed_logprobs.append(token_logprobs) + return processed_logprobs + + +class ClientSGLang(OpenAILLM, MagpieChatTemplateMixin): + """A client for the `SGLang` server implementing the OpenAI API specification. + + Attributes: + base_url: the base URL of the `SGLang` server. Defaults to `"http://localhost:30000"`. + max_retries: the maximum number of times to retry the request to the API before + failing. Defaults to `6`. + timeout: the maximum time in seconds to wait for a response from the API. Defaults + to `120`. + httpx_client_kwargs: extra kwargs that will be passed to the `httpx.AsyncClient` + created to comunicate with the `vLLM` server. Defaults to `None`. + tokenizer: the Hugging Face Hub repo id or path of the tokenizer that will be used + to apply the chat template and tokenize the inputs before sending it to the + server. Defaults to `None`. + tokenizer_revision: the revision of the tokenizer to load. Defaults to `None`. + _aclient: the `httpx.AsyncClient` used to comunicate with the `vLLM` server. Defaults + to `None`. + + Runtime parameters: + - `base_url`: the base url of the `vLLM` server. Defaults to `"http://localhost:30000"`. + - `max_retries`: the maximum number of times to retry the request to the API before + failing. Defaults to `6`. + - `timeout`: the maximum time in seconds to wait for a response from the API. Defaults + to `120`. + - `httpx_client_kwargs`: extra kwargs that will be passed to the `httpx.AsyncClient` + created to comunicate with the `vLLM` server. Defaults to `None`. + + Examples: + Generate text: + + ```python + from distilabel.models.llms import ClientSGLang + + llm = ClientSGLang( + base_url="http://localhost:30000/v1", + tokenizer="Qwen/Qwen2-7B-Instruct" + ) + + llm.load() + + results = llm.generate_outputs( + inputs=[[{"role": "user", "content": "Hello, how are you?"}]], + temperature=0.7, + top_p=1.0, + max_new_tokens=256, + ) + ``` + """ + + model: str = "" # Default value so it's not needed to `ClientvLLM(model="...")` + tokenizer: Optional[str] = None + tokenizer_revision: Optional[str] = None + + # We need the sync client to get the list of models + _client: "OpenAI" = PrivateAttr(None) + _tokenizer: "PreTrainedTokenizer" = PrivateAttr(None) + + def load(self) -> None: + """Creates an `httpx.AsyncClient` to connect to the vLLM server and a tokenizer + optionally.""" + + self.api_key = SecretStr("EMPTY") + + # We need to first create the sync client to get the model name that will be used + # in the `super().load()` when creating the logger. + try: + from openai import OpenAI + except ImportError as ie: + raise ImportError( + "OpenAI Python client is not installed. Please install it using" + " `pip install 'distilabel[openai]'`." + ) from ie + + self._client = OpenAI( + base_url=self.base_url, + api_key=self.api_key.get_secret_value(), # type: ignore + max_retries=self.max_retries, # type: ignore + timeout=self.timeout, + ) + + super().load() + + try: + from transformers import AutoTokenizer + except ImportError as ie: + raise ImportError( + "To use `ClientvLLM` you need to install `transformers`." + "Please install it using `pip install 'distilabel[hf-transformers]'`." + ) from ie + + self._tokenizer = AutoTokenizer.from_pretrained( + self.tokenizer, revision=self.tokenizer_revision + ) + + @cached_property + def model_name(self) -> str: # type: ignore + """Returns the name of the model served with vLLM server.""" + models = self._client.models.list() + return models.data[0].id + + def _prepare_input(self, input: "StandardInput") -> str: + """Prepares the input (applying the chat template and tokenization) for the provided + input. + + Args: + input: the input list containing chat items. + + Returns: + The prompt to send to the LLM. + """ + prompt: str = ( + self._tokenizer.apply_chat_template( # type: ignore + input, # type: ignore + tokenize=False, + add_generation_prompt=True, # type: ignore + ) + if input + else "" + ) + return super().apply_magpie_pre_query_template(prompt, input) + + @validate_call + async def agenerate( # type: ignore + self, + input: FormattedInput, + num_generations: int = 1, + max_new_tokens: int = 128, + frequency_penalty: float = 0.0, + logit_bias: Optional[Dict[str, int]] = None, + presence_penalty: float = 0.0, + temperature: float = 1.0, + top_p: float = 1.0, + ) -> GenerateOutput: + """Generates `num_generations` responses for each input. + + Args: + input: a single input in chat format to generate responses for. + num_generations: the number of generations to create per input. Defaults to + `1`. + max_new_tokens: the maximum number of new tokens that the model will generate. + Defaults to `128`. + frequency_penalty: the repetition penalty to use for the generation. Defaults + to `0.0`. + logit_bias: modify the likelihood of specified tokens appearing in the completion. + Defaults to `` + presence_penalty: the presence penalty to use for the generation. Defaults to + `0.0`. + temperature: the temperature to use for the generation. Defaults to `0.1`. + top_p: nucleus sampling. The value refers to the top-p tokens that should be + considered for sampling. Defaults to `1.0`. + + Returns: + A list of lists of strings containing the generated responses for each input. + """ + + completion = await self._aclient.completions.create( + model=self.model_name, + prompt=self._prepare_input(input), # type: ignore + n=num_generations, + max_tokens=max_new_tokens, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + presence_penalty=presence_penalty, + temperature=temperature, + top_p=top_p, + ) + + generations = [] + for choice in completion.choices: + text = choice.text + if text == "": + self._logger.warning( # type: ignore + f"Received no response from SGLang server (model: '{self.model_name}')." + f" Finish reason was: {choice.finish_reason}" + ) + generations.append(text) + + return prepare_output(generations, **self._get_llm_statistics(completion)) diff --git a/tests/unit/models/embeddings/test_sglang.py b/tests/unit/models/embeddings/test_sglang.py new file mode 100644 index 0000000000..6dd4162e59 --- /dev/null +++ b/tests/unit/models/embeddings/test_sglang.py @@ -0,0 +1,50 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import MagicMock, Mock + +from distilabel.models.embeddings.sglang import SGLangEmbeddings + + +class TestSentenceTransformersEmbeddings: + model_name = "group/model-name" + + def test_model_name(self) -> None: + embeddings = SGLangEmbeddings(model=self.model_name) + + assert embeddings.model_name == self.model_name + + def test_encode(self) -> None: + embeddings = SGLangEmbeddings(model=self.model_name) + + # the loading should be done here, it's just mocked + # embeddings.load() + embeddings._model = MagicMock() + + mocked_response = {"embedding": [0.1] * 10} + + embeddings._model.encode = Mock( + side_effect=lambda x: [mocked_response for _ in range(len(x))] + ) + + results = embeddings.encode( + inputs=[ + "Hello, how are you?", + "What a nice day!", + "I hear that llamas are very popular now.", + ] + ) + + for result in results: + assert len(result) == 10 diff --git a/tests/unit/models/llms/test_sglang.py b/tests/unit/models/llms/test_sglang.py new file mode 100644 index 0000000000..3c45b96e8f --- /dev/null +++ b/tests/unit/models/llms/test_sglang.py @@ -0,0 +1,299 @@ +# Copyright 2023-present, Argilla, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, List +from unittest import mock + +import pytest +from openai.pagination import SyncPage +from openai.types import Model +from openai.types.completion import Completion +from openai.types.completion_choice import CompletionChoice +from openai.types.completion_usage import CompletionUsage +from pydantic import BaseModel +from transformers import AutoTokenizer + +from distilabel.models.llms import SGLang +from distilabel.models.llms.sglang import ClientSGLang + + +class Character(BaseModel): + name: str + description: str + role: str + weapon: str + + +class Animal(BaseModel): + name: str + species: str + habitat: str + diet: str + + +SAMPLE_DATA = [ + [ + { + "instruction": [ + {"role": "user", "content": "Generate a character from a RPG game."} + ], + "structured_output": { + "format": "json", + "schema": Character.model_json_schema(), + }, + }, + { + "instruction": [ + { + "role": "user", + "content": "Generate an animal from a zoo.", + } + ], + "structured_output": { + "format": "json", + "schema": Animal.model_json_schema(), + }, + }, + { + "instruction": [{"role": "user", "content": "Repeated character"}], + "structured_output": { + "format": "json", + "schema": Character.model_json_schema(), + }, + }, + { + "instruction": [ + { + "role": "user", + "content": "What's the weather like today in Seattle in Celsius degrees?", + } + ], + "structured_output": { + "format": "regex", + "schema": "(\\d{1,2})°C", + }, + }, + { + "instruction": [{"role": "user", "content": "Other character"}], + "structured_output": { + "format": "json", + "schema": Character.model_json_schema(), + }, + }, + { + "instruction": [{"role": "user", "content": "repeated regex"}], + "structured_output": { + "format": "regex", + "schema": "(\\d{1,2})°C", + }, + }, + ] +] + + +class TestSGLang: + @pytest.mark.parametrize( + "multi_structured_output", + # TODO: uncomment once with update our code to work with `outlines>0.1.0` + (True, False), + # (False,), + ) + @pytest.mark.parametrize( + "num_generations, expected_result", + [ + ( + 1, + [ + { + "generations": ["I'm fine thank you"], + "statistics": {"input_tokens": [21], "output_tokens": [6]}, + "logprobs": [ + [ + [ + {"token": "thank", "logprob": -1}, + {"token": "you", "logprob": -3}, + ], + [ + {"token": "thank", "logprob": -1}, + {"token": "you", "logprob": -3}, + ], + ] + ], + } + ], + ), + ( + 2, + [ + { + "generations": ["I'm fine thank you"] * 2, + "statistics": { + "input_tokens": [21, 21], + "output_tokens": [6, 6], + }, + "logprobs": [ + [ + [ + {"token": "thank", "logprob": -1}, + {"token": "you", "logprob": -3}, + ], + [ + {"token": "thank", "logprob": -1}, + {"token": "you", "logprob": -3}, + ], + ] + ] + * 2, + } + ], + ), + ], + ) + def test_generate( + self, + multi_structured_output: bool, + num_generations: int, + expected_result: List[Dict[str, Any]], + ) -> None: + llm = SGLang(model="dummy") + tokenizer = AutoTokenizer.from_pretrained( + "distilabel-internal-testing/tiny-random-mistral" + ) + llm._tokenizer = tokenizer + sglang_mock = mock.MagicMock() + sglang_mock.get_tokenizer = mock.MagicMock(return_value=tokenizer) + # mock the import by hacking sys.modules + # https://stackoverflow.com/questions/60919705/how-to-mock-in-a-python-unittest-a-library-not-installed-locally + import sys + + if "sglang" not in sys.modules: + sys.modules["sglang"] = sglang_mock + llm._model = sglang_mock + + mocked_requests_output = [ + { + "text": "I'm fine thank you", + "meta_info": { + "completion_tokens": 6, + "output_top_logprobs": [ + [(-1, 6979, "thank"), (-3, 368, "you")], + [(-1, 6979, "thank"), (-3, 368, "you")], + ], + }, + } + for _ in range(num_generations) + ] + + llm._model.generate = mock.MagicMock(return_value=mocked_requests_output) + if not multi_structured_output: + formatted_inputs = [ + [ + {"role": "system", "content": "sysprompt"}, + { + "role": "user", + "content": "I'm fine thank you", + }, + ] + ] + else: + formatted_inputs = [ + ( + [ + {"role": "system", "content": "sysprompt"}, + { + "role": "user", + "content": "I'm fine thank you", + }, + ], + { + # "format": "json", + "format": "regex", + "schema": r".*", + # "schema": Character.model_json_schema(), + }, + ) + ] + result = llm.generate(inputs=formatted_inputs, num_generations=num_generations) + assert result == expected_result + + +@mock.patch("openai.OpenAI") +@mock.patch("openai.AsyncOpenAI") +class TestClientSGLang: + def test_clientsglang_model_name( + self, _: mock.MagicMock, openai_mock: mock.MagicMock + ) -> None: + llm = ClientSGLang( + base_url="http://localhost:8000/v1", + tokenizer="google-bert/bert-base-uncased", + ) + + llm._client = mock.MagicMock() + llm._client.models.list.return_value = SyncPage[Model]( # type: ignore + data=[Model(id="llama", created=1234, object="model", owned_by="")], + object="model", + ) + + assert llm.model_name == "llama" + + @pytest.mark.asyncio + async def test_agenerate( + self, _openai_mock: mock.MagicMock, _async_openai_mock: mock.MagicMock + ) -> None: + llm = ClientSGLang( + base_url="http://localhost:8000/v1", + tokenizer="distilabel-internal-testing/tiny-random-mistral", + ) + + llm.load() + + llm._aclient.completions.create = mock.AsyncMock( + return_value=Completion( + id="1234", + created=1234, + model="llama", + object="text_completion", + choices=[ + CompletionChoice( + finish_reason="stop", + index=0, + logprobs=None, + text="I'm fine thank you", + ), + CompletionChoice( + finish_reason="stop", + index=0, + logprobs=None, + text="I'm fine thank you sir", + ), + ], + usage=CompletionUsage( + completion_tokens=10, + prompt_tokens=10, + total_tokens=20, + ), + ) + ) + + generations = await llm.agenerate( + input=[{"role": "user", "content": "Hi, how are you?"}] + ) + + assert generations == { + "generations": ["I'm fine thank you", "I'm fine thank you sir"], + "statistics": { + "input_tokens": [10], + "output_tokens": [10], + }, + }