Skip to content

Commit

Permalink
feat: add append_inputs_gen util for automatic batching of records (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
quettabit authored Jan 21, 2025
1 parent f077b75 commit 82f65c7
Show file tree
Hide file tree
Showing 7 changed files with 234 additions and 31 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ pip install streamstore

## Examples

`examples/streaming` directory in the [repo](https://github.com/s2-streamstore/s2-sdk-python/tree/main/examples/streaming) contain examples for streaming APIs.
`examples/` directory in the [repo](https://github.com/s2-streamstore/s2-sdk-python/tree/main/examples/) contain examples for streaming APIs.

## Get in touch

Expand Down
8 changes: 5 additions & 3 deletions examples/streaming/producer.py → examples/append_session.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import asyncio
import os
import random
from typing import AsyncIterable

from streamstore import S2
from streamstore.schemas import Record, AppendInput
from streamstore.schemas import AppendInput, Record

AUTH_TOKEN = os.getenv("S2_AUTH_TOKEN")
MY_BASIN = os.getenv("MY_BASIN")
MY_STREAM = os.getenv("MY_STREAM")


async def append_inputs():
async def append_inputs_gen() -> AsyncIterable[AppendInput]:
num_inputs = random.randint(1, 100)
for _ in range(num_inputs):
num_records = random.randint(1, 100)
Expand All @@ -26,7 +28,7 @@ async def append_inputs():
async def producer():
async with S2(auth_token=AUTH_TOKEN) as s2:
stream = s2[MY_BASIN][MY_STREAM]
async for output in stream.append_session(append_inputs()):
async for output in stream.append_session(append_inputs_gen()):
num_appended_records = output.end_seq_num - output.start_seq_num
print(f"appended {num_appended_records} records")

Expand Down
40 changes: 40 additions & 0 deletions examples/append_session_with_auto_batching.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import asyncio
import os
import random
from datetime import timedelta
from typing import AsyncIterable

from streamstore import S2
from streamstore.schemas import Record
from streamstore.utils import append_inputs_gen

AUTH_TOKEN = os.getenv("S2_AUTH_TOKEN")
MY_BASIN = os.getenv("MY_BASIN")
MY_STREAM = os.getenv("MY_STREAM")


async def records_gen() -> AsyncIterable[Record]:
num_records = random.randint(1, 100)
for _ in range(num_records):
body_size = random.randint(1, 1024)
if random.random() < 0.5:
await asyncio.sleep(random.random() * 2.5)
yield Record(body=os.urandom(body_size))


async def producer():
async with S2(auth_token=AUTH_TOKEN) as s2:
stream = s2[MY_BASIN][MY_STREAM]
async for output in stream.append_session(
append_inputs_gen(
records=records_gen(),
max_records_per_batch=10,
max_linger_per_batch=timedelta(milliseconds=5),
)
):
num_appended_records = output.end_seq_num - output.start_seq_num
print(f"appended {num_appended_records} records")


if __name__ == "__main__":
asyncio.run(producer())
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import os

from streamstore import S2

AUTH_TOKEN = os.getenv("S2_AUTH_TOKEN")
Expand Down
38 changes: 19 additions & 19 deletions src/streamstore/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
from datetime import timedelta
from typing import Self, TypedDict, cast

import stamina
from google.protobuf.field_mask_pb2 import FieldMask
from grpc import StatusCode, ssl_channel_credentials
from grpc.aio import AioRpcError, Channel, secure_channel
from stamina import AsyncRetryingCaller, retry_context

from streamstore import schemas
from streamstore._exceptions import fallible
Expand Down Expand Up @@ -205,9 +205,9 @@ def __init__(
enable_append_retries=enable_append_retries,
)
self._stub = AccountServiceStub(self._account_channel)
self._retrying_caller = stamina.AsyncRetryingCaller(
**self._config.retry_kwargs
).on(_grpc_retry_on)
self._retrying_caller = AsyncRetryingCaller(**self._config.retry_kwargs).on(
_grpc_retry_on
)

async def __aenter__(self) -> Self:
return self
Expand All @@ -226,7 +226,7 @@ async def close(self) -> None:
Close all open connections to S2 service endpoints.
Tip:
``S2`` supports async context manager protocol, so you could also do the following instead of
``S2`` supports async context manager protocol, so you can also do the following instead of
explicitly closing:
.. code-block:: python
Expand Down Expand Up @@ -297,7 +297,7 @@ def basin(self, name: str) -> "Basin":
async with S2(..) as s2:
basin = s2.basin("your-basin-name")
:class:`.S2` implements the ``getitem`` magic method, so you could also do the following instead:
:class:`.S2` implements the ``getitem`` magic method, so you can also do the following instead:
.. code-block:: python
Expand Down Expand Up @@ -438,9 +438,9 @@ def __init__(
) -> None:
self._channel = channel
self._config = config
self._retrying_caller = stamina.AsyncRetryingCaller(
**self._config.retry_kwargs
).on(_grpc_retry_on)
self._retrying_caller = AsyncRetryingCaller(**self._config.retry_kwargs).on(
_grpc_retry_on
)
self._stub = BasinServiceStub(self._channel)
self._name = name

Expand Down Expand Up @@ -508,7 +508,7 @@ def stream(self, name: str) -> "Stream":
async with S2(..) as s2:
stream = s2.basin("your-basin-name").stream("your-stream-name")
:class:`.Basin` implements the ``getitem`` magic method, so you could also do the following instead:
:class:`.Basin` implements the ``getitem`` magic method, so you can also do the following instead:
.. code-block:: python
Expand Down Expand Up @@ -636,9 +636,9 @@ class Stream:
def __init__(self, name: str, channel: Channel, config: _Config) -> None:
self._name = name
self._config = config
self._retrying_caller = stamina.AsyncRetryingCaller(
**self._config.retry_kwargs
).on(_grpc_retry_on)
self._retrying_caller = AsyncRetryingCaller(**self._config.retry_kwargs).on(
_grpc_retry_on
)
self._stub = StreamServiceStub(channel)

def __repr__(self) -> str:
Expand Down Expand Up @@ -710,9 +710,7 @@ async def _retrying_append_session(
self, inputs: AsyncIterable[schemas.AppendInput]
) -> AsyncIterable[schemas.AppendOutput]:
inflight_inputs: deque[schemas.AppendInput] = deque()
async for attempt in stamina.retry_context(
_grpc_retry_on, **self._config.retry_kwargs
):
async for attempt in retry_context(_grpc_retry_on, **self._config.retry_kwargs):
with attempt:
if len(inflight_inputs) != 0:
async for output in self._retrying_append_session_inner(
Expand All @@ -734,6 +732,10 @@ async def append_session(
Append batches of records to a stream continuously, while guaranteeing pipelined inputs are
processed in order.
Tip:
You can use :func:`.append_inputs_gen` for automatic batching of records instead of explicitly
preparing and passing batches of records.
Yields:
:class:`.AppendOutput` for each corresponding :class:`.AppendInput`.
Expand Down Expand Up @@ -864,9 +866,7 @@ async def read_session(
start_seq_num=start_seq_num,
limit=read_limit_message(limit),
)
async for attempt in stamina.retry_context(
_grpc_retry_on, **self._config.retry_kwargs
):
async for attempt in retry_context(_grpc_retry_on, **self._config.retry_kwargs):
with attempt:
async for response in self._read_session(request):
output = response.output
Expand Down
2 changes: 1 addition & 1 deletion src/streamstore/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class AppendInput:
"""

#: Batch of records to append atomically, which must contain at least one record,
#: and no more than 1000. The total size of the batch must not exceed 1MiB of :func:`.metered_bytes`.
#: and no more than 1000. The size of the batch must not exceed 1MiB of :func:`.metered_bytes`.
records: list[Record]
#: Enforce that the sequence number issued to the first record in the batch matches this value.
match_seq_num: int | None = None
Expand Down
174 changes: 167 additions & 7 deletions src/streamstore/utils.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
__all__ = [
"CommandRecord",
"metered_bytes",
]
__all__ = ["CommandRecord", "metered_bytes", "append_inputs_gen"]

from typing import Iterable
from asyncio import Queue, Task, create_task, sleep
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import AsyncIterable, Iterable, Self

from streamstore.schemas import Record, SequencedRecord
from streamstore.schemas import ONE_MIB, AppendInput, Record, SequencedRecord


class CommandRecord:
"""
Helper class for creating `command records <https://s2.dev/docs/stream#command-records>`_.
Factory class for creating `command records <https://s2.dev/docs/stream#command-records>`_.
"""

FENCE = b"fence"
Expand Down Expand Up @@ -67,3 +67,163 @@ def metered_bytes(records: Iterable[Record | SequencedRecord]) -> int:
)
for record in records
)


@dataclass(slots=True)
class _AutoBatcher:
_next_batch_idx: int = field(init=False)
_next_batch: list[Record] = field(init=False)
_next_batch_count: int = field(init=False)
_next_batch_bytes: int = field(init=False)
_linger_queue: Queue[tuple[int, datetime]] | None = field(init=False)
_linger_handler_task: Task | None = field(init=False)
_limits_handler_task: Task | None = field(init=False)

append_input_queue: Queue[AppendInput | None]
match_seq_num: int | None
fencing_token: bytes | None
max_records_per_batch: int
max_bytes_per_batch: int
max_linger_per_batch: timedelta | None

def __post_init__(self) -> None:
self._next_batch_idx = 0
self._next_batch = []
self._next_batch_count = 0
self._next_batch_bytes = 0
self._linger_queue = Queue() if self.max_linger_per_batch is not None else None
self._linger_handler_task = None
self._limits_handler_task = None

def _accumulate(self, record: Record) -> None:
self._next_batch.append(record)
self._next_batch_count += 1
self._next_batch_bytes += metered_bytes([record])

def _next_append_input(self) -> AppendInput:
append_input = AppendInput(
records=list(self._next_batch),
match_seq_num=self.match_seq_num,
fencing_token=self.fencing_token,
)
self._next_batch.clear()
self._next_batch_count = 0
self._next_batch_bytes = 0
self._next_batch_idx += 1
if self.match_seq_num is not None:
self.match_seq_num = self.match_seq_num + len(append_input.records)
return append_input

async def linger_handler(self) -> None:
if self.max_linger_per_batch is None:
return
if self._linger_queue is None:
return
linger_duration = self.max_linger_per_batch.total_seconds()
prev_linger_start = None
while True:
batch_idx, linger_start = await self._linger_queue.get()
if batch_idx < self._next_batch_idx:
continue
if prev_linger_start is None:
prev_linger_start = linger_start
missed_duration = (linger_start - prev_linger_start).total_seconds()
await sleep(max(linger_duration - missed_duration, 0))
if batch_idx == self._next_batch_idx:
append_input = self._next_append_input()
await self.append_input_queue.put(append_input)
prev_linger_start = linger_start

def _limits_met(self, record: Record) -> bool:
if (
self._next_batch_count + 1 <= self.max_records_per_batch
and self._next_batch_bytes + metered_bytes([record])
<= self.max_bytes_per_batch
):
return False
return True

async def limits_handler(self, records: AsyncIterable[Record]) -> None:
async for record in records:
if self._limits_met(record):
append_input = self._next_append_input()
await self.append_input_queue.put(append_input)
self._accumulate(record)
if self._linger_queue is not None and len(self._next_batch) == 1:
await self._linger_queue.put((self._next_batch_idx, datetime.now()))
if len(self._next_batch) != 0:
append_input = self._next_append_input()
await self.append_input_queue.put(append_input)
await self.append_input_queue.put(None)

def run(self, records: AsyncIterable[Record]) -> None:
if self.max_linger_per_batch is not None:
self._linger_handler_task = create_task(self.linger_handler())
self._limits_handler_task = create_task(self.limits_handler(records))

def cancel(self) -> None:
if self._linger_handler_task is not None:
self._linger_handler_task.cancel()
if self._limits_handler_task is not None:
self._limits_handler_task.cancel()


@dataclass(slots=True)
class _AppendInputAsyncIterator:
append_input_queue: Queue[AppendInput | None]

def __aiter__(self) -> Self:
return self

async def __anext__(self) -> AppendInput:
append_input = await self.append_input_queue.get()
if append_input is None:
raise StopAsyncIteration
return append_input


async def append_inputs_gen(
records: AsyncIterable[Record],
match_seq_num: int | None = None,
fencing_token: bytes | None = None,
max_records_per_batch: int = 1000,
max_bytes_per_batch: int = ONE_MIB,
max_linger_per_batch: timedelta | None = None,
) -> AsyncIterable[AppendInput]:
"""
Generator function for batching records and yielding :class:`.AppendInput`.
Returned generator object can be used as the parameter to :meth:`.Stream.append_session`.
Yields:
:class:`.AppendInput`
Args:
records: Records that have to be appended to a stream.
match_seq_num: If it is not ``None``, it is used in the first yield of :class:`.AppendInput`
and is automatically advanced for subsequent yields.
fencing_token: Used in each yield of :class:`.AppendInput`.
max_records_per_batch: Maximum number of records in each batch.
max_bytes_per_batch: Maximum size of each batch calculated using :func:`.metered_bytes`.
max_linger_per_batch: Maximum duration for each batch to accumulate records before yielding.
Note:
If **max_linger_per_batch** is ``None``, appending will not occur until one of the other two
limits -- **max_records_per_batch** or **max_bytes_per_batch** -- is met.
"""
append_input_queue: Queue[AppendInput | None] = Queue()
append_input_aiter = _AppendInputAsyncIterator(append_input_queue)
batcher = _AutoBatcher(
append_input_queue,
match_seq_num,
fencing_token,
max_records_per_batch,
max_bytes_per_batch,
max_linger_per_batch,
)
batcher.run(records)
try:
async for input in append_input_aiter:
yield input
finally:
batcher.cancel()

0 comments on commit 82f65c7

Please sign in to comment.