Skip to content

Commit

Permalink
better test
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz committed Jan 28, 2025
1 parent 6a2d90e commit 725a538
Showing 1 changed file with 33 additions and 27 deletions.
60 changes: 33 additions & 27 deletions tests/server/utilities/test_messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,18 +536,23 @@ async def handler(message: Message):


@pytest.mark.usefixtures("broker", "clear_topics")
async def test_concurrent_consumers_process_messages(publisher: Publisher) -> None:
@pytest.mark.parametrize("concurrency,num_messages", [(2, 4), (4, 8)])
async def test_concurrent_consumers_process_messages(
publisher: Publisher, concurrency: int, num_messages: int
) -> None:
"""Test that messages are fairly distributed across concurrent consumers"""
concurrent_consumer = create_consumer("my-topic", concurrency=2)
concurrent_consumer = create_consumer("my-topic", concurrency=concurrency)
processed_messages: list[Message] = []
processed_by_consumer: dict[int, list[Message]] = {1: [], 2: []}
processed_by_consumer: dict[int, list[Message]] = {
i + 1: [] for i in range(concurrency)
}
consumer_seen = 0
processing_order: list[int] = []

async def handler(message: Message):
nonlocal consumer_seen
# Track which consumer got the message
consumer_id = consumer_seen % 2 + 1
consumer_id = consumer_seen % concurrency + 1
consumer_seen += 1
processed_by_consumer[consumer_id].append(message)
processed_messages.append(message)
Expand All @@ -557,40 +562,41 @@ async def handler(message: Message):
if consumer_id == 1:
await asyncio.sleep(0.1)

if len(processed_messages) >= 4:
if len(processed_messages) >= num_messages:
raise StopConsumer(ack=True)

consumer_task = asyncio.create_task(concurrent_consumer.run(handler))

try:
async with publisher as p:
# Send multiple messages
for i in range(4):
for i in range(num_messages):
await p.publish_data(f"message-{i}".encode(), {"index": str(i)})
finally:
await consumer_task

# Verify total messages processed
assert len(processed_messages) == 4
assert len(processed_messages) == num_messages

# Verify both consumers processed messages
assert (
len(processed_by_consumer[1]) == 2
), "First consumer should process exactly 2 messages"
# Verify all consumers processed equal number of messages
messages_per_consumer = num_messages // concurrency
for consumer_id in range(1, concurrency + 1):
assert (
len(processed_by_consumer[consumer_id]) == messages_per_consumer
), f"Consumer {consumer_id} should process exactly {messages_per_consumer} messages"

# Verify messages were processed in round-robin order
expected_order = [(i % concurrency) + 1 for i in range(num_messages)]
assert (
len(processed_by_consumer[2]) == 2
), "Second consumer should process exactly 2 messages"

# Verify messages were processed in alternating order despite speed difference
assert processing_order == [
1,
2,
1,
2,
], "Messages should be distributed in round-robin fashion"

# Verify each consumer got alternating messages
assert processed_by_consumer[1][0].attributes["index"] == "0"
assert processed_by_consumer[1][1].attributes["index"] == "2"
assert processed_by_consumer[2][0].attributes["index"] == "1"
assert processed_by_consumer[2][1].attributes["index"] == "3"
processing_order == expected_order
), "Messages should be distributed in round-robin fashion"

# Verify each consumer got the correct messages
for consumer_id in range(1, concurrency + 1):
expected_indices = list(range(consumer_id - 1, num_messages, concurrency))
actual_indices = [
int(msg.attributes["index"]) for msg in processed_by_consumer[consumer_id]
]
assert (
actual_indices == expected_indices
), f"Consumer {consumer_id} should process messages {expected_indices}"

0 comments on commit 725a538

Please sign in to comment.