Skip to content

Commit

Permalink
adding support for CREATE/DELETE supersteams
Browse files Browse the repository at this point in the history
  • Loading branch information
DanielePalaia committed Apr 12, 2024
1 parent 06f6568 commit eec17fe
Show file tree
Hide file tree
Showing 8 changed files with 152 additions and 105 deletions.
2 changes: 2 additions & 0 deletions rstream/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
from .superstream_producer import ( # noqa: E402
RouteType,
SuperStreamProducer,
SuperStreamCreationOption,
)

from .constants import OffsetType # noqa: E402; noqa: E402
Expand All @@ -70,4 +71,5 @@
"OnClosedErrorInfo",
"SlasMechanism",
"FilterConfiguration",
"SuperStreamCreationOption",
]
30 changes: 30 additions & 0 deletions rstream/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,27 @@ async def create_stream(self, stream: str, arguments: Optional[dict[str, Any]] =
resp_schema=schema.CreateResponse,
)

async def create_super_stream(
self,
super_stream: str,
partitions: list[str],
binding_keys: list[str],
arguments: Optional[dict[str, Any]] = None,
) -> None:
if arguments is None:
arguments = {}

await self.sync_request(
schema.CreateSuperStream(
self._corr_id_seq.next(),
super_stream=super_stream,
partitions=partitions,
binding_keys=binding_keys,
arguments=[schema.Property(key, str(val)) for key, val in arguments.items()],
),
resp_schema=schema.CreateSuperStreamResponse,
)

