Skip to content

Commit

Permalink
feat/1001: sglang integration (#1122)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jayon02 authored Feb 27, 2025
1 parent a2a91a8 commit ab612a6
Show file tree
Hide file tree
Showing 10 changed files with 1,238 additions and 2 deletions.
12 changes: 11 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ argilla = ["argilla >= 2.0.0", "ipython"]
cohere = ["cohere >= 5.2.0"]
groq = ["groq >= 0.4.1"]
hf-inference-endpoints = ["huggingface_hub >= 0.22.0"]
hf-transformers = ["transformers >= 4.34.1", "torch >= 2.0.0"]
hf-transformers = ["transformers == 4.48.3", "torch >= 2.0.0"]
instructor = ["instructor >= 1.2.3"]
litellm = ["litellm >= 1.30.0"]
llama-cpp = ["llama-cpp-python >= 0.2.0"]
Expand Down Expand Up @@ -107,6 +107,16 @@ vision = ["Pillow >= 10.3.0"] # To work with images.
# minhash
minhash = ["datasketch >= 1.6.5", "nltk>3.8.1"]

sglang = ["sglang[all]>=0.4.3.post2", "transformers == 4.48.3"]

[tool.hatch.envs.default]
dependencies = [
"sglang[all]>=0.4.3.post2",
"transformers == 4.48.3",
]
installer = "pip"
pip-args = ["--find-links", "https://flashinfer.ai/whl/cu124/torch2.5/flashinfer-python"]

[project.urls]
Documentation = "https://distilabel.argilla.io/"
Issues = "https://github.com/argilla/distilabel/issues"
Expand Down
3 changes: 2 additions & 1 deletion src/distilabel/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,11 @@
from distilabel.models.embeddings.sentence_transformers import (
SentenceTransformerEmbeddings,
)
from distilabel.models.embeddings.vllm import vLLMEmbeddings
from distilabel.models.embeddings.vllm import SGLangEmbeddings, vLLMEmbeddings

__all__ = [
"Embeddings",
"SGLangEmbeddings",
"SentenceTransformerEmbeddings",
"vLLMEmbeddings",
]
3 changes: 3 additions & 0 deletions src/distilabel/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from distilabel.models.llms.moa import MixtureOfAgentsLLM
from distilabel.models.llms.ollama import OllamaLLM
from distilabel.models.llms.openai import OpenAILLM
from distilabel.models.llms.sglang import ClientSGLang, SGLang
from distilabel.models.llms.together import TogetherLLM
from distilabel.models.llms.vertexai import VertexAILLM
from distilabel.models.llms.vllm import ClientvLLM, vLLM
Expand All @@ -49,6 +50,7 @@
"AnyscaleLLM",
"AsyncLLM",
"AzureOpenAILLM",
"ClientSGLang",
"ClientvLLM",
"CohereLLM",
"CudaDevicePlacementMixin",
Expand All @@ -63,6 +65,7 @@
"MlxLLM",
"OllamaLLM",
"OpenAILLM",
"SGLang",
"TogetherLLM",
"TransformersLLM",
"VertexAILLM",
Expand Down
5 changes: 5 additions & 0 deletions src/distilabel/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from distilabel.models.embeddings.sentence_transformers import (
SentenceTransformerEmbeddings,
)
from distilabel.models.embeddings.sglang import SGLangEmbeddings
from distilabel.models.embeddings.vllm import vLLMEmbeddings
from distilabel.models.image_generation.base import (
AsyncImageGenerationModel,
Expand All @@ -41,6 +42,7 @@
from distilabel.models.llms.moa import MixtureOfAgentsLLM
from distilabel.models.llms.ollama import OllamaLLM
from distilabel.models.llms.openai import OpenAILLM
from distilabel.models.llms.sglang import ClientSGLang, SGLang
from distilabel.models.llms.together import TogetherLLM
from distilabel.models.llms.vertexai import VertexAILLM
from distilabel.models.llms.vllm import ClientvLLM, vLLM
Expand All @@ -54,6 +56,7 @@
"AsyncImageGenerationModel",
"AsyncLLM",
"AzureOpenAILLM",
"ClientSGLang",
"ClientvLLM",
"CohereLLM",
"CudaDevicePlacementMixin",
Expand All @@ -73,6 +76,8 @@
"OllamaLLM",
"OpenAIImageGeneration",
"OpenAILLM",
"SGLang",
"SGLangEmbeddings",
"SentenceTransformerEmbeddings",
"TogetherLLM",
"TransformersLLM",
Expand Down
2 changes: 2 additions & 0 deletions src/distilabel/models/embeddings/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,13 @@
from distilabel.models.embeddings.sentence_transformers import (
SentenceTransformerEmbeddings,
)
from distilabel.models.embeddings.sglang import SGLangEmbeddings
from distilabel.models.embeddings.vllm import vLLMEmbeddings

__all__ = [
"Embeddings",
"LlamaCppEmbeddings",
"SGLangEmbeddings",
"SentenceTransformerEmbeddings",
"vLLMEmbeddings",
]
125 changes: 125 additions & 0 deletions src/distilabel/models/embeddings/sglang.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright 2023-present, Argilla, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from pydantic import Field, PrivateAttr

from distilabel.mixins.runtime_parameters import RuntimeParameter
from distilabel.models.embeddings.base import Embeddings
from distilabel.models.mixins.cuda_device_placement import CudaDevicePlacementMixin

if TYPE_CHECKING:
from sglang import Engine


class SGLangEmbeddings(Embeddings, CudaDevicePlacementMixin):
"""`sglang` library implementation for embedding generation.
Attributes:
model: the model Hugging Face Hub repo id or a path to a directory containing the
model weights and configuration files.
dtype: the data type to use for the model. Defaults to `auto`.
trust_remote_code: whether to trust the remote code when loading the model. Defaults
to `False`.
quantization: the quantization mode to use for the model. Defaults to `None`.
revision: the revision of the model to load. Defaults to `None`.
seed: the seed to use for the random number generator. Defaults to `0`.
extra_kwargs: additional dictionary of keyword arguments that will be passed to the
`Engine` class of `sglang` library. Defaults to `{}`.
_model: the `SGLang` model instance. This attribute is meant to be used internally
and should not be accessed directly. It will be set in the `load` method.
References:
- https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/server_args.py
Examples:
Generating sentence embeddings:
```python
if __name__ == "__main__":
from distilabel.models import SGLangEmbeddings
embeddings = SGLangEmbeddings(model="intfloat/e5-mistral-7b-instruct")
embeddings.load()
results = embeddings.encode(inputs=["distilabel is awesome!", "and Argilla!"])
print(results)
# [
# [0.0203704833984375, -0.0060882568359375, ...],
# [0.02398681640625, 0.0177001953125 ...],
# ]
```
"""

model: str
dtype: str = "auto"
trust_remote_code: bool = False
quantization: Optional[str] = None
revision: Optional[str] = None

seed: int = 0

extra_kwargs: Optional[RuntimeParameter[Dict[str, Any]]] = Field(
default_factory=dict,
description="Additional dictionary of keyword arguments that will be passed to the"
" `Engine` class of `sglang` library. See all the supported arguments at: "
"https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/entrypoints/engine.py",
)

_model: "Engine" = PrivateAttr(None)

def load(self) -> None:
"""Loads the `sglang` model using either the path or the Hugging Face Hub repository id."""
super().load()

CudaDevicePlacementMixin.load(self)

try:
from sglang import Engine
except ImportError as err:
raise ImportError(
"sglang is not installed. Please install it with sglang document https://docs.sglang.ai/start/install.html."
) from err

self._model = Engine(
model_path=self.model,
dtype=self.dtype,
trust_remote_code=self.trust_remote_code,
quantization=self.quantization,
revision=self.revision,
random_seed=self.seed,
**self.extra_kwargs, # type: ignore
)

def unload(self) -> None:
"""Unloads the `SGLang` model."""
self._model = None
CudaDevicePlacementMixin.unload(self)
super().unload()

@property
def model_name(self) -> str:
"""Returns the name of the model."""
return self.model

def encode(self, inputs: List[str]) -> List[List[Union[int, float]]]:
"""Generates embeddings for the provided inputs.
Args:
inputs: a list of texts for which an embedding has to be generated.
Returns:
The generated embeddings.
"""
return [output["embedding"] for output in self._model.encode(inputs)]
3 changes: 3 additions & 0 deletions src/distilabel/models/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from distilabel.models.llms.moa import MixtureOfAgentsLLM
from distilabel.models.llms.ollama import OllamaLLM
from distilabel.models.llms.openai import OpenAILLM
from distilabel.models.llms.sglang import ClientSGLang, SGLang
from distilabel.models.llms.together import TogetherLLM
from distilabel.models.llms.vertexai import VertexAILLM
from distilabel.models.llms.vllm import ClientvLLM, vLLM
Expand All @@ -38,6 +39,7 @@
"AnyscaleLLM",
"AsyncLLM",
"AzureOpenAILLM",
"ClientSGLang",
"ClientvLLM",
"CohereLLM",
"CudaDevicePlacementMixin",
Expand All @@ -52,6 +54,7 @@
"MlxLLM",
"OllamaLLM",
"OpenAILLM",
"SGLang",
"TogetherLLM",
"TransformersLLM",
"VertexAILLM",
Expand Down
Loading

0 comments on commit ab612a6

Please sign in to comment.