diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 26fdec59..8ba4596a 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -505,7 +505,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, async def _connect_addr(*, addr, loop, timeout, params, config, - connection_class): + middlewares, connection_class): assert loop is not None if timeout <= 0: @@ -539,12 +539,12 @@ async def _connect_addr(*, addr, loop, timeout, params, config, tr.close() raise - con = connection_class(pr, tr, loop, addr, config, params) + con = connection_class(pr, tr, loop, addr, config, params, middlewares) pr.set_connection(con) return con -async def _connect(*, loop, timeout, connection_class, **kwargs): +async def _connect(*, loop, timeout, middlewares, connection_class, **kwargs): if loop is None: loop = asyncio.get_event_loop() @@ -558,6 +558,7 @@ async def _connect(*, loop, timeout, connection_class, **kwargs): con = await _connect_addr( addr=addr, loop=loop, timeout=timeout, params=params, config=config, + middlewares=middlewares, connection_class=connection_class) except (OSError, asyncio.TimeoutError, ConnectionError) as ex: last_error = ex diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 8e841871..b8b67e5b 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -42,7 +42,7 @@ class Connection(metaclass=ConnectionMeta): """ __slots__ = ('_protocol', '_transport', '_loop', - '_top_xact', '_aborted', + '_top_xact', '_aborted', '_middlewares' '_pool_release_ctr', '_stmt_cache', '_stmts_to_close', '_listeners', '_server_version', '_server_caps', '_intro_query', '_reset_query', '_proxy', @@ -53,7 +53,8 @@ class Connection(metaclass=ConnectionMeta): def __init__(self, protocol, transport, loop, addr: (str, int) or str, config: connect_utils._ClientConfiguration, - params: connect_utils._ConnectionParameters): + params: connect_utils._ConnectionParameters, + middlewares=None): self._protocol = protocol self._transport = transport self._loop = loop @@ -92,7 +93,7 @@ def __init__(self, protocol, transport, loop, self._reset_query = None self._proxy = None - + self._middlewares = _middlewares # Used to serialize operations that might involve anonymous # statements. Specifically, we want to make the following # operation atomic: @@ -1410,8 +1411,12 @@ async def reload_schema_state(self): async def _execute(self, query, args, limit, timeout, return_status=False): with self._stmt_exclusive_section: - result, _ = await self.__execute( - query, args, limit, timeout, return_status=return_status) + wrapped = self.__execute + if self._middlewares: + for m in reversed(self._middlewares): + wrapped = await m(self, wrapped) + + result, _ = await wrapped(query, args, limit, timeout, return_status=return_status) return result async def __execute(self, query, args, limit, timeout, @@ -1502,6 +1507,7 @@ async def connect(dsn=None, *, max_cacheable_statement_size=1024 * 15, command_timeout=None, ssl=None, + middlewares=None, connection_class=Connection, server_settings=None): r"""A coroutine to establish a connection to a PostgreSQL server. @@ -1618,6 +1624,10 @@ async def connect(dsn=None, *, PostgreSQL documentation for a `list of supported options `_. + :param middlewares: + An optional list of middleware functions. Refer to documentation + on create_pool. + :param Connection connection_class: Class of the returned connection object. Must be a subclass of :class:`~asyncpg.connection.Connection`. @@ -1683,6 +1693,7 @@ async def connect(dsn=None, *, ssl=ssl, database=database, server_settings=server_settings, command_timeout=command_timeout, + middlewares=middlewares, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, max_cacheable_statement_size=max_cacheable_statement_size) diff --git a/asyncpg/pool.py b/asyncpg/pool.py index 64f4071e..408b3d78 100644 --- a/asyncpg/pool.py +++ b/asyncpg/pool.py @@ -304,7 +304,7 @@ class Pool: Pools are created by calling :func:`~asyncpg.pool.create_pool`. """ - __slots__ = ('_queue', '_loop', '_minsize', '_maxsize', + __slots__ = ('_queue', '_loop', '_minsize', '_maxsize', '_middlewares' '_init', '_connect_args', '_connect_kwargs', '_working_addr', '_working_config', '_working_params', '_holders', '_initialized', '_initializing', '_closing', @@ -317,6 +317,7 @@ def __init__(self, *connect_args, max_inactive_connection_lifetime, setup, init, + middlewares, loop, connection_class, **connect_kwargs): @@ -374,6 +375,7 @@ def __init__(self, *connect_args, self._closed = False self._generation = 0 self._init = init + self._middlewares = middlewares self._connect_args = connect_args self._connect_kwargs = connect_kwargs @@ -460,6 +462,7 @@ async def _get_new_connection(self): *self._connect_args, loop=self._loop, connection_class=self._connection_class, + middlewares=self._middlewares, **self._connect_kwargs) self._working_addr = con._addr @@ -774,7 +777,6 @@ def __await__(self): self.done = True return self.pool._acquire(self.timeout).__await__() - def create_pool(dsn=None, *, min_size=10, max_size=10, @@ -782,6 +784,7 @@ def create_pool(dsn=None, *, max_inactive_connection_lifetime=300.0, setup=None, init=None, + middlewares=None, loop=None, connection_class=connection.Connection, **connect_kwargs): @@ -857,6 +860,19 @@ def create_pool(dsn=None, *, or :meth:`Connection.set_type_codec() <\ asyncpg.connection.Connection.set_type_codec>`. + :param middlewares: + A list of middleware functions to be middleware just + before a connection excecutes a statement. + Syntax of a middleware is as follows: + async def middleware_factory(connection, handler): + async def middleware(query, args. limit, timeout, return_status): + print('do something before') + result, stmt = await handler(query, args, limit, + timeout, return_status) + print('do something after') + return result, stmt + return middleware + :param loop: An asyncio event loop instance. If ``None``, the default event loop will be used. @@ -884,6 +900,8 @@ def create_pool(dsn=None, *, dsn, connection_class=connection_class, min_size=min_size, max_size=max_size, - max_queries=max_queries, loop=loop, setup=setup, init=init, + max_queries=max_queries, loop=loop, setup=setup, + middlewares=middlewares, init=init, max_inactive_connection_lifetime=max_inactive_connection_lifetime, **connect_kwargs) + diff --git a/docs/installation.rst b/docs/installation.rst index 61668663..80d85335 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -29,7 +29,7 @@ If you want to build **asyncpg** from a Git checkout you will need: * CPython header files. These can usually be obtained by installing the relevant Python development package: **python3-dev** on Debian/Ubuntu, **python3-devel** on RHEL/Fedora. - + * Clone the repo with submodules (`git clone --recursive`, or `git submodules init; git submodules update`) Once the above requirements are satisfied, run the following command in the root of the source checkout: diff --git a/tests/test_pool.py b/tests/test_pool.py index b1894f3a..aebd3dc1 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -76,6 +76,23 @@ async def worker(): tasks = [worker() for _ in range(n)] await asyncio.gather(*tasks, loop=self.loop) + async def test_pool_with_middleware(self): + called = False + + async def my_middleware_factory(connection, handler): + async def middleware(query, args, limit, timeout, return_status): + nonlocal called + called = True + return await handler(query, args, limit, + timeout, return_status) + async with self.create_pool(database='postgres', + min_size=1, max_size=1, + middlewares=[my_middleware_factory]) \ + as pool: + con = await pool.acquire(timeout=5) + await con.fetchval('SELECT 1', 1) + assert called + async def test_pool_03(self): pool = await self.create_pool(database='postgres', min_size=1, max_size=1)