diff --git a/README.md b/README.md index bc4bc27..6a71880 100644 --- a/README.md +++ b/README.md @@ -34,7 +34,7 @@ pip install streamstore ## Examples -`examples/streaming` directory in the [repo](https://github.com/s2-streamstore/s2-sdk-python/tree/main/examples/streaming) contain examples for streaming APIs. +`examples/` directory in the [repo](https://github.com/s2-streamstore/s2-sdk-python/tree/main/examples/) contain examples for streaming APIs. ## Get in touch diff --git a/examples/streaming/producer.py b/examples/append_session.py similarity index 80% rename from examples/streaming/producer.py rename to examples/append_session.py index 5a1efae..9d0dbd2 100644 --- a/examples/streaming/producer.py +++ b/examples/append_session.py @@ -1,15 +1,17 @@ import asyncio import os import random +from typing import AsyncIterable + from streamstore import S2 -from streamstore.schemas import Record, AppendInput +from streamstore.schemas import AppendInput, Record AUTH_TOKEN = os.getenv("S2_AUTH_TOKEN") MY_BASIN = os.getenv("MY_BASIN") MY_STREAM = os.getenv("MY_STREAM") -async def append_inputs(): +async def append_inputs_gen() -> AsyncIterable[AppendInput]: num_inputs = random.randint(1, 100) for _ in range(num_inputs): num_records = random.randint(1, 100) @@ -26,7 +28,7 @@ async def append_inputs(): async def producer(): async with S2(auth_token=AUTH_TOKEN) as s2: stream = s2[MY_BASIN][MY_STREAM] - async for output in stream.append_session(append_inputs()): + async for output in stream.append_session(append_inputs_gen()): num_appended_records = output.end_seq_num - output.start_seq_num print(f"appended {num_appended_records} records") diff --git a/examples/append_session_with_auto_batching.py b/examples/append_session_with_auto_batching.py new file mode 100644 index 0000000..cf10672 --- /dev/null +++ b/examples/append_session_with_auto_batching.py @@ -0,0 +1,40 @@ +import asyncio +import os +import random +from datetime import timedelta +from typing import AsyncIterable + +from streamstore import S2 +from streamstore.schemas import Record +from streamstore.utils import append_inputs_gen + +AUTH_TOKEN = os.getenv("S2_AUTH_TOKEN") +MY_BASIN = os.getenv("MY_BASIN") +MY_STREAM = os.getenv("MY_STREAM") + + +async def records_gen() -> AsyncIterable[Record]: + num_records = random.randint(1, 100) + for _ in range(num_records): + body_size = random.randint(1, 1024) + if random.random() < 0.5: + await asyncio.sleep(random.random() * 2.5) + yield Record(body=os.urandom(body_size)) + + +async def producer(): + async with S2(auth_token=AUTH_TOKEN) as s2: + stream = s2[MY_BASIN][MY_STREAM] + async for output in stream.append_session( + append_inputs_gen( + records=records_gen(), + max_records_per_batch=10, + max_linger_per_batch=timedelta(milliseconds=5), + ) + ): + num_appended_records = output.end_seq_num - output.start_seq_num + print(f"appended {num_appended_records} records") + + +if __name__ == "__main__": + asyncio.run(producer()) diff --git a/examples/streaming/consumer.py b/examples/read_session.py similarity index 99% rename from examples/streaming/consumer.py rename to examples/read_session.py index e6a6b0c..dc0ca09 100644 --- a/examples/streaming/consumer.py +++ b/examples/read_session.py @@ -1,5 +1,6 @@ import asyncio import os + from streamstore import S2 AUTH_TOKEN = os.getenv("S2_AUTH_TOKEN") diff --git a/src/streamstore/_client.py b/src/streamstore/_client.py index ec98294..1ca3d65 100644 --- a/src/streamstore/_client.py +++ b/src/streamstore/_client.py @@ -6,10 +6,10 @@ from datetime import timedelta from typing import Self, TypedDict, cast -import stamina from google.protobuf.field_mask_pb2 import FieldMask from grpc import StatusCode, ssl_channel_credentials from grpc.aio import AioRpcError, Channel, secure_channel +from stamina import AsyncRetryingCaller, retry_context from streamstore import schemas from streamstore._exceptions import fallible @@ -205,9 +205,9 @@ def __init__( enable_append_retries=enable_append_retries, ) self._stub = AccountServiceStub(self._account_channel) - self._retrying_caller = stamina.AsyncRetryingCaller( - **self._config.retry_kwargs - ).on(_grpc_retry_on) + self._retrying_caller = AsyncRetryingCaller(**self._config.retry_kwargs).on( + _grpc_retry_on + ) async def __aenter__(self) -> Self: return self @@ -226,7 +226,7 @@ async def close(self) -> None: Close all open connections to S2 service endpoints. Tip: - ``S2`` supports async context manager protocol, so you could also do the following instead of + ``S2`` supports async context manager protocol, so you can also do the following instead of explicitly closing: .. code-block:: python @@ -297,7 +297,7 @@ def basin(self, name: str) -> "Basin": async with S2(..) as s2: basin = s2.basin("your-basin-name") - :class:`.S2` implements the ``getitem`` magic method, so you could also do the following instead: + :class:`.S2` implements the ``getitem`` magic method, so you can also do the following instead: .. code-block:: python @@ -438,9 +438,9 @@ def __init__( ) -> None: self._channel = channel self._config = config - self._retrying_caller = stamina.AsyncRetryingCaller( - **self._config.retry_kwargs - ).on(_grpc_retry_on) + self._retrying_caller = AsyncRetryingCaller(**self._config.retry_kwargs).on( + _grpc_retry_on + ) self._stub = BasinServiceStub(self._channel) self._name = name @@ -508,7 +508,7 @@ def stream(self, name: str) -> "Stream": async with S2(..) as s2: stream = s2.basin("your-basin-name").stream("your-stream-name") - :class:`.Basin` implements the ``getitem`` magic method, so you could also do the following instead: + :class:`.Basin` implements the ``getitem`` magic method, so you can also do the following instead: .. code-block:: python @@ -636,9 +636,9 @@ class Stream: def __init__(self, name: str, channel: Channel, config: _Config) -> None: self._name = name self._config = config - self._retrying_caller = stamina.AsyncRetryingCaller( - **self._config.retry_kwargs - ).on(_grpc_retry_on) + self._retrying_caller = AsyncRetryingCaller(**self._config.retry_kwargs).on( + _grpc_retry_on + ) self._stub = StreamServiceStub(channel) def __repr__(self) -> str: @@ -710,9 +710,7 @@ async def _retrying_append_session( self, inputs: AsyncIterable[schemas.AppendInput] ) -> AsyncIterable[schemas.AppendOutput]: inflight_inputs: deque[schemas.AppendInput] = deque() - async for attempt in stamina.retry_context( - _grpc_retry_on, **self._config.retry_kwargs - ): + async for attempt in retry_context(_grpc_retry_on, **self._config.retry_kwargs): with attempt: if len(inflight_inputs) != 0: async for output in self._retrying_append_session_inner( @@ -734,6 +732,10 @@ async def append_session( Append batches of records to a stream continuously, while guaranteeing pipelined inputs are processed in order. + Tip: + You can use :func:`.append_inputs_gen` for automatic batching of records instead of explicitly + preparing and passing batches of records. + Yields: :class:`.AppendOutput` for each corresponding :class:`.AppendInput`. @@ -864,9 +866,7 @@ async def read_session( start_seq_num=start_seq_num, limit=read_limit_message(limit), ) - async for attempt in stamina.retry_context( - _grpc_retry_on, **self._config.retry_kwargs - ): + async for attempt in retry_context(_grpc_retry_on, **self._config.retry_kwargs): with attempt: async for response in self._read_session(request): output = response.output diff --git a/src/streamstore/schemas.py b/src/streamstore/schemas.py index ce4ed39..ea8a22d 100644 --- a/src/streamstore/schemas.py +++ b/src/streamstore/schemas.py @@ -54,7 +54,7 @@ class AppendInput: """ #: Batch of records to append atomically, which must contain at least one record, - #: and no more than 1000. The total size of the batch must not exceed 1MiB of :func:`.metered_bytes`. + #: and no more than 1000. The size of the batch must not exceed 1MiB of :func:`.metered_bytes`. records: list[Record] #: Enforce that the sequence number issued to the first record in the batch matches this value. match_seq_num: int | None = None diff --git a/src/streamstore/utils.py b/src/streamstore/utils.py index 4131a51..4e22d37 100644 --- a/src/streamstore/utils.py +++ b/src/streamstore/utils.py @@ -1,16 +1,16 @@ -__all__ = [ - "CommandRecord", - "metered_bytes", -] +__all__ = ["CommandRecord", "metered_bytes", "append_inputs_gen"] -from typing import Iterable +from asyncio import Queue, Task, create_task, sleep +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import AsyncIterable, Iterable, Self -from streamstore.schemas import Record, SequencedRecord +from streamstore.schemas import ONE_MIB, AppendInput, Record, SequencedRecord class CommandRecord: """ - Helper class for creating `command records `_. + Factory class for creating `command records `_. """ FENCE = b"fence" @@ -67,3 +67,163 @@ def metered_bytes(records: Iterable[Record | SequencedRecord]) -> int: ) for record in records ) + + +@dataclass(slots=True) +class _AutoBatcher: + _next_batch_idx: int = field(init=False) + _next_batch: list[Record] = field(init=False) + _next_batch_count: int = field(init=False) + _next_batch_bytes: int = field(init=False) + _linger_queue: Queue[tuple[int, datetime]] | None = field(init=False) + _linger_handler_task: Task | None = field(init=False) + _limits_handler_task: Task | None = field(init=False) + + append_input_queue: Queue[AppendInput | None] + match_seq_num: int | None + fencing_token: bytes | None + max_records_per_batch: int + max_bytes_per_batch: int + max_linger_per_batch: timedelta | None + + def __post_init__(self) -> None: + self._next_batch_idx = 0 + self._next_batch = [] + self._next_batch_count = 0 + self._next_batch_bytes = 0 + self._linger_queue = Queue() if self.max_linger_per_batch is not None else None + self._linger_handler_task = None + self._limits_handler_task = None + + def _accumulate(self, record: Record) -> None: + self._next_batch.append(record) + self._next_batch_count += 1 + self._next_batch_bytes += metered_bytes([record]) + + def _next_append_input(self) -> AppendInput: + append_input = AppendInput( + records=list(self._next_batch), + match_seq_num=self.match_seq_num, + fencing_token=self.fencing_token, + ) + self._next_batch.clear() + self._next_batch_count = 0 + self._next_batch_bytes = 0 + self._next_batch_idx += 1 + if self.match_seq_num is not None: + self.match_seq_num = self.match_seq_num + len(append_input.records) + return append_input + + async def linger_handler(self) -> None: + if self.max_linger_per_batch is None: + return + if self._linger_queue is None: + return + linger_duration = self.max_linger_per_batch.total_seconds() + prev_linger_start = None + while True: + batch_idx, linger_start = await self._linger_queue.get() + if batch_idx < self._next_batch_idx: + continue + if prev_linger_start is None: + prev_linger_start = linger_start + missed_duration = (linger_start - prev_linger_start).total_seconds() + await sleep(max(linger_duration - missed_duration, 0)) + if batch_idx == self._next_batch_idx: + append_input = self._next_append_input() + await self.append_input_queue.put(append_input) + prev_linger_start = linger_start + + def _limits_met(self, record: Record) -> bool: + if ( + self._next_batch_count + 1 <= self.max_records_per_batch + and self._next_batch_bytes + metered_bytes([record]) + <= self.max_bytes_per_batch + ): + return False + return True + + async def limits_handler(self, records: AsyncIterable[Record]) -> None: + async for record in records: + if self._limits_met(record): + append_input = self._next_append_input() + await self.append_input_queue.put(append_input) + self._accumulate(record) + if self._linger_queue is not None and len(self._next_batch) == 1: + await self._linger_queue.put((self._next_batch_idx, datetime.now())) + if len(self._next_batch) != 0: + append_input = self._next_append_input() + await self.append_input_queue.put(append_input) + await self.append_input_queue.put(None) + + def run(self, records: AsyncIterable[Record]) -> None: + if self.max_linger_per_batch is not None: + self._linger_handler_task = create_task(self.linger_handler()) + self._limits_handler_task = create_task(self.limits_handler(records)) + + def cancel(self) -> None: + if self._linger_handler_task is not None: + self._linger_handler_task.cancel() + if self._limits_handler_task is not None: + self._limits_handler_task.cancel() + + +@dataclass(slots=True) +class _AppendInputAsyncIterator: + append_input_queue: Queue[AppendInput | None] + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> AppendInput: + append_input = await self.append_input_queue.get() + if append_input is None: + raise StopAsyncIteration + return append_input + + +async def append_inputs_gen( + records: AsyncIterable[Record], + match_seq_num: int | None = None, + fencing_token: bytes | None = None, + max_records_per_batch: int = 1000, + max_bytes_per_batch: int = ONE_MIB, + max_linger_per_batch: timedelta | None = None, +) -> AsyncIterable[AppendInput]: + """ + Generator function for batching records and yielding :class:`.AppendInput`. + + Returned generator object can be used as the parameter to :meth:`.Stream.append_session`. + + Yields: + :class:`.AppendInput` + + Args: + records: Records that have to be appended to a stream. + match_seq_num: If it is not ``None``, it is used in the first yield of :class:`.AppendInput` + and is automatically advanced for subsequent yields. + fencing_token: Used in each yield of :class:`.AppendInput`. + max_records_per_batch: Maximum number of records in each batch. + max_bytes_per_batch: Maximum size of each batch calculated using :func:`.metered_bytes`. + max_linger_per_batch: Maximum duration for each batch to accumulate records before yielding. + + Note: + If **max_linger_per_batch** is ``None``, appending will not occur until one of the other two + limits -- **max_records_per_batch** or **max_bytes_per_batch** -- is met. + """ + append_input_queue: Queue[AppendInput | None] = Queue() + append_input_aiter = _AppendInputAsyncIterator(append_input_queue) + batcher = _AutoBatcher( + append_input_queue, + match_seq_num, + fencing_token, + max_records_per_batch, + max_bytes_per_batch, + max_linger_per_batch, + ) + batcher.run(records) + try: + async for input in append_input_aiter: + yield input + finally: + batcher.cancel()