-
Notifications
You must be signed in to change notification settings - Fork 12
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Dmytro Parfeniuk
committed
Sep 2, 2024
1 parent
6e30870
commit b0c0acb
Showing
5 changed files
with
74 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
from . import vllm | ||
from .openai import openai_completion_factory, openai_model_factory | ||
|
||
__all__ = ["openai_completion_factory", "openai_model_factory"] | ||
__all__ = ["openai_completion_factory", "openai_model_factory", "vllm"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
""" | ||
This module includes data models factories for the `vllm` 3-rd party package | ||
""" | ||
|
||
import random | ||
from functools import partial | ||
from typing import List, Optional | ||
|
||
from pydantic import BaseModel, ConfigDict, Field | ||
|
||
from guidellm.utils import random_strings | ||
|
||
__all__ = ["TestLLM", "CompletionOutput"] | ||
|
||
|
||
class CompletionOutput(BaseModel): | ||
"""Test interface of `vllm.CompletionOutput`.""" | ||
|
||
text: str | ||
|
||
|
||
class SamplingParams(BaseModel): | ||
"""Test interface of `vllm.SamplingParams`.""" | ||
|
||
max_tokens: int | ||
|
||
|
||
class TestLLM(BaseModel): | ||
"""Test interface of `vllm.LLM`. | ||
Args: | ||
_outputs_number(int | None): the number of generated tokens per output. | ||
Should be used only for testing purposes. | ||
Default: randint(10..20) | ||
""" | ||
|
||
model_config = ConfigDict( | ||
extra="allow", | ||
validate_assignment=True, | ||
arbitrary_types_allowed=True, | ||
from_attributes=True, | ||
) | ||
|
||
model: str | ||
max_num_batched_tokens: int | ||
|
||
_outputs_number: int = Field(default_factory=partial(random.randint, 10, 20)) | ||
|
||
def _generate_completion_outputs(self, max_tokens: int) -> List[CompletionOutput]: | ||
self._outputs_number = random.randint(10, 20) | ||
|
||
return [ | ||
CompletionOutput(text=text) | ||
for text in random_strings( | ||
min_chars=0, max_chars=max_tokens, n=self._outputs_number | ||
) | ||
] | ||
|
||
def generate( | ||
self, inputs: List[str], sampling_params: SamplingParams | ||
) -> Optional[List[List[CompletionOutput]]]: | ||
return [ | ||
self._generate_completion_outputs(max_tokens=sampling_params.max_tokens) | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters