From bc9b9ab9966cbdb6070572ebf66cb4ce3ec33537 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20van=20Noord?= <13665637+DanielNoord@users.noreply.github.com> Date: Mon, 4 Nov 2024 21:40:33 +0100 Subject: [PATCH] Add typing to `acquire` --- asyncpg/pool.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 2e4a7b4f..517b76b8 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -7,7 +7,7 @@ from __future__ import annotations import asyncio -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Iterator import functools import inspect import logging @@ -405,7 +405,7 @@ def __init__(self, *connect_args, self._holders = [] self._initialized = False self._initializing = False - self._queue = None + self._queue: Optional[asyncio.LifoQueue[PoolConnectionHolder]] = None self._connection_class = connection_class self._record_class = record_class @@ -838,7 +838,11 @@ async def copy_records_to_table( where=where ) - def acquire(self, *, timeout=None): + def acquire( + self, + *, + timeout: Optional[float] = None, + ) -> PoolAcquireContext: """Acquire a database connection from the pool. :param float timeout: A timeout for acquiring a Connection. @@ -863,11 +867,12 @@ def acquire(self, *, timeout=None): """ return PoolAcquireContext(self, timeout) - async def _acquire(self, timeout): - async def _acquire_impl(): - ch = await self._queue.get() # type: PoolConnectionHolder + async def _acquire(self, timeout: Optional[float]) -> PoolConnectionProxy: + async def _acquire_impl() -> PoolConnectionProxy: + assert self._queue is not None + ch = await self._queue.get() try: - proxy = await ch.acquire() # type: PoolConnectionProxy + proxy = await ch.acquire() except (Exception, asyncio.CancelledError): self._queue.put_nowait(ch) raise @@ -1039,7 +1044,7 @@ def __init__(self, pool: Pool, timeout: Optional[float]) -> None: self.connection = None self.done = False - async def __aenter__(self): + async def __aenter__(self) -> PoolConnectionProxy: if self.connection is not None or self.done: raise exceptions.InterfaceError('a connection is already acquired') self.connection = await self.pool._acquire(self.timeout) @@ -1056,7 +1061,7 @@ async def __aexit__( self.connection = None await self.pool.release(con) - def __await__(self): + def __await__(self) -> Iterator[PoolConnectionProxy]: self.done = True return self.pool._acquire(self.timeout).__await__()