Skip to content

Commit

Permalink
🚧 WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmytro Parfeniuk committed Sep 2, 2024
1 parent 6e30870 commit b0c0acb
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 13 deletions.
8 changes: 4 additions & 4 deletions src/guidellm/backend/vllm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

check_python_version(min_version="3.8", max_version="3.12")

module_is_available(
module="vllm",
helper=("`vllm` package is not available. Try run: `pip install -e '.[vllm]'`"),
)
# module_is_available(
# module="vllm",
# helper=("`vllm` package is not available. Try run: `pip install -e '.[vllm]'`"),
# )

from .backend import VllmBackend # noqa: E402

Expand Down
2 changes: 1 addition & 1 deletion tests/dummy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
test.dummy.data.openai_completion_factory - openai.types.Completion test factory
"""

from . import data, services # noqa: F401
from . import data, services, vllm # noqa: F401
3 changes: 2 additions & 1 deletion tests/dummy/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import vllm
from .openai import openai_completion_factory, openai_model_factory

__all__ = ["openai_completion_factory", "openai_model_factory"]
__all__ = ["openai_completion_factory", "openai_model_factory", "vllm"]
65 changes: 65 additions & 0 deletions tests/dummy/vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
"""
This module includes data models factories for the `vllm` 3-rd party package
"""

import random
from functools import partial
from typing import List, Optional

from pydantic import BaseModel, ConfigDict, Field

from guidellm.utils import random_strings

__all__ = ["TestLLM", "CompletionOutput"]


class CompletionOutput(BaseModel):
"""Test interface of `vllm.CompletionOutput`."""

text: str


class SamplingParams(BaseModel):
"""Test interface of `vllm.SamplingParams`."""

max_tokens: int


class TestLLM(BaseModel):
"""Test interface of `vllm.LLM`.
Args:
_outputs_number(int | None): the number of generated tokens per output.
Should be used only for testing purposes.
Default: randint(10..20)
"""

model_config = ConfigDict(
extra="allow",
validate_assignment=True,
arbitrary_types_allowed=True,
from_attributes=True,
)

model: str
max_num_batched_tokens: int

_outputs_number: int = Field(default_factory=partial(random.randint, 10, 20))

def _generate_completion_outputs(self, max_tokens: int) -> List[CompletionOutput]:
self._outputs_number = random.randint(10, 20)

return [
CompletionOutput(text=text)
for text in random_strings(
min_chars=0, max_chars=max_tokens, n=self._outputs_number
)
]

def generate(
self, inputs: List[str], sampling_params: SamplingParams
) -> Optional[List[List[CompletionOutput]]]:
return [
self._generate_completion_outputs(max_tokens=sampling_params.max_tokens)
]
9 changes: 2 additions & 7 deletions tests/unit/backend/test_vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,14 @@
the rimtime platform is not a Linux / WSL according to vllm documentation.
"""

import importlib
import sys
from typing import Dict, List

import pytest

from guidellm.backend import Backend
from guidellm.config import reload_settings
from guidellm.core import TextGenerationRequest
from tests import dummy

# pytestmark = pytest.mark.skipif(
# sys.platform != "linux",
Expand All @@ -30,13 +29,9 @@ def backend_class():

@pytest.fixture(autouse=True)
def mock_vllm_llm(mocker):
module = importlib.import_module("vllm")
llm = getattr(module, "LLM")(
llm = dummy.vllm.TestLLM(
model="facebook/opt-125m",
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True,
)

return mocker.patch("vllm.LLM", return_value=llm)
Expand Down

0 comments on commit b0c0acb

Please sign in to comment.