Skip to content

Commit

Permalink
revert removed aenter
Browse files Browse the repository at this point in the history
oops

rm needless diff
  • Loading branch information
zzstoatzz committed Jan 24, 2025
1 parent 1ff6158 commit 5aba2f8
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/prefect/server/services/task_run_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,13 @@ async def record_task_run_event(event: ReceivedEvent) -> None:
}

db = provide_database_interface()
async with db.session_context(begin_transaction=False) as session:
async with db.session_context() as session:
await _insert_task_run(session, task_run, task_run_attributes)
await _insert_task_run_state(session, task_run)
await _update_task_run_with_state(
session, task_run, denormalized_state_attributes
)
await session.commit()

logger.debug(
"Recorded task run state change",
Expand Down
20 changes: 15 additions & 5 deletions src/prefect/server/utilities/messaging/memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@
from dataclasses import asdict, dataclass
from datetime import timedelta
from pathlib import Path
from types import TracebackType
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from uuid import uuid4

import anyio
from cachetools import TTLCache
from pydantic_core import to_json
from typing_extensions import Self

from prefect.logging import get_logger
from prefect.server.utilities.messaging import Cache as _Cache
Expand Down Expand Up @@ -288,7 +290,15 @@ def __init__(self, topic: str, cache: Cache, deduplicate_by: Optional[str] = Non
self.deduplicate_by = deduplicate_by
self._cache = cache

async def __aexit__(self, *args: Any) -> None:
async def __aenter__(self) -> Self:
return self

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
return None

async def publish_data(self, data: bytes, attributes: Mapping[str, str]) -> None:
Expand Down Expand Up @@ -329,16 +339,16 @@ async def run(self, handler: MessageHandler) -> None:

async def _consume_loop(self, handler: MessageHandler) -> None:
while True:
msg = await self.subscription.get()
message = await self.subscription.get()
try:
await handler(msg)
await handler(message)
await update_metric(self.topic.name, "consumed")
except StopConsumer as e:
if not e.ack:
await self.subscription.retry(msg)
await self.subscription.retry(message)
raise # ends task group
except Exception:
await self.subscription.retry(msg)
await self.subscription.retry(message)


@asynccontextmanager
Expand Down

0 comments on commit 5aba2f8

Please sign in to comment.