From e8e26a903f49c4467f4c046d7ac1c0571083400e Mon Sep 17 00:00:00 2001 From: "Mark E. Haase" Date: Tue, 16 Oct 2018 11:39:09 -0400 Subject: [PATCH] Improve ping handling (#22) As suggested in this issue thread, there needs to be a way for clients that send infrequent messages to make sure that the connection is still open. Ping/pong are the natural way to do this, but the ping behavior needs to be a bit more robust. This commit adds a `wait_pong()` method that waits for a pong to arrive and returns its payload. --- tests/test_connection.py | 10 +++++++ trio_websocket/__init__.py | 60 +++++++++++++++++++++++++++++++++----- 2 files changed, 63 insertions(+), 7 deletions(-) diff --git a/tests/test_connection.py b/tests/test_connection.py index 8a1465d..f6b775b 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -207,6 +207,16 @@ async def test_client_send_and_receive(echo_conn): assert received_msg == 'This is a test message.' +async def test_client_ping_pong(echo_conn): + async with echo_conn: + await echo_conn.ping(b'test-payload-1') + pong1 = await echo_conn.wait_pong() + assert pong1 == b'test-payload-1' + await echo_conn.ping(b'test-payload-2') + pong2 = await echo_conn.wait_pong() + assert pong2 == b'test-payload-2' + + async def test_client_default_close(echo_conn): async with echo_conn: assert not echo_conn.is_closed diff --git a/trio_websocket/__init__.py b/trio_websocket/__init__.py index ed687a5..f5206c0 100644 --- a/trio_websocket/__init__.py +++ b/trio_websocket/__init__.py @@ -297,6 +297,36 @@ def __repr__(self): self.code, self.name, self.reason) +class Future: + ''' Represents a future value. ''' + def __init__(self): + ''' Constructor. ''' + self._event = trio.Event() + self._value = None + + def set(self, value): + ''' + Set the future's value. + + This can only be called once. + + :param value: The value to set. + ''' + if self._event.is_set(): + raise Exception('This future is already set.') + self._value = value + self._event.set() + + async def wait(self): + ''' + Wait for the future to have a value. + + :returns: The future's value. + ''' + await self._event.wait() + return self._value + + class WebSocketConnection(trio.abc.AsyncResource): ''' A WebSocket connection. ''' @@ -320,6 +350,7 @@ def __init__(self, stream, wsproto, path=None): self._reader_running = True self._path = path self._put_channel, self._get_channel = open_channel(0) + self._pong_future = Future() # Set once the WebSocket open handshake takes place, i.e. # ConnectionRequested for server or ConnectedEstablished for client. self._open_handshake = trio.Event() @@ -398,15 +429,15 @@ async def get_message(self): raise ConnectionClosed(self._close_reason) from None return message - async def ping(self, payload): + async def ping(self, payload=None): ''' Send WebSocket ping to peer. - Does not wait for pong reply. (Is this the right behavior? This may - change in the future.) Raises ``ConnectionClosed`` if the connection is - closed. + Does not wait for pong reply. Use the ``wait_pong()`` method if you + want to wait for the pong. - :param payload: str or bytes payloads + :param payload: + :type payload: str, bytes, or None :raises ConnectionClosed: if connection is closed ''' if self._close_reason: @@ -428,6 +459,16 @@ async def send_message(self, message): self._wsproto.send_data(message) await self._write_pending() + async def wait_pong(self): + ''' + Wait for a pong. + + :returns: The pong's payload. + :rtype: bytes + ''' + value = await self._pong_future.wait() + return value + def _abort_web_socket(self): ''' If a stream is closed outside of this class, e.g. due to network @@ -537,18 +578,23 @@ async def _handle_ping_received_event(self, event): :param event: ''' + logger.debug('conn#%d ping %r', self._id, event.payload) await self._write_pending() async def _handle_pong_received_event(self, event): ''' Handle a PongReceived event. - Currently we don't do anything special for a Pong frame, but this may - change in the future. This handler is here as a placeholder. + This will send the pong's payload to all tasks that are waiting for + ``wait_pong()``. This method is async even though it never awaits, + because all of the other event handlers are async and this simplifies + event dispatch. :param event: ''' logger.debug('conn#%d pong %r', self._id, event.payload) + self._pong_future.set(event.payload) + self._pong_future = Future() async def _reader_task(self): ''' A background task that reads network data and generates events. '''