Skip to content

Commit

Permalink
[Feature] Discard full channels (#1)
Browse files Browse the repository at this point in the history
## Problem

We have an issue where channels are no longer being consumed from, but they linger around anyway. This requires us to support far more channels than we need to.

Ideally we'd figure out how to discard these channels upon clients disconnecting, but even though we call `group_discard` appropriately, the channels don't seem to be discarded.

## Solution

Add an option that, when enabled, discards any at-capacity channels immediately upon sending a message that channel.

## Testing

- [x] Added an automated test

## Security

No sensitive data is exposed.
  • Loading branch information
nick-merrill authored Feb 16, 2022
1 parent fe27f1a commit a05e92b
Show file tree
Hide file tree
Showing 3 changed files with 139 additions and 16 deletions.
7 changes: 7 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,13 @@ to 10, and all ``websocket.send!`` channels to 20:
If you want to enforce a matching order, use an ``OrderedDict`` as the
argument; channels will then be matched in the order the dict provides them.

``should_auto_discard_full_channels``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

When this option is set to ``True`` and a message is sent to a channel that is at its
maximum capacity (e.g. 100 messages are in a channel whose capacity is 100), the
*entire channel* will be discarded from its group.

``symmetric_encryption_keys``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
59 changes: 47 additions & 12 deletions channels_redis/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,13 +236,15 @@ def __init__(
capacity=100,
channel_capacity=None,
symmetric_encryption_keys=None,
should_auto_discard_full_channels=False,
):
# Store basic information
self.expiry = expiry
self.group_expiry = group_expiry
self.capacity = capacity
self.channel_capacity = self.compile_capacities(channel_capacity or {})
self.prefix = prefix
self.should_auto_discard_full_channels = should_auto_discard_full_channels
assert isinstance(self.prefix, str), "Prefix must be unicode"
# Configure the host objects
self.hosts = self.decode_hosts(hosts)
Expand Down Expand Up @@ -682,6 +684,7 @@ async def group_send(self, group, message):
connection_to_channel_keys,
channel_keys_to_message,
channel_keys_to_capacity,
channel_keys_to_channel_name,
) = self._map_channel_keys_to_connection(channel_names, message)

for connection_index, channel_redis_keys in connection_to_channel_keys.items():
Expand All @@ -699,18 +702,25 @@ async def group_send(self, group, message):
# __asgi_channel__ key.

group_send_lua = """
local over_capacity = 0
local channels_over_capacity = {}
local current_time = ARGV[#ARGV - 1]
local expiry = ARGV[#ARGV]
local num_messages_in_channel
for i=1,#KEYS do
if redis.call('ZCOUNT', KEYS[i], '-inf', '+inf') < tonumber(ARGV[i + #KEYS]) then
redis.call('ZADD', KEYS[i], current_time, ARGV[i])
redis.call('EXPIRE', KEYS[i], expiry)
local channel_capacity = tonumber(ARGV[i + #KEYS])
local channel_name = KEYS[i]
local member = ARGV[i]
num_messages_in_channel = redis.call('ZCOUNT', channel_name, '-inf', '+inf')
if num_messages_in_channel < channel_capacity then
-- Add the member (the message) to the Redis set (our channel)
redis.call('ZADD', channel_name, current_time, member)
-- Update the channel's expiration time (TTL)
redis.call('EXPIRE', channel_name, expiry)
else
over_capacity = over_capacity + 1
channels_over_capacity[#channels_over_capacity+1] = channel_name
end
end
return over_capacity
return channels_over_capacity
"""

# We need to filter the messages to keep those related to the connection
Expand All @@ -725,21 +735,41 @@ async def group_send(self, group, message):
for channel_key in channel_redis_keys
]

# Additional arguments to be accessed by indexes from the end of the args list
args += [time.time(), self.expiry]

# channel_keys does not contain a single redis key more than once
async with self.connection(connection_index) as connection:
channels_over_capacity = await connection.eval(
channel_keys_over_capacity_binary = await connection.eval(
group_send_lua, keys=channel_redis_keys, args=args
)
if channels_over_capacity > 0:
assert isinstance(channel_keys_over_capacity_binary, list)
# The executed Lua script returns strings as binary, so convert to unicode.
channel_keys_over_capacity_unicode = [
val.decode("UTF-8") for val in channel_keys_over_capacity_binary
]
channel_names_over_capacity = [
channel_keys_to_channel_name[val]
for val in channel_keys_over_capacity_unicode
]

if len(channel_names_over_capacity) > 0:
logger.info(
"%s of %s channels over capacity in group %s",
channels_over_capacity,
len(channel_names_over_capacity),
len(channel_names),
group,
)

if self.should_auto_discard_full_channels:
for channel_over_capacity in channel_names_over_capacity:
logger.info(
"Channel %s over capacity. Discarding it from group %s.",
channel_over_capacity,
group,
)
await self.group_discard(group, channel_over_capacity)

def _map_channel_to_connection(self, channel_names, message):
"""
For a list of channel names, bucket each one to a dict keyed by the
Expand All @@ -765,7 +795,6 @@ def _map_channel_to_connection(self, channel_names, message):
connection_to_channels[idx].append(channel_key)
channel_to_capacity[channel] = self.get_capacity(channel)
channel_to_message[channel] = self.serialize(message)
# We build a
channel_to_key[channel] = channel_key

return (
Expand All @@ -785,14 +814,18 @@ def _map_channel_keys_to_connection(self, channel_names, message):
the list of channels mapped to that redis key in __asgi_channel__ key to the message
3. returns a mapping of redis channels keys to their capacity
4. returns a mapping of redis channel keys to their channel names
"""

# Connection dict keyed by index to list of redis keys mapped on that index
connection_to_channel_keys = collections.defaultdict(list)
# Message dict maps redis key to the message that needs to be send on that key
# Message dict maps redis key to the message that needs to be sent on that key
channel_key_to_message = dict()
# Channel key mapped to its capacity
channel_key_to_capacity = dict()
# Channel key mapped to channel name
channel_key_to_channel_name = dict()

# For each channel
for channel in channel_names:
Expand All @@ -801,6 +834,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
channel_non_local_name = self.non_local_name(channel)
# Get its redis key
channel_key = self.prefix + channel_non_local_name
channel_key_to_channel_name[channel_key] = channel
# Have we come across the same redis key?
if channel_key not in channel_key_to_message:
# If not, fill the corresponding dicts
Expand All @@ -814,7 +848,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
# Yes, Append the channel in message dict
channel_key_to_message[channel_key]["__asgi_channel__"].append(channel)

# Now that we know what message needs to be send on a redis key we serialize it
# Now that we know what message needs to be sent on a redis key, we serialize it
for key, value in channel_key_to_message.items():
# Serialize the message stored for each redis key
channel_key_to_message[key] = self.serialize(value)
Expand All @@ -823,6 +857,7 @@ def _map_channel_keys_to_connection(self, channel_names, message):
connection_to_channel_keys,
channel_key_to_message,
channel_key_to_capacity,
channel_key_to_channel_name,
)

