Skip to content

Commit

Permalink
🚧 WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
Dmytro Parfeniuk committed Jul 17, 2024
1 parent 87ca6b1 commit 65aace3
Show file tree
Hide file tree
Showing 8 changed files with 122 additions and 104 deletions.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,6 @@ lint.select = ["E", "F", "W"]

[tool.pytest.ini_options]
addopts = '-s -vvv --cache-clear'
asyncio_mode = 'auto'
markers = [
"smoke: quick tests to check basic functionality",
"sanity: detailed tests to ensure major functions work correctly",
Expand Down
4 changes: 3 additions & 1 deletion src/guidellm/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,9 @@ def submit(self, request: TextGenerationRequest) -> TextGenerationResult:
:rtype: TextGenerationResult
"""

logger.info(f"Submitting request with prompt: {request.prompt}")
logger.info(
f"Submitting request with prompython -m ruff format src testspt: {request.prompt}"
)

result = TextGenerationResult(TextGenerationRequest(prompt=request.prompt))
result.start(request.prompt)
Expand Down
22 changes: 14 additions & 8 deletions src/guidellm/core/result.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from dataclasses import dataclass
from time import perf_counter, time
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Type, Union

from loguru import logger

Expand Down Expand Up @@ -270,7 +270,11 @@ class TextGenerationError:
:type error: Exception
"""

def __init__(self, request: TextGenerationRequest, error: Exception):
def __init__(
self,
request: TextGenerationRequest,
error_class: Type[BaseException],
):
"""
Initialize the TextGenerationError with a unique identifier.
Expand All @@ -279,10 +283,10 @@ def __init__(self, request: TextGenerationRequest, error: Exception):
:param error: The exception that occurred during the text generation.
:type error: Exception
"""
self._request = request
self._error = error
self._request: TextGenerationRequest = request
self._error_class: Type[BaseException] = error_class

logger.error(f"Error occurred for request: {self._request}: {error}")
logger.error(f"Error occurred for request: {self._request}: {error_class}")

def __repr__(self) -> str:
"""
Expand All @@ -291,7 +295,9 @@ def __repr__(self) -> str:
:return: String representation of the TextGenerationError.
:rtype: str
"""
return f"TextGenerationError(request={self._request}, error={self._error})"
return (
f"TextGenerationError(request={self._request}, error={self._error_class})"
)

@property
def request(self) -> TextGenerationRequest:
Expand All @@ -304,14 +310,14 @@ def request(self) -> TextGenerationRequest:
return self._request

@property
def error(self) -> Exception:
def error(self) -> Type[BaseException]:
"""
Get the exception that occurred during the text generation.
:return: The exception.
:rtype: Exception
"""
return self._error
return self._error_class


@dataclass
Expand Down
102 changes: 79 additions & 23 deletions src/guidellm/scheduler/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import asyncio
import time
from typing import Generator, List, Optional
from typing import Generator, List, Optional, Tuple

from loguru import logger

from guidellm.backend import Backend
from guidellm.core import TextGenerationBenchmark, TextGenerationError
Expand Down Expand Up @@ -36,6 +38,28 @@ def __init__(
self._max_requests = max_requests
self._max_duration = max_duration

# Tasks that scheduler is going to manage.
# NOTE: Tasks are populated in sync/async manner and limited by
# the max number of requests and max duration on the execution.
self._tasks: List[Tuple[asyncio.Task, Task]] = []

def __len__(self) -> int:
"""
The length of the scheduler
is the number of total tasks in the processing at the moment.
"""

return len(self._tasks)

@property
def event_loop(self) -> asyncio.AbstractEventLoop:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
return asyncio.get_event_loop()
else:
return loop

def run(self) -> TextGenerationBenchmark:
if self._load_gen_mode == LoadGenerationMode.SYNCHRONOUS:
report = self._run_sync()
Expand Down Expand Up @@ -67,6 +91,16 @@ def _run_sync(self) -> TextGenerationBenchmark:
return benchmark

async def _run_async(self) -> TextGenerationBenchmark:
"""
Running in async mode determines next steps:
* Iterate through all the tasks with load attached
* Check the execution time does not go over the max duration
* Check the number of requests is not greater than max requests
If the max duration is not specified for the scheduler - check only
max requests and just break the loop without cancelling tasks.
"""

benchmark: TextGenerationBenchmark = TextGenerationBenchmark(
mode=self._load_gen_mode.value, rate=self._load_gen_rate
)
Expand All @@ -75,16 +109,18 @@ async def _run_async(self) -> TextGenerationBenchmark:
load_gen = LoadGenerator(self._load_gen_mode, self._load_gen_rate)

start_time: float = time.time()
requests_counter = 0
tasks: List[asyncio.Task] = []
requests_counter: int = 0

for _task, task_start_time in zip(self._task_iterator(), load_gen.times()):
for task, task_start_time in zip(self._task_iterator(), load_gen.times()):
if (
self._max_requests is not None
and requests_counter >= self._max_requests
) or (
self._max_duration is not None
and time.time() - start_time >= self._max_duration
):
self.cancel_running_tasks(benchmark)
break
elif (
self._max_requests is not None
and requests_counter >= self._max_requests
):
break

Expand All @@ -93,32 +129,52 @@ async def _run_async(self) -> TextGenerationBenchmark:
if pending_time > 0:
await asyncio.sleep(pending_time)

tasks.append(
asyncio.create_task(self._run_task_async(_task, benchmark)),
self._tasks.append(
(asyncio.create_task(self._run_task_async(task, benchmark)), task)
)

requests_counter += 1

# Tasks execution strategy dispatcher
if self._max_duration is None:
# Ensure all the asyncio tasks are done
await asyncio.gather(*tasks)
return benchmark
await asyncio.gather(
*(asyncio_task for asyncio_task, _ in self._tasks),
return_exceptions=False,
)
else:
try:
# Wait for tasks execution if the self.max_duration is specified
await asyncio.wait_for(asyncio.gather(*tasks), self._max_duration)
except asyncio.TimeoutError:
breakpoint() # TODO: remove
# Return not fully filled benchmark if Task TTL is end
for task in tasks:
if not task.done():
task.cancel()
finally:
return benchmark
# Set the timeout if the max duration is specified
await asyncio.wait_for(
asyncio.gather(
*(asyncio_task for asyncio_task, _ in self._tasks),
return_exceptions=True,
),
self._max_duration,
)
except TimeoutError:
self.cancel_running_tasks(benchmark)

return benchmark

def cancel_running_tasks(self, benchmark: TextGenerationBenchmark) -> None:
"""
Cancel all the running tasks for the scheduler
"""

for asyncio_task, guidellm_task in self._tasks:
if not asyncio_task.done():
logger.debug(f"Cancelling running task {asyncio_task}")
asyncio_task.cancel()
benchmark.errors.append(
TextGenerationError(
**guidellm_task._params,
error_class=asyncio.CancelledError,
)
)

async def _run_task_async(self, task: Task, benchmark: TextGenerationBenchmark):
benchmark.request_started()
res = await task.run_async()
res = await task.run_async(self.event_loop)
benchmark.request_completed(res)

def _task_iterator(self) -> Generator[Task, None, None]:
Expand Down
54 changes: 9 additions & 45 deletions src/guidellm/scheduler/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,39 +36,26 @@ def __init__(
f"params: {self._params}"
)

async def run_async(self) -> Any:
async def run_async(self, event_loop: asyncio.AbstractEventLoop) -> Any:
"""
Run the task asynchronously.
:return: The output of the function.
:rtype: Any
"""

logger.info(f"Running task asynchronously with function: {self._func.__name__}")
try:
loop = asyncio.get_running_loop()

result = await asyncio.gather(
loop.run_in_executor(
None, functools.partial(self._func, **self._params)
),
return_exceptions=True,
try:
result = await event_loop.run_in_executor(
None, functools.partial(self._func, **self._params)
)
if isinstance(result[0], Exception):
raise result[0]
if isinstance(result, Exception):
raise result

if self.cancelled is True:
raise asyncio.CancelledError("Task was cancelled")

logger.info(f"Task completed with result: {result[0]}")
logger.info(f"Task completed with result: {result}")

return result[0]
except asyncio.CancelledError as cancel_err:
logger.warning("Task was cancelled")
return (
cancel_err
if not self._err_container
else self._err_container(**self._params, error=cancel_err)
)
return result
except Exception as err:
logger.error(f"Task failed with error: {err}")
return (
Expand Down Expand Up @@ -96,26 +83,3 @@ def run_sync(self) -> Any:
if not self._err_container
else self._err_container(**self._params, error=err)
)

def cancel(self) -> None:
"""
Cancel the task.
"""
logger.info("Cancelling task")
self._cancel_event.set()

async def _check_cancelled(self):
"""
Check if the task is cancelled.
"""
await self._cancel_event.wait()

@property
def cancelled(self) -> bool:
"""
Check if the task is cancelled.
:return: True if the task is cancelled, False otherwise.
:rtype: bool
"""
return self._cancel_event.is_set()
1 change: 0 additions & 1 deletion tests/unit/backend/test_openai_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def test_make_request(openai_backend_factory, openai_completion_create_patch):
backend_service.make_request(request=request),
openai_completion_create_patch,
):

total_generative_responses += 1
expected_token: Optional[str] = getattr(completion_patch, "content") or None

Expand Down
16 changes: 0 additions & 16 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,6 @@ def openai_completion_create_patch(
return cast(openai.Stream[openai.types.Completion], items)


@pytest.fixture(autouse=True)
def openai_async_completion_create_patch(
mocker,
):
"""
Mock available models function to avoid OpenAI API call.
"""

async def callback(*args, **kwargs):
return dummy.data.openai_completion_factory_async()

return mocker.patch(
"openai.resources.completions.AsyncCompletions.create", side_effect=callback
)


@pytest.fixture(autouse=True)
def openai_models_list_patch(mocker) -> List[openai.types.Model]:
"""
Expand Down
Loading

0 comments on commit 65aace3

Please sign in to comment.