-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature]: completion support for transformers
#8229
Labels
enhancement
New feature or request
Comments
Here is a basic implementation: from typing import Self
import litellm
from transformers import Pipeline, pipeline
class ChatTransformersLLMHandler(litellm.CustomLLM):
@staticmethod
def make_text_gen_pipeline(model: str) -> Pipeline:
return pipeline(task="text-generation", model=model)
@classmethod
def with_models(cls, *models: str) -> Self:
instance = cls()
for model in models:
instance.add(model, pipeline=cls.make_text_gen_pipeline(model))
return instance
def __init__(self):
super().__init__()
self._pipelines: dict[str : Pipeline | None] = {}
def add(self, model: str, pipeline: Pipeline | None = None) -> Pipeline:
"""Add the model to this handler.
Args:
model: Model name to input to `transformers.pipeline`.
pipeline: Optional pipeline to use. Passing `None` (default) means the
`pipeline` will be lazily constructed, which also will lazily pull the
model down.
"""
self._pipelines[model] = pipeline
return pipeline
def completion(
self, *args, model: str, messages: list, **kwargs
) -> litellm.ModelResponse:
if model not in self._pipelines:
raise ValueError(
f"Input model {model} should match stored model in {set(self._pipelines)}."
)
generator = self._pipelines[model]
if generator is None:
# NOTE: pipeline will pull the model the first time
generator = self.add(
model, pipeline=self.make_text_gen_pipeline(model=model)
)
generation = generator(messages)
if len(generation) > 1:
raise NotImplementedError("Didn't yet support batching.")
(new_generation,) = generation[0]["generated_text"][len(messages) :]
if new_generation.get("role") != "assistant" or not new_generation.get(
"content"
):
raise NotImplementedError(
f"Expected assistant to respond, got {new_generation.get("role")} with"
f" content {new_generation.get("content")}."
)
return litellm.ModelResponse(
choices=[
{"message": {"content": new_generation["content"], "role": "assistant"}}
]
)
CUSTOM_TRANSFORMERS_PROVIDER = "transformers"
if CUSTOM_TRANSFORMERS_PROVIDER in litellm.types.utils.LlmProvidersSet:
raise ValueError("The provider 'transformers' is already present.")
litellm.custom_provider_map = [
{
"provider": CUSTOM_TRANSFORMERS_PROVIDER,
"custom_handler": ChatTransformersLLMHandler.with_models(
"meta-llama/Llama-3.2-1B-Instruct"
),
}
]
response = litellm.completion(
model=f"{CUSTOM_TRANSFORMERS_PROVIDER}/meta-llama/Llama-3.2-1B-Instruct",
messages=[{"role": "user", "content": "Hello world!"}],
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The Feature
#831 was asking for
transformers
support, tho it was closed in favor of allowing clients to define their own implementation.Motivation, pitch
Imo
transformers
is a very common library and is here to stay. Imo LiteLLM should provide first-party support fortransformers
.Are you a ML Ops Team?
No
Twitter / LinkedIn details
No response
The text was updated successfully, but these errors were encountered: