Skip to content

Commit

Permalink
vllm inference plugin (#2967)
Browse files Browse the repository at this point in the history
* vllm inference plugin

Signed-off-by: Daniel Sola <[email protected]>

* fixed default value

Signed-off-by: Daniel Sola <[email protected]>

---------

Signed-off-by: Daniel Sola <[email protected]>
  • Loading branch information
dansola authored Jan 3, 2025
1 parent f3996f6 commit 0ad84f3
Show file tree
Hide file tree
Showing 6 changed files with 210 additions and 0 deletions.
63 changes: 63 additions & 0 deletions plugins/flytekit-inference/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,66 @@ def model_serving(questions: list[str], gguf: FlyteFile) -> list[str]:

return responses
```

## vLLM

The vLLM plugin allows you to serve an LLM hosted on HuggingFace.

```python
import flytekit as fl
from openai import OpenAI

model_name = "google/gemma-2b-it"
hf_token_key = "vllm_hf_token"

vllm_args = {
"model": model_name,
"dtype": "half",
"max-model-len": 2000,
}

hf_secrets = HFSecret(
secrets_prefix="_FSEC_",
hf_token_key=hf_token_key
)

vllm_instance = VLLM(
hf_secret=hf_secrets,
arg_dict=vllm_args
)

image = fl.ImageSpec(
name="vllm_serve",
registry="...",
packages=["flytekitplugins-inference"],
)


@fl.task(
pod_template=vllm_instance.pod_template,
container_image=image,
secret_requests=[
fl.Secret(
key=hf_token_key, mount_requirement=fl.Secret.MountType.ENV_VAR # must be mounted as an env var
)
],
)
def model_serving() -> str:
client = OpenAI(
base_url=f"{vllm_instance.base_url}/v1", api_key="vllm" # api key required but ignored
)

completion = client.chat.completions.create(
model=model_name,
messages=[
{
"role": "user",
"content": "Compose a haiku about the power of AI.",
}
],
temperature=0.5,
top_p=1,
max_tokens=1024,
)
return completion.choices[0].message.content
```
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@

from .nim.serve import NIM, NIMSecrets
from .ollama.serve import Model, Ollama
from .vllm.serve import VLLM, HFSecret
Empty file.
85 changes: 85 additions & 0 deletions plugins/flytekit-inference/flytekitplugins/inference/vllm/serve.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from dataclasses import dataclass
from typing import Optional

from ..sidecar_template import ModelInferenceTemplate


@dataclass
class HFSecret:
"""
:param secrets_prefix: The secrets prefix that Flyte appends to all mounted secrets.
:param hf_token_group: The group name for the HuggingFace token.
:param hf_token_key: The key name for the HuggingFace token.
"""

secrets_prefix: str # _UNION_ or _FSEC_
hf_token_key: str
hf_token_group: Optional[str] = None


class VLLM(ModelInferenceTemplate):
def __init__(
self,
hf_secret: HFSecret,
arg_dict: Optional[dict] = None,
image: str = "vllm/vllm-openai",
health_endpoint: str = "/health",
port: int = 8000,
cpu: int = 2,
gpu: int = 1,
mem: str = "10Gi",
):
"""
Initialize NIM class for managing a Kubernetes pod template.
:param hf_secret: Instance of HFSecret for managing hugging face secrets.
:param arg_dict: A dictionary of arguments for the VLLM model server (https://docs.vllm.ai/en/stable/models/engine_args.html).
:param image: The Docker image to be used for the model server container. Default is "vllm/vllm-openai".
:param health_endpoint: The health endpoint for the model server container. Default is "/health".
:param port: The port number for the model server container. Default is 8000.
:param cpu: The number of CPU cores requested for the model server container. Default is 2.
:param gpu: The number of GPU cores requested for the model server container. Default is 1.
:param mem: The amount of memory requested for the model server container. Default is "10Gi".
"""
if hf_secret.hf_token_key is None:
raise ValueError("HuggingFace token key must be provided.")
if hf_secret.secrets_prefix is None:
raise ValueError("Secrets prefix must be provided.")

self._hf_secret = hf_secret
self._arg_dict = arg_dict

super().__init__(
image=image,
health_endpoint=health_endpoint,
port=port,
cpu=cpu,
gpu=gpu,
mem=mem,
)

self.setup_vllm_pod_template()

def setup_vllm_pod_template(self):
from kubernetes.client.models import V1EnvVar

model_server_container = self.pod_template.pod_spec.init_containers[0]

if self._hf_secret.hf_token_group:
hf_key = f"$({self._hf_secret.secrets_prefix}{self._hf_secret.hf_token_group}_{self._hf_secret.hf_token_key})".upper()
else:
hf_key = f"$({self._hf_secret.secrets_prefix}{self._hf_secret.hf_token_key})".upper()

model_server_container.env = [
V1EnvVar(name="HUGGING_FACE_HUB_TOKEN", value=hf_key),
]
model_server_container.args = self.build_vllm_args()

def build_vllm_args(self) -> list:
args = []
if self._arg_dict:
for key, value in self._arg_dict.items():
args.append(f"--{key}")
if value is not None:
args.append(str(value))
return args
1 change: 1 addition & 0 deletions plugins/flytekit-inference/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
f"flytekitplugins.{PLUGIN_NAME}",
f"flytekitplugins.{PLUGIN_NAME}.nim",
f"flytekitplugins.{PLUGIN_NAME}.ollama",
f"flytekitplugins.{PLUGIN_NAME}.vllm",
],
install_requires=plugin_requires,
license="apache2",
Expand Down
60 changes: 60 additions & 0 deletions plugins/flytekit-inference/tests/test_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from flytekitplugins.inference import VLLM, HFSecret


def test_vllm_init_valid_params():
vllm_args = {
"model": "google/gemma-2b-it",
"dtype": "half",
"max-model-len": 2000,
}

hf_secrets = HFSecret(
secrets_prefix="_UNION_",
hf_token_key="vllm_hf_token"
)

vllm_instance = VLLM(
hf_secret=hf_secrets,
arg_dict=vllm_args,
image='vllm/vllm-openai:my-tag',
cpu='10',
gpu='2',
mem='50Gi',
port=8080,
)

assert len(vllm_instance.pod_template.pod_spec.init_containers) == 1
assert (
vllm_instance.pod_template.pod_spec.init_containers[0].image
== 'vllm/vllm-openai:my-tag'
)
assert (
vllm_instance.pod_template.pod_spec.init_containers[0].resources.requests[
"memory"
]
== "50Gi"
)
assert (
vllm_instance.pod_template.pod_spec.init_containers[0].ports[0].container_port
== 8080
)
assert vllm_instance.pod_template.pod_spec.init_containers[0].args == ['--model', 'google/gemma-2b-it', '--dtype', 'half', '--max-model-len', '2000']
assert vllm_instance.pod_template.pod_spec.init_containers[0].env[0].name == 'HUGGING_FACE_HUB_TOKEN'
assert vllm_instance.pod_template.pod_spec.init_containers[0].env[0].value == '$(_UNION_VLLM_HF_TOKEN)'



def test_vllm_default_params():
vllm_instance = VLLM(hf_secret=HFSecret(secrets_prefix="_FSEC_", hf_token_key="test_token"))

assert vllm_instance.base_url == "http://localhost:8000"
assert vllm_instance._image == 'vllm/vllm-openai'
assert vllm_instance._port == 8000
assert vllm_instance._cpu == 2
assert vllm_instance._gpu == 1
assert vllm_instance._health_endpoint == "/health"
assert vllm_instance._mem == "10Gi"
assert vllm_instance._arg_dict == None
assert vllm_instance._hf_secret.secrets_prefix == '_FSEC_'
assert vllm_instance._hf_secret.hf_token_key == 'test_token'
assert vllm_instance._hf_secret.hf_token_group == None

0 comments on commit 0ad84f3

Please sign in to comment.