Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added FlowControl to Protocol #51

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/asynqp/exchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,18 @@ def __init__(self, reader, synchroniser, sender, name, type, durable, auto_delet
self.auto_delete = auto_delete
self.internal = internal

def publish(self, message, routing_key, *, mandatory=True):
def publish(self, message, routing_key, *, mandatory=True, flush_buffers=False):
"""
Publish a message on the exchange, to be asynchronously delivered to queues.

:param asynqp.Message message: the message to send
:param str routing_key: the routing key with which to publish the message
:param bool flush_buffers: If set to `True` this function will return an `awaitable`,
that will wait for all TCP buffers to be flushed.
"""
self.sender.send_BasicPublish(self.name, routing_key, mandatory, message)
if flush_buffers:
return self.sender.drain()

@asyncio.coroutine
def delete(self, *, if_unused=True):
Expand Down
59 changes: 57 additions & 2 deletions src/asynqp/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,61 @@
from .log import log


class AMQP(asyncio.Protocol):
class FlowControl(asyncio.Protocol):
""" Basicly took from asyncio.streams """

def __init__(self, *, loop):
self._loop = loop
self._paused = False
self._drain_waiter = None
self._connection_lost = False

def pause_writing(self):
assert not self._paused
self._paused = True

def resume_writing(self):
assert self._paused
self._paused = False

waiter = self._drain_waiter
if waiter is not None:
self._drain_waiter = None
if not waiter.done():
waiter.set_result(None)

def connection_lost(self, exc):
self._connection_lost = True
# Wake up the writer if currently paused.
if not self._paused:
return
waiter = self._drain_waiter
if waiter is None:
return
self._drain_waiter = None
if waiter.done():
return
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)

@asyncio.coroutine
def _drain_helper(self):
if self._connection_lost:
raise ConnectionResetError('Connection lost')
if not self._paused:
return
waiter = self._drain_waiter
assert waiter is None or waiter.cancelled()
waiter = asyncio.Future(loop=self._loop)
self._drain_waiter = waiter
yield from waiter


class AMQP(FlowControl):
def __init__(self, dispatcher, loop):
super().__init__(loop=loop)
self.dispatcher = dispatcher
self.partial_frame = b''
self.frame_reader = FrameReader()
Expand All @@ -19,7 +72,8 @@ def connection_made(self, transport):

def data_received(self, data):
while data:
self.heartbeat_monitor.heartbeat_received() # the spec says 'any octet may substitute for a heartbeat'
# the spec says 'any octet may substitute for a heartbeat'
self.heartbeat_monitor.heartbeat_received()

try:
result = self.frame_reader.read_frame(data)
Expand Down Expand Up @@ -48,6 +102,7 @@ def start_heartbeat(self, heartbeat_interval):
self.heartbeat_monitor.start(heartbeat_interval)

def connection_lost(self, exc):
super().connection_lost(exc)
# If self._closed=True - we closed the transport ourselves. No need to
# dispatch PoisonPillFrame, as we should have closed everything already
if not self._closed:
Expand Down
8 changes: 8 additions & 0 deletions src/asynqp/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ def __init__(self, channel_id, protocol):
def send_method(self, method):
self.protocol.send_method(self.channel_id, method)

@asyncio.coroutine
def drain(self):
"""
Make sure all outgoing data to be passed to OS TCP buffers.
Will wait if OS buffers are full.
"""
return (yield from self.protocol._drain_helper())


class Actor(object):
def __init__(self, synchroniser, sender, *, loop):
Expand Down
27 changes: 26 additions & 1 deletion test/exchange_tests.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asynqp
import asyncio
import uuid
from datetime import datetime
import asynqp
from unittest.mock import patch

from asynqp import spec
from asynqp import frames
from asynqp import message
Expand Down Expand Up @@ -157,6 +159,29 @@ def it_should_send_multiple_body_frames(self):
], any_order=False)


class WhenPublishingWithFlushBuffers(ExchangeContext):
def given_a_message(self):
self.msg = asynqp.Message(
'body',
)

