Skip to content

Commit

Permalink
use task groups for concurrent queue consumers (#16850)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Jan 29, 2025
1 parent 180185e commit 85f58f4
Show file tree
Hide file tree
Showing 3 changed files with 174 additions and 22 deletions.
9 changes: 9 additions & 0 deletions src/prefect/server/services/task_run_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
MessageHandler,
create_consumer,
)
from prefect.server.utilities.messaging.memory import log_metrics_periodically

if TYPE_CHECKING:
import logging
Expand Down Expand Up @@ -219,6 +220,7 @@ class TaskRunRecorder:
name: str = "TaskRunRecorder"

consumer_task: asyncio.Task[None] | None = None
metrics_task: asyncio.Task[None] | None = None

def __init__(self):
self._started_event: Optional[asyncio.Event] = None
Expand All @@ -239,6 +241,8 @@ async def start(self) -> None:

async with consumer() as handler:
self.consumer_task = asyncio.create_task(self.consumer.run(handler))
self.metrics_task = asyncio.create_task(log_metrics_periodically())

logger.debug("TaskRunRecorder started")
self.started_event.set()

Expand All @@ -250,10 +254,15 @@ async def start(self) -> None:
async def stop(self) -> None:
assert self.consumer_task is not None, "Logger not started"
self.consumer_task.cancel()
if self.metrics_task:
self.metrics_task.cancel()
try:
await self.consumer_task
if self.metrics_task:
await self.metrics_task
except asyncio.CancelledError:
pass
finally:
self.consumer_task = None
self.metrics_task = None
logger.debug("TaskRunRecorder stopped")
95 changes: 86 additions & 9 deletions src/prefect/server/utilities/messaging/memory.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from __future__ import annotations

import asyncio
import copy
import threading
from collections import defaultdict
from collections.abc import AsyncGenerator, Iterable, Mapping, MutableMapping
from contextlib import asynccontextmanager
from dataclasses import asdict, dataclass
from datetime import timedelta
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from uuid import uuid4

import anyio
from cachetools import TTLCache
from exceptiongroup import BaseExceptionGroup # novermin
from pydantic_core import to_json
from typing_extensions import Self

Expand All @@ -26,6 +31,44 @@

logger: "logging.Logger" = get_logger(__name__)

# Simple global counters by topic with thread-safe access
_metrics_lock: threading.Lock | None = None
METRICS: dict[str, dict[str, int]] = defaultdict(
lambda: {
"published": 0,
"retried": 0,
"consumed": 0,
}
)


async def log_metrics_periodically(interval: float = 2.0) -> None:
if _metrics_lock is None:
return
while True:
await asyncio.sleep(interval)
with _metrics_lock:
for topic, data in METRICS.items():
if data["published"] == 0:
continue
depth = data["published"] - data["consumed"]
logger.debug(
"Topic=%r | published=%d consumed=%d retried=%d depth=%d",
topic,
data["published"],
data["consumed"],
data["retried"],
depth,
)


async def update_metric(topic: str, key: str, amount: int = 1) -> None:
global _metrics_lock
if _metrics_lock is None:
_metrics_lock = threading.Lock()
with _metrics_lock:
METRICS[topic][key] += amount


@dataclass
class MemoryMessage:
Expand Down Expand Up @@ -58,7 +101,7 @@ def __init__(
self,
topic: "Topic",
max_retries: int = 3,
dead_letter_queue_path: Union[Path, str, None] = None,
dead_letter_queue_path: Path | str | None = None,
) -> None:
self.topic = topic
self.max_retries = max_retries
Expand All @@ -78,6 +121,13 @@ async def deliver(self, message: MemoryMessage) -> None:
message: The message to deliver.
"""
await self._queue.put(message)
await update_metric(self.topic.name, "published")
logger.debug(
"Delivered message to topic=%r queue_size=%d retry_queue_size=%d",
self.topic.name,
self._queue.qsize(),
self._retry.qsize(),
)

async def retry(self, message: MemoryMessage) -> None:
"""
Expand All @@ -99,6 +149,14 @@ async def retry(self, message: MemoryMessage) -> None:
await self.send_to_dead_letter_queue(message)
else:
await self._retry.put(message)
await update_metric(self.topic.name, "retried")
logger.debug(
"Retried message on topic=%r retry_count=%d queue_size=%d retry_queue_size=%d",
self.topic.name,
message.retry_count,
self._queue.qsize(),
self._retry.qsize(),
)

