diff --git a/src/prefect/server/services/task_run_recorder.py b/src/prefect/server/services/task_run_recorder.py index 78ef9c3a6003..e40c4a23dd3f 100644 --- a/src/prefect/server/services/task_run_recorder.py +++ b/src/prefect/server/services/task_run_recorder.py @@ -25,6 +25,7 @@ MessageHandler, create_consumer, ) +from prefect.server.utilities.messaging.memory import log_metrics_periodically if TYPE_CHECKING: import logging @@ -219,6 +220,7 @@ class TaskRunRecorder: name: str = "TaskRunRecorder" consumer_task: asyncio.Task[None] | None = None + metrics_task: asyncio.Task[None] | None = None def __init__(self): self._started_event: Optional[asyncio.Event] = None @@ -239,6 +241,8 @@ async def start(self) -> None: async with consumer() as handler: self.consumer_task = asyncio.create_task(self.consumer.run(handler)) + self.metrics_task = asyncio.create_task(log_metrics_periodically()) + logger.debug("TaskRunRecorder started") self.started_event.set() @@ -250,10 +254,15 @@ async def start(self) -> None: async def stop(self) -> None: assert self.consumer_task is not None, "Logger not started" self.consumer_task.cancel() + if self.metrics_task: + self.metrics_task.cancel() try: await self.consumer_task + if self.metrics_task: + await self.metrics_task except asyncio.CancelledError: pass finally: self.consumer_task = None + self.metrics_task = None logger.debug("TaskRunRecorder stopped") diff --git a/src/prefect/server/utilities/messaging/memory.py b/src/prefect/server/utilities/messaging/memory.py index ca3d3168d589..07249ba7ae4c 100644 --- a/src/prefect/server/utilities/messaging/memory.py +++ b/src/prefect/server/utilities/messaging/memory.py @@ -1,16 +1,21 @@ +from __future__ import annotations + import asyncio import copy +import threading +from collections import defaultdict from collections.abc import AsyncGenerator, Iterable, Mapping, MutableMapping from contextlib import asynccontextmanager from dataclasses import asdict, dataclass from datetime import timedelta from pathlib import Path from types import TracebackType -from typing import TYPE_CHECKING, Any, Optional, Type, TypeVar, Union +from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union from uuid import uuid4 import anyio from cachetools import TTLCache +from exceptiongroup import BaseExceptionGroup # novermin from pydantic_core import to_json from typing_extensions import Self @@ -26,6 +31,44 @@ logger: "logging.Logger" = get_logger(__name__) +# Simple global counters by topic with thread-safe access +_metrics_lock: threading.Lock | None = None +METRICS: dict[str, dict[str, int]] = defaultdict( + lambda: { + "published": 0, + "retried": 0, + "consumed": 0, + } +) + + +async def log_metrics_periodically(interval: float = 2.0) -> None: + if _metrics_lock is None: + return + while True: + await asyncio.sleep(interval) + with _metrics_lock: + for topic, data in METRICS.items(): + if data["published"] == 0: + continue + depth = data["published"] - data["consumed"] + logger.debug( + "Topic=%r | published=%d consumed=%d retried=%d depth=%d", + topic, + data["published"], + data["consumed"], + data["retried"], + depth, + ) + + +async def update_metric(topic: str, key: str, amount: int = 1) -> None: + global _metrics_lock + if _metrics_lock is None: + _metrics_lock = threading.Lock() + with _metrics_lock: + METRICS[topic][key] += amount + @dataclass class MemoryMessage: @@ -58,7 +101,7 @@ def __init__( self, topic: "Topic", max_retries: int = 3, - dead_letter_queue_path: Union[Path, str, None] = None, + dead_letter_queue_path: Path | str | None = None, ) -> None: self.topic = topic self.max_retries = max_retries @@ -78,6 +121,13 @@ async def deliver(self, message: MemoryMessage) -> None: message: The message to deliver. """ await self._queue.put(message) + await update_metric(self.topic.name, "published") + logger.debug( + "Delivered message to topic=%r queue_size=%d retry_queue_size=%d", + self.topic.name, + self._queue.qsize(), + self._retry.qsize(), + ) async def retry(self, message: MemoryMessage) -> None: """ @@ -99,6 +149,14 @@ async def retry(self, message: MemoryMessage) -> None: await self.send_to_dead_letter_queue(message) else: await self._retry.put(message) + await update_metric(self.topic.name, "retried") + logger.debug( + "Retried message on topic=%r retry_count=%d queue_size=%d retry_queue_size=%d", + self.topic.name, + message.retry_count, + self._queue.qsize(), + self._retry.qsize(), + ) async def get(self) -> MemoryMessage: """ @@ -152,8 +210,8 @@ def clear_all(cls) -> None: topic.clear() cls._topics = {} - def subscribe(self) -> Subscription: - subscription = Subscription(self) + def subscribe(self, **subscription_kwargs: Any) -> Subscription: + subscription = Subscription(self, **subscription_kwargs) self._subscriptions.append(subscription) return subscription @@ -243,9 +301,9 @@ async def __aenter__(self) -> Self: async def __aexit__( self, - exc_type: Optional[Type[BaseException]], - exc_val: Optional[BaseException], - exc_tb: Optional[TracebackType], + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, ) -> None: return None @@ -266,22 +324,41 @@ async def publish_data(self, data: bytes, attributes: Mapping[str, str]) -> None class Consumer(_Consumer): - def __init__(self, topic: str, subscription: Optional[Subscription] = None): + def __init__( + self, + topic: str, + subscription: Optional[Subscription] = None, + concurrency: int = 2, + ): self.topic: Topic = Topic.by_name(topic) if not subscription: subscription = self.topic.subscribe() assert subscription.topic is self.topic self.subscription = subscription + self.concurrency = concurrency async def run(self, handler: MessageHandler) -> None: + try: + async with anyio.create_task_group() as tg: + for _ in range(self.concurrency): + tg.start_soon(self._consume_loop, handler) + except BaseExceptionGroup as group: # novermin + if all(isinstance(exc, StopConsumer) for exc in group.exceptions): + logger.debug("StopConsumer received") + return # Exit cleanly when all tasks stop + # Re-raise if any non-StopConsumer exceptions + raise group + + async def _consume_loop(self, handler: MessageHandler) -> None: while True: message = await self.subscription.get() try: await handler(message) + await update_metric(self.topic.name, "consumed") except StopConsumer as e: if not e.ack: await self.subscription.retry(message) - return + raise # Propagate to task group except Exception: await self.subscription.retry(message) diff --git a/tests/server/utilities/test_messaging.py b/tests/server/utilities/test_messaging.py index 862e3a547238..5664c5a9a5d7 100644 --- a/tests/server/utilities/test_messaging.py +++ b/tests/server/utilities/test_messaging.py @@ -9,7 +9,6 @@ AsyncGenerator, Callable, Generator, - List, Optional, ) @@ -124,11 +123,11 @@ async def publisher(broker: str, cache: Cache) -> Publisher: @pytest.fixture async def consumer(broker: str, clear_topics: None) -> Consumer: - return create_consumer("my-topic") + return create_consumer("my-topic", concurrency=1) async def drain_one(consumer: Consumer) -> Optional[Message]: - captured_messages: List[Message] = [] + captured_messages: list[Message] = [] async def handler(message: Message): captured_messages.append(message) @@ -143,7 +142,7 @@ async def handler(message: Message): async def test_publishing_and_consuming_a_single_message( publisher: Publisher, consumer: Consumer ) -> None: - captured_messages: List[Message] = [] + captured_messages: list[Message] = [] async def handler(message: Message): captured_messages.append(message) @@ -169,7 +168,7 @@ async def handler(message: Message): async def test_stopping_consumer_without_acking( publisher: Publisher, consumer: Consumer ) -> None: - captured_messages: List[Message] = [] + captured_messages: list[Message] = [] async def handler(message: Message): captured_messages.append(message) @@ -195,7 +194,7 @@ async def handler(message: Message): async def test_erroring_handler_does_not_ack( publisher: Publisher, consumer: Consumer ) -> None: - captured_messages: List[Message] = [] + captured_messages: list[Message] = [] async def handler(message: Message): captured_messages.append(message) @@ -228,7 +227,7 @@ def deduplicating_publisher(broker: str, cache: Cache) -> Publisher: async def test_publisher_will_avoid_sending_duplicate_messages_in_same_batch( deduplicating_publisher: Publisher, consumer: Consumer ): - captured_messages: List[Message] = [] + captured_messages: list[Message] = [] async def handler(message: Message): captured_messages.append(message) @@ -259,7 +258,7 @@ async def handler(message: Message): async def test_publisher_will_avoid_sending_duplicate_messages_in_different_batches( deduplicating_publisher: Publisher, consumer: Consumer ): - captured_messages: List[Message] = [] + captured_messages: list[Message] = [] async def handler(message: Message): captured_messages.append(message) @@ -328,7 +327,7 @@ async def test_publisher_will_forget_duplicate_messages_on_error( assert not remaining_message # but on a subsequent attempt, the message is published and not considered duplicate - captured_messages: List[Message] = [] + captured_messages: list[Message] = [] async def handler(message: Message): captured_messages.append(message) @@ -356,7 +355,7 @@ async def handler(message: Message): async def test_publisher_does_not_interfere_with_duplicate_messages_without_id( deduplicating_publisher: Publisher, consumer: Consumer ): - captured_messages: List[Message] = [] + captured_messages: list[Message] = [] async def handler(message: Message): captured_messages.append(message) @@ -402,7 +401,7 @@ async def test_publisher_does_not_interfere_with_duplicate_messages_without_id_o assert not remaining_message # but on a subsequent attempt, the message is published - captured_messages: List[Message] = [] + captured_messages: list[Message] = [] async def handler(message: Message): captured_messages.append(message) @@ -427,7 +426,7 @@ async def handler(message: Message): async def test_ephemeral_subscription(broker: str, publisher: Publisher): - captured_messages: List[Message] = [] + captured_messages: list[Message] = [] async def handler(message: Message): captured_messages.append(message) @@ -461,7 +460,7 @@ async def test_repeatedly_failed_message_is_moved_to_dead_letter_queue( consumer: MemoryConsumer, tmp_path: Path, ): - captured_messages: List[Message] = [] + captured_messages: list[Message] = [] async def handler(message: Message): captured_messages.append(message) @@ -534,3 +533,70 @@ async def handler(message: Message): await consumer_task assert captured_events == [emitted_event] + + +@pytest.mark.usefixtures("broker", "clear_topics") +@pytest.mark.parametrize("concurrency,num_messages", [(2, 4), (4, 8)]) +async def test_concurrent_consumers_process_messages( + publisher: Publisher, concurrency: int, num_messages: int +) -> None: + """Test that messages are fairly distributed across concurrent consumers""" + concurrent_consumer = create_consumer("my-topic", concurrency=concurrency) + processed_messages: list[Message] = [] + processed_by_consumer: dict[int, list[Message]] = { + i + 1: [] for i in range(concurrency) + } + consumer_seen = 0 + processing_order: list[int] = [] + + async def handler(message: Message): + nonlocal consumer_seen + # Track which consumer got the message + consumer_id = consumer_seen % concurrency + 1 + consumer_seen += 1 + processed_by_consumer[consumer_id].append(message) + processed_messages.append(message) + processing_order.append(consumer_id) + + # First consumer is slow but should still get its fair share + if consumer_id == 1: + await asyncio.sleep(0.1) + + if len(processed_messages) >= num_messages: + raise StopConsumer(ack=True) + + consumer_task = asyncio.create_task(concurrent_consumer.run(handler)) + + try: + async with publisher as p: + # Send multiple messages + for i in range(num_messages): + await p.publish_data(f"message-{i}".encode(), {"index": str(i)}) + finally: + await consumer_task + + # Verify total messages processed + assert len(processed_messages) == num_messages + + # Verify all consumers processed equal number of messages + messages_per_consumer = num_messages // concurrency + for consumer_id in range(1, concurrency + 1): + assert ( + len(processed_by_consumer[consumer_id]) == messages_per_consumer + ), f"Consumer {consumer_id} should process exactly {messages_per_consumer} messages" + + # Verify messages were processed in round-robin order + expected_order = [(i % concurrency) + 1 for i in range(num_messages)] + assert ( + processing_order == expected_order + ), "Messages should be distributed in round-robin fashion" + + # Verify each consumer got the correct messages + for consumer_id in range(1, concurrency + 1): + expected_indices = list(range(consumer_id - 1, num_messages, concurrency)) + actual_indices = [ + int(msg.attributes["index"]) for msg in processed_by_consumer[consumer_id] + ] + assert ( + actual_indices == expected_indices + ), f"Consumer {consumer_id} should process messages {expected_indices}"