def when_I_publish_the_message(self):
self._expected = object()
fut = asyncio.Future()
fut.set_result(self._expected)
with patch("asynqp.protocol.AMQP._drain_helper",
return_value=fut):
self.result = self.loop.run_until_complete(self.exchange.publish(
self.msg, 'routing.key', flush_buffers=True))

def it_should_call_drain(self):
assert self.result is self._expected

def it_should_send_message_as_normal(self):
expected_body = frames.ContentBodyFrame(self.channel.id, b'body')
self.server.should_have_received_frame(expected_body)


class WhenDeletingAnExchange(ExchangeContext):
def when_I_delete_the_exchange(self):
self.async_partial(self.exchange.delete(if_unused=True))
Expand Down
106 changes: 106 additions & 0 deletions test/protocol_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from unittest import mock
import contexts
import asynqp
import asyncio
import socket
from asynqp import spec
from asynqp import protocol
from asynqp.exceptions import ConnectionLostError
Expand Down Expand Up @@ -155,3 +157,107 @@ def it_should_raise_a_connection_lost_error(self):

def cleanup(self):
self.loop.set_exception_handler(testing_exception_handler)


class WhenWritingAboveLimit:

DATA_LEN = 10 * 1024 * 1024 # 10Mb should be enough I think

def given_I_have_a_connection_with_low_water(self):
self.loop = asyncio.get_event_loop()

# Bind any free port
self.server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_sock.bind(("127.0.0.1", 0))
self.port = self.server_sock.getsockname()[1]

# Listen on bound socket. Note: we set read limit se we hit write limit
# on the other side
self.loop.run_until_complete(
asyncio.start_server(
self._connected, sock=self.server_sock, loop=self.loop,
limit=100))

self.transport, self.protocol = self.loop.run_until_complete(
self.loop.create_connection(
lambda: protocol.AMQP(mock.Mock(), self.loop),
host="127.0.0.1", port=self.port))
self.transport.set_write_buffer_limits(high=0)

def _connected(self, r, w):
self.reader = r
self.writer = w

def when_we_many_bytes(self):
data = b'x' * self.DATA_LEN
self.transport.write(data)

def it_should_pause_writing_correctly(self):
assert self.protocol._paused
# Launch reader
fut = asyncio.async(
self.reader.readexactly(self.DATA_LEN), loop=self.loop)
# Wait for client transport to drain
self.loop.run_until_complete(self.protocol._drain_helper())
assert not self.protocol._paused
# Destroy reader task
fut.cancel()
del fut


class WhenDisconnectedWritingAboveLimit:

DATA_LEN = 10 * 1024 * 1024 # 10Mb should be enough I think

def given_I_have_a_connection_with_low_water(self):
self.loop = asyncio.get_event_loop()

# Bind any free port
self.server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_sock.bind(("127.0.0.1", 0))
self.port = self.server_sock.getsockname()[1]

# Listen on bound socket. Note: we set read limit se we hit write limit
# on the other side
self.loop.run_until_complete(
asyncio.start_server(
self._connected, sock=self.server_sock, loop=self.loop,
limit=100))

self.transport, self.protocol = self.loop.run_until_complete(
self.loop.create_connection(
lambda: protocol.AMQP(mock.Mock(), self.loop),
host="127.0.0.1", port=self.port))
self.transport.set_write_buffer_limits(high=0)

def _connected(self, r, w):
self.reader = r
self.writer = w

def when_we_many_bytes_and_disconnect(self):
data = b'x' * self.DATA_LEN
self.transport.write(data)
# Set up a waiter
self.waiter = asyncio.async(
self.protocol._drain_helper(), loop=self.loop)
self.writer.transport.close()
try:
self.loop.run_until_complete(self.waiter)
except ConnectionResetError:
pass

def it_should_raise_an_exception_on_new_drain(self):
raised = False
try:
self.loop.run_until_complete(self.protocol._drain_helper())
except ConnectionResetError:
raised = True
assert raised

def it_should_raise_an_exception_on_old_drain(self):
raised = False
try:
self.waiter.result()
except ConnectionResetError:
raised = True
assert raised