From 6e3087027217bf66e405fbc98a8b819eb2fa46b1 Mon Sep 17 00:00:00 2001 From: Dmytro Parfeniuk Date: Mon, 2 Sep 2024 08:58:33 +0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=A7=20WIP?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Dockerfile | 20 +++++++++ tests/unit/backend/test_vllm.py | 78 ++++++++++++++++++++++++++++++--- 2 files changed, 92 insertions(+), 6 deletions(-) create mode 100644 Dockerfile diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..61aaac4 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,20 @@ +FROM --platform=linux/amd64 python:3.8-slim + +# Environment variables +ENV PYTHONUNBUFFERED=1 + +RUN apt-get update \ + # dependencies for building Python packages && cleaning up unused files + && apt-get install -y build-essential \ + libcurl4-openssl-dev libssl-dev \ + && rm -rf /var/lib/apt/lists/* + + +# Python dependencies +RUN pip install --upgrade pip setuptools + +WORKDIR /app/ + +COPY ./ ./ + +RUN pip install -e '.[dev,deepsparse,vllm]' diff --git a/tests/unit/backend/test_vllm.py b/tests/unit/backend/test_vllm.py index de432f6..148c37c 100644 --- a/tests/unit/backend/test_vllm.py +++ b/tests/unit/backend/test_vllm.py @@ -7,16 +7,18 @@ import importlib import sys -from typing import Dict +from typing import Dict, List import pytest from guidellm.backend import Backend +from guidellm.config import reload_settings +from guidellm.core import TextGenerationRequest -pytestmark = pytest.mark.skipif( - sys.platform != "linux", - reason="Unsupported Platform. Try using Linux or WSL instead.", -) +# pytestmark = pytest.mark.skipif( +# sys.platform != "linux", +# reason="Unsupported Platform. Try using Linux or WSL instead.", +# ) @pytest.fixture(scope="module") @@ -29,7 +31,7 @@ def backend_class(): @pytest.fixture(autouse=True) def mock_vllm_llm(mocker): module = importlib.import_module("vllm") - llm = module.LLM( + llm = getattr(module, "LLM")( model="facebook/opt-125m", max_num_batched_tokens=4096, tensor_parallel_size=1, @@ -65,3 +67,67 @@ def test_backend_creation(create_payload: Dict, backend_class): if (custom_model := create_payload.get("model")) else backend.default_model ) + + +@pytest.mark.smoke() +def test_backend_model_from_env(mocker, backend_class): + mocker.patch.dict( + "os.environ", + {"GUIDELLM__LLM_MODEL": "test_backend_model_from_env"}, + ) + + reload_settings() + + backends = [Backend.create("vllm"), backend_class()] + + for backend in backends: + assert backend.model == "test_backend_model_from_env" + + +@pytest.mark.smoke() +@pytest.mark.parametrize( + "text_generation_request_create_payload", + [ + {"prompt": "Test prompt"}, + {"prompt": "Test prompt", "output_token_count": 20}, + ], +) +@pytest.mark.asyncio() +async def test_make_request( + text_generation_request_create_payload: Dict, backend_class +): + backend = backend_class() + + output_tokens: List[str] = [] + async for response in backend.make_request( + request=TextGenerationRequest(**text_generation_request_create_payload) + ): + if response.add_token: + output_tokens.append(response.add_token) + assert "".join(output_tokens) == "".join( + generation.text for generation in backend.pipeline._generations + ) + + if max_tokens := text_generation_request_create_payload.get("output_token_count"): + assert len(backend.pipeline._generations) == max_tokens + + +@pytest.mark.smoke() +@pytest.mark.parametrize( + ("text_generation_request_create_payload", "error"), + [ + ({"prompt": "Test prompt"}, ValueError), + ], +) +@pytest.mark.asyncio() +async def test_make_request_invalid_request_payload( + text_generation_request_create_payload: Dict, error, backend_class +): + backend = backend_class() + with pytest.raises(error): + [ + respnose + async for respnose in backend.make_request( + request=TextGenerationRequest(**text_generation_request_create_payload) + ) + ]