async def get(self) -> MemoryMessage:
"""
Expand Down Expand Up @@ -152,8 +210,8 @@ def clear_all(cls) -> None:
topic.clear()
cls._topics = {}

def subscribe(self) -> Subscription:
subscription = Subscription(self)
def subscribe(self, **subscription_kwargs: Any) -> Subscription:
subscription = Subscription(self, **subscription_kwargs)
self._subscriptions.append(subscription)
return subscription

Expand Down Expand Up @@ -243,9 +301,9 @@ async def __aenter__(self) -> Self:

async def __aexit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
return None

Expand All @@ -266,22 +324,41 @@ async def publish_data(self, data: bytes, attributes: Mapping[str, str]) -> None


class Consumer(_Consumer):
def __init__(self, topic: str, subscription: Optional[Subscription] = None):
def __init__(
self,
topic: str,
subscription: Optional[Subscription] = None,
concurrency: int = 2,
):
self.topic: Topic = Topic.by_name(topic)
if not subscription:
subscription = self.topic.subscribe()
assert subscription.topic is self.topic
self.subscription = subscription
self.concurrency = concurrency

async def run(self, handler: MessageHandler) -> None:
try:
async with anyio.create_task_group() as tg:
for _ in range(self.concurrency):
tg.start_soon(self._consume_loop, handler)
except BaseExceptionGroup as group: # novermin
if all(isinstance(exc, StopConsumer) for exc in group.exceptions):
logger.debug("StopConsumer received")
return # Exit cleanly when all tasks stop
# Re-raise if any non-StopConsumer exceptions
raise group

async def _consume_loop(self, handler: MessageHandler) -> None:
while True:
message = await self.subscription.get()
try:
await handler(message)
await update_metric(self.topic.name, "consumed")
except StopConsumer as e:
if not e.ack:
await self.subscription.retry(message)
return
raise # Propagate to task group
except Exception:
await self.subscription.retry(message)

Expand Down
92 changes: 79 additions & 13 deletions tests/server/utilities/test_messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
AsyncGenerator,
Callable,
Generator,
List,
Optional,
)

Expand Down Expand Up @@ -124,11 +123,11 @@ async def publisher(broker: str, cache: Cache) -> Publisher:

@pytest.fixture
async def consumer(broker: str, clear_topics: None) -> Consumer:
return create_consumer("my-topic")
return create_consumer("my-topic", concurrency=1)


async def drain_one(consumer: Consumer) -> Optional[Message]:
captured_messages: List[Message] = []
captured_messages: list[Message] = []

async def handler(message: Message):
captured_messages.append(message)
Expand All @@ -143,7 +142,7 @@ async def handler(message: Message):
async def test_publishing_and_consuming_a_single_message(
publisher: Publisher, consumer: Consumer
) -> None:
captured_messages: List[Message] = []
captured_messages: list[Message] = []

async def handler(message: Message):
captured_messages.append(message)
Expand All @@ -169,7 +168,7 @@ async def handler(message: Message):
async def test_stopping_consumer_without_acking(
publisher: Publisher, consumer: Consumer
) -> None:
captured_messages: List[Message] = []
captured_messages: list[Message] = []

async def handler(message: Message):
captured_messages.append(message)
Expand All @@ -195,7 +194,7 @@ async def handler(message: Message):
async def test_erroring_handler_does_not_ack(
publisher: Publisher, consumer: Consumer
) -> None:
captured_messages: List[Message] = []
captured_messages: list[Message] = []

async def handler(message: Message):
captured_messages.append(message)
Expand Down Expand Up @@ -228,7 +227,7 @@ def deduplicating_publisher(broker: str, cache: Cache) -> Publisher:
async def test_publisher_will_avoid_sending_duplicate_messages_in_same_batch(
deduplicating_publisher: Publisher, consumer: Consumer
):
captured_messages: List[Message] = []
captured_messages: list[Message] = []

async def handler(message: Message):
captured_messages.append(message)
Expand Down Expand Up @@ -259,7 +258,7 @@ async def handler(message: Message):
async def test_publisher_will_avoid_sending_duplicate_messages_in_different_batches(
deduplicating_publisher: Publisher, consumer: Consumer
):
captured_messages: List[Message] = []
captured_messages: list[Message] = []

async def handler(message: Message):
captured_messages.append(message)
Expand Down Expand Up @@ -328,7 +327,7 @@ async def test_publisher_will_forget_duplicate_messages_on_error(
assert not remaining_message

# but on a subsequent attempt, the message is published and not considered duplicate
captured_messages: List[Message] = []
captured_messages: list[Message] = []

async def handler(message: Message):
captured_messages.append(message)
Expand Down Expand Up @@ -356,7 +355,7 @@ async def handler(message: Message):
async def test_publisher_does_not_interfere_with_duplicate_messages_without_id(
deduplicating_publisher: Publisher, consumer: Consumer
):
captured_messages: List[Message] = []
captured_messages: list[Message] = []

async def handler(message: Message):
captured_messages.append(message)
Expand Down Expand Up @@ -402,7 +401,7 @@ async def test_publisher_does_not_interfere_with_duplicate_messages_without_id_o
assert not remaining_message

# but on a subsequent attempt, the message is published
captured_messages: List[Message] = []
captured_messages: list[Message] = []

async def handler(message: Message):
captured_messages.append(message)
Expand All @@ -427,7 +426,7 @@ async def handler(message: Message):


async def test_ephemeral_subscription(broker: str, publisher: Publisher):
captured_messages: List[Message] = []
captured_messages: list[Message] = []

async def handler(message: Message):
captured_messages.append(message)
Expand Down Expand Up @@ -461,7 +460,7 @@ async def test_repeatedly_failed_message_is_moved_to_dead_letter_queue(
consumer: MemoryConsumer,
tmp_path: Path,
):
captured_messages: List[Message] = []
captured_messages: list[Message] = []

async def handler(message: Message):
captured_messages.append(message)
Expand Down Expand Up @@ -534,3 +533,70 @@ async def handler(message: Message):
await consumer_task

assert captured_events == [emitted_event]


@pytest.mark.usefixtures("broker", "clear_topics")
@pytest.mark.parametrize("concurrency,num_messages", [(2, 4), (4, 8)])
async def test_concurrent_consumers_process_messages(
publisher: Publisher, concurrency: int, num_messages: int
) -> None:
"""Test that messages are fairly distributed across concurrent consumers"""
concurrent_consumer = create_consumer("my-topic", concurrency=concurrency)
processed_messages: list[Message] = []
processed_by_consumer: dict[int, list[Message]] = {
i + 1: [] for i in range(concurrency)
}
consumer_seen = 0
processing_order: list[int] = []

async def handler(message: Message):
nonlocal consumer_seen
# Track which consumer got the message
consumer_id = consumer_seen % concurrency + 1
consumer_seen += 1
processed_by_consumer[consumer_id].append(message)
processed_messages.append(message)
processing_order.append(consumer_id)

# First consumer is slow but should still get its fair share
if consumer_id == 1:
await asyncio.sleep(0.1)

if len(processed_messages) >= num_messages:
raise StopConsumer(ack=True)

consumer_task = asyncio.create_task(concurrent_consumer.run(handler))

try:
async with publisher as p:
# Send multiple messages
for i in range(num_messages):
await p.publish_data(f"message-{i}".encode(), {"index": str(i)})
finally:
await consumer_task

# Verify total messages processed
assert len(processed_messages) == num_messages

# Verify all consumers processed equal number of messages
messages_per_consumer = num_messages // concurrency
for consumer_id in range(1, concurrency + 1):
assert (
len(processed_by_consumer[consumer_id]) == messages_per_consumer
), f"Consumer {consumer_id} should process exactly {messages_per_consumer} messages"

# Verify messages were processed in round-robin order
expected_order = [(i % concurrency) + 1 for i in range(num_messages)]
assert (
processing_order == expected_order
), "Messages should be distributed in round-robin fashion"

# Verify each consumer got the correct messages
for consumer_id in range(1, concurrency + 1):
expected_indices = list(range(consumer_id - 1, num_messages, concurrency))
actual_indices = [
int(msg.attributes["index"]) for msg in processed_by_consumer[consumer_id]
]
assert (
actual_indices == expected_indices
), f"Consumer {consumer_id} should process messages {expected_indices}"

0 comments on commit 85f58f4

Please sign in to comment.