Skip to content

Commit

Permalink
rename ConnectionTokenContext to ClientTokenContext, add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
FZambia committed Feb 4, 2024
1 parent a97e2bb commit 951fc39
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 17 deletions.
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ When using Protobuf protocol:
* all payloads received from the library will be `bytes` or `None` if not present.
* don't forget that when using Protobuf protocol you can still have JSON payloads - just encode them to `bytes` before passing to the library.

## Callbacks should not block

Event callbacks are called by SDK using `await` internally, the websocket connection read loop is blocked for the time SDK waits for the callback to be executed. This means that if you need to perform long operations in callbacks consider moving the work to a separate coroutine/task to return fast and continue reading data from the websocket.

## Run tests

To run tests, first start Centrifugo server:
Expand Down
4 changes: 2 additions & 2 deletions centrifuge/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from .contexts import (
ConnectedContext,
ConnectingContext,
ConnectionTokenContext,
ClientTokenContext,
DisconnectedContext,
ErrorContext,
JoinContext,
Expand Down Expand Up @@ -50,9 +50,9 @@
"ClientEventHandler",
"ClientInfo",
"ClientState",
"ClientTokenContext",
"ConnectedContext",
"ConnectingContext",
"ConnectionTokenContext",
"DisconnectedContext",
"DuplicateSubscriptionError",
"ErrorContext",
Expand Down
8 changes: 4 additions & 4 deletions centrifuge/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from centrifuge.contexts import (
ConnectedContext,
ConnectingContext,
ConnectionTokenContext,
ClientTokenContext,
DisconnectedContext,
ErrorContext,
JoinContext,
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(
address: str,
events: Optional[ClientEventHandler] = None,
token: str = "",
get_token: Optional[Callable[[ConnectionTokenContext], Awaitable[str]]] = None,
get_token: Optional[Callable[[ClientTokenContext], Awaitable[str]]] = None,
use_protobuf: bool = False,
timeout: float = 5.0,
max_server_ping_delay: float = 10.0,
Expand Down Expand Up @@ -308,7 +308,7 @@ async def _create_connection(self) -> bool:

if not self._token and self._get_token:
try:
token = await self._get_token(ConnectionTokenContext())
token = await self._get_token(ClientTokenContext())
except Exception as e:
if isinstance(e, UnauthorizedError):
code = _DisconnectedCode.UNAUTHORIZED
Expand Down Expand Up @@ -551,7 +551,7 @@ async def _refresh(self) -> None:
cmd_id = self._next_command_id()

try:
token = await self._get_token(ConnectionTokenContext())
token = await self._get_token(ClientTokenContext())
except Exception as e:
if isinstance(e, UnauthorizedError):
code = _DisconnectedCode.UNAUTHORIZED
Expand Down
4 changes: 2 additions & 2 deletions centrifuge/contexts.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,8 @@ class SubscriptionErrorContext:


@dataclass
class ConnectionTokenContext:
"""ConnectionTokenContext is a context passed to get_token callback of connection."""
class ClientTokenContext:
"""ClientTokenContext is a context passed to get_token callback of connection."""


@dataclass
Expand Down
15 changes: 6 additions & 9 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
ClientEventHandler,
ConnectedContext,
ConnectingContext,
ConnectionTokenContext,
ClientTokenContext,
DisconnectedContext,
ErrorContext,
JoinContext,
Expand Down Expand Up @@ -37,7 +37,7 @@
cf_logger.setLevel(logging.DEBUG)


async def get_token(ctx: ConnectionTokenContext) -> str:
async def get_token(ctx: ClientTokenContext) -> str:
# To reject connection raise centrifuge.UnauthorizedError() exception:
# raise centrifuge.UnauthorizedError()

Expand Down Expand Up @@ -67,6 +67,8 @@ async def get_subscription_token(ctx: SubscriptionTokenContext) -> str:


class ClientEventLoggerHandler(ClientEventHandler):
"""Check out comments of ClientEventHandler methods to see when they are called."""

async def on_connecting(self, ctx: ConnectingContext) -> None:
logging.info("connecting: %s", ctx)

Expand Down Expand Up @@ -99,6 +101,8 @@ async def on_leave(self, ctx: ServerLeaveContext) -> None:


class SubscriptionEventLoggerHandler(SubscriptionEventHandler):
"""Check out comments of SubscriptionEventHandler methods to see when they are called."""

async def on_subscribing(self, ctx: SubscribingContext) -> None:
logging.info("subscribing: %s", ctx)

Expand Down Expand Up @@ -174,13 +178,6 @@ async def run():
async def shutdown(received_signal):
logging.info("received exit signal %s...", received_signal.name)
await client.disconnect()

tasks = [t for t in asyncio.all_tasks() if t is not asyncio.current_task()]
for task in tasks:
task.cancel()

logging.info("Cancelling outstanding tasks")
await asyncio.gather(*tasks, return_exceptions=True)
loop.stop()

signals = (signal.SIGTERM, signal.SIGINT)
Expand Down
68 changes: 68 additions & 0 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
SubscriptionState,
PublicationContext,
SubscribedContext,
ClientTokenContext,
SubscriptionTokenContext,
DisconnectedContext,
UnsubscribedContext,
)

logging.basicConfig(
Expand Down Expand Up @@ -173,3 +177,67 @@ async def on_subscribed_after_recovery(ctx: SubscribedContext) -> None:
self.assertEqual(len(results), 5)
await client1.disconnect()
await client2.disconnect()


class TestClientToken(unittest.IsolatedAsyncioTestCase):
async def test_client_token(self) -> None:
for use_protobuf in (False, True):
with self.subTest(use_protobuf=use_protobuf):
await self._test_client_token(use_protobuf=use_protobuf)

async def _test_client_token(self, use_protobuf=False) -> None:
future = asyncio.Future()

async def test_get_client_token(ctx: ClientTokenContext) -> str:
self.assertEqual(ctx, ClientTokenContext())
return "invalid_token"

client = Client(
"ws://localhost:8000/connection/websocket",
use_protobuf=use_protobuf,
get_token=test_get_client_token,
)

async def on_disconnected(ctx: DisconnectedContext) -> None:
future.set_result(ctx.code)

client.events.on_disconnected = on_disconnected

await client.connect()
res = await future
self.assertTrue(res == 3500)
self.assertTrue(client.state == ClientState.DISCONNECTED)
await client.disconnect()


class TestSubscriptionToken(unittest.IsolatedAsyncioTestCase):
async def test_client_token(self) -> None:
for use_protobuf in (False, True):
with self.subTest(use_protobuf=use_protobuf):
await self._test_subscription_token(use_protobuf=use_protobuf)

async def _test_subscription_token(self, use_protobuf=False) -> None:
future = asyncio.Future()

async def test_get_subscription_token(ctx: SubscriptionTokenContext) -> str:
self.assertEqual(ctx, SubscriptionTokenContext(channel="channel"))
return "invalid_token"

client = Client(
"ws://localhost:8000/connection/websocket",
use_protobuf=use_protobuf,
)

sub = client.new_subscription("channel", get_token=test_get_subscription_token)

async def on_unsubscribed(ctx: UnsubscribedContext) -> None:
future.set_result(ctx.code)

sub.events.on_unsubscribed = on_unsubscribed

await client.connect()
await sub.subscribe()
res = await future
self.assertTrue(res == 103, res)
self.assertTrue(client.state == ClientState.CONNECTED)
await client.disconnect()

0 comments on commit 951fc39

Please sign in to comment.