async def delete_stream(self, stream: str) -> None:
await self.sync_request(
schema.Delete(
Expand All @@ -461,6 +482,15 @@ async def delete_stream(self, stream: str) -> None:
resp_schema=schema.DeleteResponse,
)

async def delete_super_stream(self, super_stream: str) -> None:
await self.sync_request(
schema.DeleteSuperStream(
self._corr_id_seq.next(),
super_stream=super_stream,
),
resp_schema=schema.DeleteSuperStreamResponse,
)

async def stream_exists(self, stream: str) -> bool:
try:
await self.sync_request(
Expand Down
2 changes: 2 additions & 0 deletions rstream/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ class Key(enum.Enum):
Partitions = 25
ConsumerUpdate = 26
CommandExchangeCommandVersion = 27
CommandCreateSuperStream = 29
CommandDeleteSuperStream = 30
ConsumerUpdateRequest = 32794


Expand Down
31 changes: 31 additions & 0 deletions rstream/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,20 @@ class DeleteResponse(Frame, is_response=True):
response_code: int = field(metadata={"type": T.uint16})


@dataclass
class DeleteSuperStream(Frame):
key = Key.CommandDeleteSuperStream
correlation_id: int = field(metadata={"type": T.uint32})
super_stream: str = field(metadata={"type": T.string})


@dataclass
class DeleteSuperStreamResponse(Frame, is_response=True):
key = Key.CommandDeleteSuperStream
correlation_id: int = field(metadata={"type": T.uint32})
response_code: int = field(metadata={"type": T.uint16})


@dataclass
class DeclarePublisher(Frame):
key = Key.DeclarePublisher
Expand Down Expand Up @@ -583,6 +597,23 @@ class ConsumerUpdateServerResponse(Frame, is_response=True):
offset_specification: OffsetSpecification


@dataclass
class CreateSuperStream(Frame):
key = Key.CommandCreateSuperStream
correlation_id: int = field(metadata={"type": T.uint32})
super_stream: str = field(metadata={"type": T.string})
partitions: list[str] = field(metadata={"type": [T.string]})
binding_keys: list[str] = field(metadata={"type": [T.string]})
arguments: list[Property]


@dataclass
class CreateSuperStreamResponse(Frame, is_response=True):
key = Key.CommandCreateSuperStream
correlation_id: int = field(metadata={"type": T.uint32})
response_code: int = field(metadata={"type": T.uint16})


def is_struct(obj: Any) -> bool:
return hasattr(obj, "flds_meta")

Expand Down
72 changes: 69 additions & 3 deletions rstream/superstream_producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import logging
import ssl
from dataclasses import dataclass
from enum import Enum
from typing import (
Annotated,
Expand All @@ -13,6 +14,7 @@
TypeVar,
)

from . import exceptions
from .amqp import _MessageProtocol
from .client import Client, ClientPool
from .producer import ConfirmationStatus, Producer
Expand All @@ -37,6 +39,13 @@ class RouteType(Enum):
Key = 1


@dataclass
class SuperStreamCreationOption:
n_partitions: int
binding_keys: Optional[list[str]] = None
arguments: Optional[dict[str, any]] = None


class SuperStreamProducer:
def __init__(
self,
Expand All @@ -48,6 +57,7 @@ def __init__(
username: str,
password: str,
super_stream: str,
super_stream_creation_option: Optional[SuperStreamCreationOption] = None,
routing_extractor: CB[Any],
routing: RouteType = RouteType.Hash,
frame_max: int = 1 * 1024 * 1024,
Expand Down Expand Up @@ -87,11 +97,11 @@ def __init__(
self._default_client: Optional[Client] = None
self._producer: Producer | None = None
self._routing_strategy: RoutingStrategy
# self._on_close_handler = on_close_handler
self._connection_name = connection_name
if self._connection_name is None:
self._connection_name = "rstream-producer"
self._filter_value_extractor: Optional[CB_F[Any]] = filter_value_extractor
self.super_stream_creation_option = super_stream_creation_option

async def _get_producer(self) -> Producer:
logger.debug("_get_producer() Making or getting a producer")
Expand Down Expand Up @@ -128,9 +138,9 @@ async def send(
await self._producer.send(stream=stream, message=message, on_publish_confirm=on_publish_confirm)

@property
def default_client(self) -> Client:
async def default_client(self) -> Client:
if self._default_client is None:
raise ValueError("Producer is not started")
self._default_client = await self._pool.get(connection_name="rstream-locator")
return self._default_client

async def __aenter__(self):
Expand All @@ -148,6 +158,15 @@ async def start(self) -> None:
else:
self._routing_strategy = RoutingKeyRoutingStrategy(self.routing_extractor)

if self.super_stream_creation_option is not None:
await self.create_super_stream(
self.super_stream,
self.super_stream_creation_option.n_partitions,
self.super_stream_creation_option.binding_keys,
self.super_stream_creation_option.arguments,
True,
)

async def close(self) -> None:
if self._default_client is not None:
await self._default_client.close()
Expand All @@ -159,3 +178,50 @@ async def close(self) -> None:
async def stream_exists(self, stream: str) -> bool:
producer = await self._get_producer()
return await producer.stream_exists(stream)

async def create_super_stream(
self,
super_stream: str,
n_partitions: int = 0,
binding_keys: list[str] = None,
arguments: Optional[dict[str, Any]] = None,
exists_ok: bool = False,
) -> None:
if binding_keys is not None and n_partitions != 0:
raise ValueError("Just one between n_partitions and binding_keys can be specified")

partitions = []
new_binding_key = []
if binding_keys is None:
for i in range(n_partitions):
partitions.append(super_stream + "-" + str(i))
new_binding_key.append(str(i))
else:
for i in range(len(binding_keys)):
new_binding_key = binding_keys
partitions.append(super_stream + "-" + binding_keys[i])

try:
await (await self.default_client).create_super_stream(
super_stream, partitions, new_binding_key, arguments
)
except exceptions.StreamAlreadyExists:
if not exists_ok:
raise
finally:
await self._close_locator_connection()

async def delete_super_stream(self, super_stream: str, missing_ok: bool = False) -> None:
try:
await (await self.default_client).delete_super_stream(super_stream)
except exceptions.StreamDoesNotExist:
if not missing_ok:
raise
finally:
await self._close_locator_connection()

async def _close_locator_connection(self):
if self._default_client is not None:
if await (await self.default_client).get_stream_count() == 0:
await (await self.default_client).close()
self._default_client = None
39 changes: 5 additions & 34 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@
)
from rstream.client import Client

from .http_requests import (
create_binding,
create_exchange,
delete_exchange,
)
from .util import (
filter_value_extractor,
routing_extractor,
Expand Down Expand Up @@ -156,39 +151,15 @@ async def producer_with_filtering(pytestconfig, ssl_context):

@pytest.fixture()
async def super_stream(client: Client):
# create an exchange to connect the 3 supersteams
super_stream = "test-super-stream"
status_code = create_exchange(exchange_name=super_stream)
assert status_code == 201 or status_code == 204

await client.create_stream(super_stream + "-0")
await client.create_stream(super_stream + "-1")
await client.create_stream(super_stream + "-2")

# create binding with exchange
status_code = create_binding(
exchange_name=super_stream, routing_key="key1", stream_name=super_stream + "-0"
)
assert status_code == 201 or status_code == 204
status_code = create_binding(
exchange_name=super_stream, routing_key="key2", stream_name=super_stream + "-1"
)
assert status_code == 201 or status_code == 204
status_code = create_binding(
exchange_name=super_stream, routing_key="key3", stream_name=super_stream + "-2"
await client.create_super_stream(
"test-super-stream",
["test-super-stream-0", "test-super-stream-1", "test-super-stream-2"],
["key1", "key2", "key3"],
)
assert status_code == 201 or status_code == 204

try:
yield "test-super-stream"
#
finally:
await client.delete_stream(super_stream + "-0")
await client.delete_stream(super_stream + "-1")
await client.delete_stream(super_stream + "-2")

status_code = delete_exchange(exchange_name=super_stream)
assert status_code == 201 or status_code == 204
await client.delete_super_stream("test-super-stream")


@pytest.fixture()
Expand Down
22 changes: 0 additions & 22 deletions tests/http_requests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,6 @@
from requests.auth import HTTPBasicAuth


def create_exchange(exchange_name: str) -> int:
request = "http://guest:guest@localhost:15672/api/exchanges/%2f" + "/" + exchange_name
response = requests.put(request)
return response.status_code


def delete_exchange(exchange_name: str) -> int:
request = "http://guest:guest@localhost:15672/api/exchanges/%2f" + "/" + exchange_name
response = requests.delete(request)
return response.status_code


def create_binding(exchange_name: str, routing_key: str, stream_name: str):
data = {
"routing_key": routing_key,
}
request = "http://guest:guest@localhost:15672/api/bindings/%2f/e/" + exchange_name + "/q/" + stream_name

response = requests.post(request, json=data)
return response.status_code


def get_connections() -> list:
request = "http://localhost:15672/api/connections"
response = requests.get(request, auth=HTTPBasicAuth("guest", "guest"))
Expand Down
Loading

0 comments on commit eec17fe

Please sign in to comment.