Skip to content

Commit

Permalink
Refactor pipeline helper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok committed Jan 17, 2025
1 parent 3573933 commit 2c319ed
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 44 deletions.
41 changes: 12 additions & 29 deletions src/spdl/pipeline/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,14 @@

from . import _convert
from ._convert import _to_async_gen, Callables
from ._hook import _stage_hooks, _task_hooks, PipelineHook, TaskStatsHook
from ._hook import (
_stage_hooks,
_task_hooks,
_time_str,
PipelineHook,
StatsCounter,
TaskStatsHook,
)
from ._pipeline import Pipeline
from ._utils import create_task

Expand Down Expand Up @@ -136,7 +143,7 @@ def _pipe(
if inspect.iscoroutinefunction(afunc):

async def _wrap(coro: Awaitable[U]) -> None:
async with _task_hooks(hooks): # pyre-ignore: [16]
async with _task_hooks(hooks):
result = await coro

await output_queue.put(result)
Expand Down Expand Up @@ -272,7 +279,7 @@ def _ordered_pipe(

async def _wrap(item: T) -> asyncio.Task[U]:
async def _with_hooks():
async with _task_hooks(hooks): # pyre-ignore: [16]
async with _task_hooks(hooks):
return await afunc(item)

return create_task(_with_hooks())
Expand Down Expand Up @@ -342,34 +349,10 @@ async def enqueue():
################################################################################


def _time_str(val: float) -> str:
return "{:.4f} [{:>3s}]".format(
val * 1000 if val < 1 else val,
"ms" if val < 1 else "sec",
)


class _Counter:
def __init__(self):
self.num_items = 0
self.ave_time = 0

@contextmanager
def count(self):
t0 = time.monotonic()
yield
elapsed = time.monotonic() - t0
self.num_items += 1
self.ave_time += (elapsed - self.ave_time) / self.num_items

def __str__(self):
return _time_str(self.ave_time)


@contextmanager
def _sink_stats():
get_counter = _Counter()
put_counter = _Counter()
get_counter = StatsCounter()
put_counter = StatsCounter()
t0 = time.monotonic()
try:
yield get_counter, put_counter
Expand Down
63 changes: 48 additions & 15 deletions src/spdl/pipeline/_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@
import logging
import time
from abc import ABC, abstractmethod
from collections.abc import Sequence
from contextlib import asynccontextmanager, AsyncExitStack
from typing import TypeVar
from collections.abc import AsyncIterator, Iterator, Sequence
from contextlib import asynccontextmanager, AsyncExitStack, contextmanager
from typing import AsyncContextManager, TypeVar

from ._utils import create_task

__all__ = [
"_stage_hooks",
"_task_hooks",
"_time_str",
"PipelineHook",
"TaskStatsHook",
"StatsCounter",
]

_LG = logging.getLogger(__name__)
Expand All @@ -27,6 +31,37 @@
T = TypeVar("T")


def _time_str(val: float) -> str:
return "{:.4f} [{:>3s}]".format(
val * 1000 if val < 1 else val,
"ms" if val < 1 else "sec",
)


class StatsCounter:
def __init__(self):
self.num_items: int = 0
self.ave_time: float = 0.0

def reset(self) -> None:
self.num_items = 0
self.ave_time = 0

def update(self, val: float) -> None:
self.num_items += 1
self.ave_time += (val - self.ave_time) / self.num_items

@contextmanager
def count(self) -> Iterator[None]:
t0 = time.monotonic()
yield
elapsed = time.monotonic() - t0
self.update(elapsed)

def __str__(self):
return _time_str(self.ave_time)


class PipelineHook(ABC):
"""Base class for hooks to be used in the pipeline.
Expand Down Expand Up @@ -149,7 +184,7 @@ async def stage_hook(self):

@abstractmethod
@asynccontextmanager
async def task_hook(self):
async def task_hook(self) -> AsyncIterator[None]:
"""Perform custom action before and after task is executed.
.. important::
Expand Down Expand Up @@ -183,7 +218,7 @@ def _stage_hooks(hooks: Sequence[PipelineHook]):
)

@asynccontextmanager
async def stage_hooks():
async def stage_hooks() -> AsyncIterator[None]:
async with AsyncExitStack() as stack:
for h in hs:
await stack.enter_async_context(h)
Expand All @@ -192,7 +227,7 @@ async def stage_hooks():
return stage_hooks()


def _task_hooks(hooks: Sequence[PipelineHook]):
def _task_hooks(hooks: Sequence[PipelineHook]) -> AsyncContextManager[None]:
hs = [hook.task_hook() for hook in hooks]

if not all(hasattr(h, "__aenter__") or hasattr(h, "__aexit__") for h in hs):
Expand All @@ -203,7 +238,7 @@ def _task_hooks(hooks: Sequence[PipelineHook]):
)

@asynccontextmanager
async def task_hooks():
async def task_hooks() -> AsyncIterator[None]:
async with AsyncExitStack() as stack:
for h in hs:
await stack.enter_async_context(h)
Expand Down Expand Up @@ -241,7 +276,7 @@ def __init__(self, name: str, concurrency: int, interval: float | None = None):

# For interval
self._int_task = None
self._int_t0 = 0
self._int_t0 = 0.0
self._int_num_tasks = 0
self._int_num_success = 0
self._int_ave_time = 0.0
Expand All @@ -262,7 +297,7 @@ async def stage_hook(self):
yield
finally:
elapsed = time.monotonic() - t0
if self.interval is not None:
if self._int_task is not None:
self._int_task.cancel()
self._log_stats(elapsed, self.num_tasks, self.num_success, self.ave_time)

Expand Down Expand Up @@ -312,17 +347,15 @@ async def _log_interval_stats(self):

def _log_stats(self, elapsed, num_tasks, num_success, ave_time):
_LG.info(
"[%s]\tCompleted %5d tasks (%3d failed) in %.4f [%3s]. "
"[%s]\tCompleted %5d tasks (%3d failed) in %s. "
"QPS: %.2f (Concurrency: %3d). "
"Average task time: %.4f [%3s].",
"Average task time: %s.",
self.name,
num_tasks,
num_tasks - num_success,
elapsed * 1000 if elapsed < 1 else elapsed,
"ms" if elapsed < 1 else "sec",
_time_str(elapsed),
num_success / elapsed if elapsed > 0.001 else float("nan"),
self.concurrency,
ave_time * 1000 if ave_time < 1 else ave_time,
"ms" if ave_time < 1 else "sec",
_time_str(ave_time),
stacklevel=2,
)

0 comments on commit 2c319ed

Please sign in to comment.