diff --git a/tests/test_connection.py b/tests/test_connection.py index 8a1465d..9902914 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -207,6 +207,51 @@ async def test_client_send_and_receive(echo_conn): assert received_msg == 'This is a test message.' +async def test_client_ping(echo_conn): + async with echo_conn: + await echo_conn.ping(b'A') + with pytest.raises(ConnectionClosed): + await echo_conn.ping(b'B') + + +async def test_client_ping_two_payloads(echo_conn): + pong_count = 0 + async def ping_and_count(): + nonlocal pong_count + await echo_conn.ping() + pong_count += 1 + async with echo_conn: + async with trio.open_nursery() as nursery: + nursery.start_soon(ping_and_count) + nursery.start_soon(ping_and_count) + assert pong_count == 2 + + +async def test_client_ping_same_payload(echo_conn): + # This test verifies that two tasks can't ping with the same payload at the + # same time. One of them should succeed and the other should get an + # exception. + exc_count = 0 + async def ping_and_catch(): + nonlocal exc_count + try: + await echo_conn.ping(b'A') + except Exception: + exc_count += 1 + async with echo_conn: + async with trio.open_nursery() as nursery: + nursery.start_soon(ping_and_catch) + nursery.start_soon(ping_and_catch) + assert exc_count == 1 + + +async def test_client_pong(echo_conn): + async with echo_conn: + await echo_conn.pong(b'A') + with pytest.raises(ConnectionClosed): + await echo_conn.pong(b'B') + + async def test_client_default_close(echo_conn): async with echo_conn: assert not echo_conn.is_closed diff --git a/trio_websocket/__init__.py b/trio_websocket/__init__.py index ed687a5..48b2129 100644 --- a/trio_websocket/__init__.py +++ b/trio_websocket/__init__.py @@ -1,7 +1,10 @@ +from collections import OrderedDict +from functools import partial import itertools import logging +import random import ssl -from functools import partial +import struct from async_generator import async_generator, yield_, asynccontextmanager import attr @@ -320,6 +323,7 @@ def __init__(self, stream, wsproto, path=None): self._reader_running = True self._path = path self._put_channel, self._get_channel = open_channel(0) + self._pings = OrderedDict() # Set once the WebSocket open handshake takes place, i.e. # ConnectionRequested for server or ConnectedEstablished for client. self._open_handshake = trio.Event() @@ -398,13 +402,37 @@ async def get_message(self): raise ConnectionClosed(self._close_reason) from None return message - async def ping(self, payload): + async def ping(self, payload=None): ''' - Send WebSocket ping to peer. + Send WebSocket ping to peer and wait for a correspoding pong. + + Each ping is matched to its expected pong by its payload value. An + exception is raised if you call ping with a ``payload`` value equal to + an existing in-flight ping. If the remote endpoint recieves multiple + pings, it is allowed to send a single pong. Therefore, the order of + calls to ``ping()`` is tracked, and a pong will wake up its + corresponding ping _as well as any earlier pings_. + + :param payload: The payload to send. If ``None`` then a random value is + created. + :type payload: str, bytes, or None + :raises ConnectionClosed: if connection is closed + ''' + if self._close_reason: + raise ConnectionClosed(self._close_reason) + if payload in self._pings: + raise Exception('Payload value {} is already in flight.'. + format(payload)) + if payload is None: + payload = struct.pack('!I', random.getrandbits(32)) + self._pings[payload] = trio.Event() + self._wsproto.ping(payload) + await self._write_pending() + await self._pings[payload].wait() - Does not wait for pong reply. (Is this the right behavior? This may - change in the future.) Raises ``ConnectionClosed`` if the connection is - closed. + async def pong(self, payload=None): + ''' + Send an unsolicted pong. :param payload: str or bytes payloads :raises ConnectionClosed: if connection is closed @@ -537,18 +565,37 @@ async def _handle_ping_received_event(self, event): :param event: ''' + logger.debug('conn#%d ping %r', self._id, event.payload) await self._write_pending() async def _handle_pong_received_event(self, event): ''' Handle a PongReceived event. - Currently we don't do anything special for a Pong frame, but this may - change in the future. This handler is here as a placeholder. + When a pong is received, check if we have any ping requests waiting for + this pong response. If the remote endpoint skipped any earlier pings, + then we wake up those skipped pings, too. + + This function is async even though it never awaits, because the other + event handlers are async, too, and event dispatch would be more + complicated if some handlers were sync. :param event: ''' - logger.debug('conn#%d pong %r', self._id, event.payload) + payload = bytes(event.payload) + try: + event = self._pings[payload] + except KeyError: + # We received a pong that doesn't match any in-flight pongs. Nothing + # we can do with it, so ignore it. + return + key, event = self._pings.popitem(0) + while key != payload: + logger.debug('conn#%d pong [skipped] %r', self._id, key) + event.set() + key, event = self._pings.popitem(0) + logger.debug('conn#%d pong %r', self._id, key) + event.set() async def _reader_task(self): ''' A background task that reads network data and generates events. '''