From d38f9e96af84044f428fe977a3e7aa9289c8cc82 Mon Sep 17 00:00:00 2001 From: Taras Date: Sun, 4 Oct 2015 18:27:01 +0300 Subject: [PATCH 1/3] Added FlowControl to Protocol, so we can wait for writes if TCP buffers are full. --- src/asynqp/exchange.py | 6 ++++- src/asynqp/protocol.py | 59 ++++++++++++++++++++++++++++++++++++++++-- src/asynqp/routing.py | 8 ++++++ test/exchange_tests.py | 27 ++++++++++++++++++- 4 files changed, 96 insertions(+), 4 deletions(-) diff --git a/src/asynqp/exchange.py b/src/asynqp/exchange.py index 531c7e7..47a3fd7 100644 --- a/src/asynqp/exchange.py +++ b/src/asynqp/exchange.py @@ -32,14 +32,18 @@ def __init__(self, reader, synchroniser, sender, name, type, durable, auto_delet self.auto_delete = auto_delete self.internal = internal - def publish(self, message, routing_key, *, mandatory=True): + def publish(self, message, routing_key, *, mandatory=True, flush_buffers=False): """ Publish a message on the exchange, to be asynchronously delivered to queues. :param asynqp.Message message: the message to send :param str routing_key: the routing key with which to publish the message + :param bool flush_buffers: If set to `True` this function will return an `awaitable`, + that will wait for all TCP buffers to be flushed. """ self.sender.send_BasicPublish(self.name, routing_key, mandatory, message) + if flush_buffers: + return self.sender.drain() @asyncio.coroutine def delete(self, *, if_unused=True): diff --git a/src/asynqp/protocol.py b/src/asynqp/protocol.py index 8ac26f2..204019f 100644 --- a/src/asynqp/protocol.py +++ b/src/asynqp/protocol.py @@ -6,8 +6,61 @@ from .log import log -class AMQP(asyncio.Protocol): +class FlowControl(asyncio.Protocol): + """ Basicly took from asyncio.streams """ + + def __init__(self, *, loop): + self._loop = loop + self._paused = False + self._drain_waiter = None + self._connection_lost = False + + def pause_writing(self): + assert not self._paused + self._paused = True + + def resume_writing(self): + assert self._paused + self._paused = False + + waiter = self._drain_waiter + if waiter is not None: + self._drain_waiter = None + if not waiter.done(): + waiter.set_result(None) + + def connection_lost(self, exc): + self._connection_lost = True + # Wake up the writer if currently paused. + if not self._paused: + return + waiter = self._drain_waiter + if waiter is None: + return + self._drain_waiter = None + if waiter.done(): + return + if exc is None: + waiter.set_result(None) + else: + waiter.set_exception(exc) + + @asyncio.coroutine + def _drain_helper(self): + if self._connection_lost: + raise ConnectionResetError('Connection lost') + if not self._paused: + return + waiter = self._drain_waiter + assert waiter is None or waiter.cancelled() + waiter = asyncio.Future(loop=self._loop) + self._drain_waiter = waiter + yield from waiter + + +class AMQP(FlowControl): def __init__(self, dispatcher, loop): + super().__init__(loop=loop) self.dispatcher = dispatcher self.partial_frame = b'' self.frame_reader = FrameReader() @@ -19,7 +72,8 @@ def connection_made(self, transport): def data_received(self, data): while data: - self.heartbeat_monitor.heartbeat_received() # the spec says 'any octet may substitute for a heartbeat' + # the spec says 'any octet may substitute for a heartbeat' + self.heartbeat_monitor.heartbeat_received() try: result = self.frame_reader.read_frame(data) @@ -48,6 +102,7 @@ def start_heartbeat(self, heartbeat_interval): self.heartbeat_monitor.start(heartbeat_interval) def connection_lost(self, exc): + super().connection_lost(exc) # If self._closed=True - we closed the transport ourselves. No need to # dispatch PoisonPillFrame, as we should have closed everything already if not self._closed: diff --git a/src/asynqp/routing.py b/src/asynqp/routing.py index 637f97b..69d7abb 100644 --- a/src/asynqp/routing.py +++ b/src/asynqp/routing.py @@ -33,6 +33,14 @@ def __init__(self, channel_id, protocol): def send_method(self, method): self.protocol.send_method(self.channel_id, method) + @asyncio.coroutine + def drain(self): + """ + Make sure all outgoing data to be passed to OS TCP buffers. + Will wait if OS buffers are full. + """ + return (yield from self.protocol._drain_helper()) + class Actor(object): def __init__(self, synchroniser, sender, *, loop): diff --git a/test/exchange_tests.py b/test/exchange_tests.py index ff7161b..459d770 100644 --- a/test/exchange_tests.py +++ b/test/exchange_tests.py @@ -1,7 +1,9 @@ +import asynqp import asyncio import uuid from datetime import datetime -import asynqp +from unittest.mock import patch + from asynqp import spec from asynqp import frames from asynqp import message @@ -157,6 +159,29 @@ def it_should_send_multiple_body_frames(self): ], any_order=False) +class WhenPublishingWithFlushBuffers(ExchangeContext): + def given_a_message(self): + self.msg = asynqp.Message( + 'body', + ) + + def when_I_publish_the_message(self): + self._expected = object() + fut = asyncio.Future() + fut.set_result(self._expected) + with patch("asynqp.protocol.AMQP._drain_helper", + return_value=fut): + self.result = self.loop.run_until_complete(self.exchange.publish( + self.msg, 'routing.key', flush_buffers=True)) + + def it_should_call_drain(self): + assert self.result is self._expected + + def it_should_send_message_as_normal(self): + expected_body = frames.ContentBodyFrame(self.channel.id, b'body') + self.server.should_have_received_frame(expected_body) + + class WhenDeletingAnExchange(ExchangeContext): def when_I_delete_the_exchange(self): self.async_partial(self.exchange.delete(if_unused=True)) From 781d5003761d0d53494c1504cd47fcf9f57f23a2 Mon Sep 17 00:00:00 2001 From: Taras Date: Sun, 4 Oct 2015 20:45:20 +0300 Subject: [PATCH 2/3] Added simple flow control test for protocol --- test/protocol_tests.py | 48 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/test/protocol_tests.py b/test/protocol_tests.py index e9ef554..3b2167d 100644 --- a/test/protocol_tests.py +++ b/test/protocol_tests.py @@ -1,6 +1,8 @@ from unittest import mock import contexts import asynqp +import asyncio +import socket from asynqp import spec from asynqp import protocol from asynqp.exceptions import ConnectionLostError @@ -155,3 +157,49 @@ def it_should_raise_a_connection_lost_error(self): def cleanup(self): self.loop.set_exception_handler(testing_exception_handler) + + +class WhenWritingAboveLimit: + + DATA_LEN = 10 * 1024 * 1024 # 1Mb should be enough I think + + def given_I_have_a_connection_with_low_water(self): + self.loop = asyncio.get_event_loop() + + # Bind any free port + self.server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_sock.bind(("127.0.0.1", 0)) + self.port = self.server_sock.getsockname()[1] + + # Listen on bound socket. Note: we set read limit se we hit write limit + # on the other side + self.loop.run_until_complete( + asyncio.start_server( + self._connected, sock=self.server_sock, loop=self.loop, + limit=100)) + + self.transport, self.protocol = self.loop.run_until_complete( + self.loop.create_connection( + lambda: protocol.AMQP(mock.Mock(), self.loop), + host="127.0.0.1", port=self.port)) + self.transport.set_write_buffer_limits(high=0) + + def _connected(self, r, w): + self.reader = r + self.writer = w + + def when_we_many_bytes(self): + data = b'x' * self.DATA_LEN + self.transport.write(data) + + def it_should_pause_writing_correctly(self): + assert self.protocol._paused + # Launch reader + fut = asyncio.async( + self.reader.readexactly(self.DATA_LEN), loop=self.loop) + # Wait for client transport to drain + self.loop.run_until_complete(self.protocol._drain_helper()) + assert not self.protocol._paused + # Destroy reader task + fut.cancel() + del fut From bccea8c0043af52fa036dbe448738263e53ef765 Mon Sep 17 00:00:00 2001 From: Taras Date: Sun, 4 Oct 2015 21:10:12 +0300 Subject: [PATCH 3/3] Another test for flow control --- test/protocol_tests.py | 60 +++++++++++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/test/protocol_tests.py b/test/protocol_tests.py index 3b2167d..b2a408d 100644 --- a/test/protocol_tests.py +++ b/test/protocol_tests.py @@ -161,7 +161,7 @@ def cleanup(self): class WhenWritingAboveLimit: - DATA_LEN = 10 * 1024 * 1024 # 1Mb should be enough I think + DATA_LEN = 10 * 1024 * 1024 # 10Mb should be enough I think def given_I_have_a_connection_with_low_water(self): self.loop = asyncio.get_event_loop() @@ -203,3 +203,61 @@ def it_should_pause_writing_correctly(self): # Destroy reader task fut.cancel() del fut + + +class WhenDisconnectedWritingAboveLimit: + + DATA_LEN = 10 * 1024 * 1024 # 10Mb should be enough I think + + def given_I_have_a_connection_with_low_water(self): + self.loop = asyncio.get_event_loop() + + # Bind any free port + self.server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.server_sock.bind(("127.0.0.1", 0)) + self.port = self.server_sock.getsockname()[1] + + # Listen on bound socket. Note: we set read limit se we hit write limit + # on the other side + self.loop.run_until_complete( + asyncio.start_server( + self._connected, sock=self.server_sock, loop=self.loop, + limit=100)) + + self.transport, self.protocol = self.loop.run_until_complete( + self.loop.create_connection( + lambda: protocol.AMQP(mock.Mock(), self.loop), + host="127.0.0.1", port=self.port)) + self.transport.set_write_buffer_limits(high=0) + + def _connected(self, r, w): + self.reader = r + self.writer = w + + def when_we_many_bytes_and_disconnect(self): + data = b'x' * self.DATA_LEN + self.transport.write(data) + # Set up a waiter + self.waiter = asyncio.async( + self.protocol._drain_helper(), loop=self.loop) + self.writer.transport.close() + try: + self.loop.run_until_complete(self.waiter) + except ConnectionResetError: + pass + + def it_should_raise_an_exception_on_new_drain(self): + raised = False + try: + self.loop.run_until_complete(self.protocol._drain_helper()) + except ConnectionResetError: + raised = True + assert raised + + def it_should_raise_an_exception_on_old_drain(self): + raised = False + try: + self.waiter.result() + except ConnectionResetError: + raised = True + assert raised