def _group_key(self, group):
Expand Down
89 changes: 85 additions & 4 deletions tests/test_core.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import logging
import random
import re

import async_timeout
import pytest
Expand Down Expand Up @@ -67,6 +69,21 @@ async def channel_layer():
await channel_layer.flush()


@pytest.fixture()
@async_generator
async def channel_layer_with_option_auto_discard_full_channels():
"""
Channel layer fixture that flushes automatically.
"""
channel_layer = RedisChannelLayer(
hosts=TEST_HOSTS,
capacity=3,
should_auto_discard_full_channels=True,
)
await yield_(channel_layer)
await channel_layer.flush()


@pytest.fixture()
@async_generator
async def channel_layer_multiple_hosts():
Expand Down Expand Up @@ -368,6 +385,7 @@ async def test_group_send_capacity_multiple_channels(channel_layer, caplog):
Makes sure we dont group_send messages to channels that are over capacity
Make sure number of channels with full capacity are logged as an exception to help debug errors.
"""
caplog.set_level(logging.INFO)

channel_1 = await channel_layer.new_channel()
channel_2 = await channel_layer.new_channel(prefix="channel_2")
Expand Down Expand Up @@ -397,11 +415,74 @@ async def test_group_send_capacity_multiple_channels(channel_layer, caplog):
await channel_layer.receive(channel_2)

# Make sure number of channels over capacity are logged
for record in caplog.records:
assert record.levelname == "INFO"
assert (
record.getMessage() == "1 of 2 channels over capacity in group test-group"
assert caplog.record_tuples == [
(
"channels_redis.core",
logging.INFO,
"1 of 2 channels over capacity in group test-group",
)
]


@pytest.mark.asyncio
async def test_group_send_with_auto_discard_full_channels(
channel_layer_with_option_auto_discard_full_channels, caplog
):
"""
Tests when the should_auto_discard_full_channels option is enabled, a full channel is discarded
"""
caplog.set_level(logging.INFO)

channel_layer = channel_layer_with_option_auto_discard_full_channels

channel_1 = await channel_layer.new_channel()
channel_2 = await channel_layer.new_channel(prefix="channel_2")
await channel_layer.group_add("test-group", channel_1)
await channel_layer.group_add("test-group", channel_2)

# Let's help channel_2 get over capacity later in the test
await channel_layer.send(channel_2, {"type": "message.0"})

await channel_layer.group_send("test-group", {"type": "message.1"})
await channel_layer.group_send("test-group", {"type": "message.2"})
await channel_layer.group_send("test-group", {"type": "message.3"})

# Channel_1 should receive all 3 group messages
assert (await channel_layer.receive(channel_1))["type"] == "message.1"
assert (await channel_layer.receive(channel_1))["type"] == "message.2"
assert (await channel_layer.receive(channel_1))["type"] == "message.3"

# Channel_2 should receive the first message + 2 group messages (given the capacity is 3)
assert (await channel_layer.receive(channel_2))["type"] == "message.0"
assert (await channel_layer.receive(channel_2))["type"] == "message.1"
assert (await channel_layer.receive(channel_2))["type"] == "message.2"

# Make sure channel_2 does not receive the 3rd group message
with pytest.raises(asyncio.TimeoutError):
async with async_timeout.timeout(1):
await channel_layer.receive(channel_2)

# Make sure discarded channels are logged
assert len(caplog.records) == 2
assert caplog.record_tuples[0] == (
"channels_redis.core",
logging.INFO,
"1 of 2 channels over capacity in group test-group",
)

assert caplog.records[1].levelname == "INFO"
assert re.match(
r"Channel channel_2\..* over capacity. Discarding it from group test-group.",
caplog.records[1].message,
)

# Make sure channel_1 still receives a new message, while channel_2 does not (because it is
# no longer part of the group)
await channel_layer.group_send("test-group", {"type": "message.4"})
assert (await channel_layer.receive(channel_1))["type"] == "message.4"
with pytest.raises(asyncio.TimeoutError):
async with async_timeout.timeout(1):
await channel_layer.receive(channel_2)


@pytest.mark.asyncio
Expand Down

0 comments on commit a05e92b

Please sign in to comment.