diff --git a/README.rst b/README.rst index 938bf38b..f0a759ba 100644 --- a/README.rst +++ b/README.rst @@ -47,10 +47,51 @@ Usage example await asyncio.sleep(5.0) # not all scheduled jobs are finished at the moment - # gracefully close spawned jobs - await scheduler.close() + # gracefully wait on tasks before closing any remaining spawned jobs + await scheduler.wait_and_close() - asyncio.get_event_loop().run_until_complete(main()) + asyncio.run(main()) + +Shielding tasks with a scheduler +================================ + +It is typically recommended to use ``asyncio.shield`` to protect tasks +from cancellation. However, the inner shielded tasks can't be tracked and +are therefore at risk of being cancelled during application shutdown. + +To resolve this issue aiojobs includes a ``aiojobs.Scheduler.shield`` +method to shield tasks while also keeping track of them in the scheduler. +In combination with the ``aiojobs.Scheduler.wait_and_close`` method, +this allows shielded tasks the required time to complete successfully +during application shutdown. + +For example: + +.. code-block:: python + + import asyncio + import aiojobs + from contextlib import suppress + + async def important(): + print("START") + await asyncio.sleep(5) + print("DONE") + + async def run_something(scheduler): + # If we use asyncio.shield() here, then the task doesn't complete and DONE is never printed. + await scheduler.shield(important()) + + async def main(): + scheduler = aiojobs.Scheduler() + t = asyncio.create_task(run_something(scheduler)) + await asyncio.sleep(0.1) + t.cancel() + with suppress(asyncio.CancelledError): + await t + await scheduler.wait_and_close() + + asyncio.run(main()) Integration with aiohttp.web diff --git a/aiojobs/_scheduler.py b/aiojobs/_scheduler.py index 3a938c0c..0a4abc6b 100644 --- a/aiojobs/_scheduler.py +++ b/aiojobs/_scheduler.py @@ -1,6 +1,9 @@ import asyncio +import sys +from contextlib import suppress from typing import ( Any, + Awaitable, Callable, Collection, Coroutine, @@ -9,14 +12,34 @@ Optional, Set, TypeVar, + Union, ) from ._job import Job +if sys.version_info >= (3, 11): + from asyncio import timeout as asyncio_timeout +else: + from async_timeout import timeout as asyncio_timeout + _T = TypeVar("_T") +_FutureLike = Union["asyncio.Future[_T]", Awaitable[_T]] ExceptionHandler = Callable[["Scheduler", Dict[str, Any]], None] +def _get_loop( # pragma: no cover + fut: "asyncio.Task[object]", +) -> asyncio.AbstractEventLoop: + # https://github.com/python/cpython/blob/bb802db8cfa35a88582be32fae05fe1cf8f237b1/Lib/asyncio/futures.py#L300 + try: + get_loop = fut.get_loop + except AttributeError: + pass + else: + return get_loop() + return fut._loop + + class Scheduler(Collection[Job[object]]): def __init__( self, @@ -33,6 +56,7 @@ def __init__( ) self._jobs: Set[Job[object]] = set() + self._shields: Set[asyncio.Task[object]] = set() self._close_timeout = close_timeout self._limit = limit self._exception_handler = exception_handler @@ -104,19 +128,72 @@ async def spawn( self._jobs.add(job) return job + def shield(self, arg: _FutureLike[_T]) -> "asyncio.Future[_T]": + inner = asyncio.ensure_future(arg) + if inner.done(): + return inner + + # This function is a copy of asyncio.shield(), except for the addition of + # the below 2 lines. + self._shields.add(inner) + inner.add_done_callback(self._shields.discard) + + loop = _get_loop(inner) + outer = loop.create_future() + + def _inner_done_callback(inner: "asyncio.Task[object]") -> None: + if outer.cancelled(): + if not inner.cancelled(): + inner.exception() + return + + if inner.cancelled(): + outer.cancel() + else: + exc = inner.exception() + if exc is not None: + outer.set_exception(exc) + else: + outer.set_result(inner.result()) + + def _outer_done_callback(outer: "asyncio.Future[object]") -> None: + if not inner.done(): + inner.remove_done_callback(_inner_done_callback) + + inner.add_done_callback(_inner_done_callback) + outer.add_done_callback(_outer_done_callback) + return outer + + async def wait_and_close(self, timeout: float = 60) -> None: + with suppress(asyncio.TimeoutError): + async with asyncio_timeout(timeout): + while self._jobs or self._shields: + gather = asyncio.gather( + *(job.wait() for job in self._jobs), + *self._shields, + return_exceptions=True, + ) + await asyncio.shield(gather) + await self.close() + async def close(self) -> None: if self._closed: return self._closed = True # prevent adding new jobs jobs = self._jobs - if jobs: + if jobs or self._shields: # cleanup pending queue # all job will be started on closing while not self._pending.empty(): self._pending.get_nowait() + + for f in self._shields: + f.cancel() + await asyncio.gather( - *[job._close(self._close_timeout) for job in jobs], + *(job._close(self._close_timeout) for job in jobs), + *(asyncio.wait_for(f, self._close_timeout) for f in self._shields), return_exceptions=True, ) self._jobs.clear() diff --git a/aiojobs/aiohttp.py b/aiojobs/aiohttp.py index 58a62e11..87adf559 100644 --- a/aiojobs/aiohttp.py +++ b/aiojobs/aiohttp.py @@ -1,3 +1,4 @@ +import asyncio from functools import wraps from typing import ( Any, @@ -18,6 +19,7 @@ __all__ = ("setup", "spawn", "get_scheduler", "get_scheduler_from_app", "atomic") _T = TypeVar("_T") +_FutureLike = Union["asyncio.Future[_T]", Awaitable[_T]] _RequestView = TypeVar("_RequestView", bound=Union[web.Request, web.View]) @@ -43,6 +45,10 @@ async def spawn(request: web.Request, coro: Coroutine[object, object, _T]) -> Jo return await get_scheduler(request).spawn(coro) +def shield(request: web.Request, arg: _FutureLike[_T]) -> "asyncio.Future[_T]": + return get_scheduler(request).shield(arg) + + def atomic( coro: Callable[[_RequestView], Coroutine[object, object, _T]] ) -> Callable[[_RequestView], Awaitable[_T]]: @@ -65,6 +71,6 @@ def setup(app: web.Application, **kwargs: Any) -> None: async def cleanup_context(app: web.Application) -> AsyncIterator[None]: app[AIOJOBS_SCHEDULER] = scheduler = Scheduler(**kwargs) yield - await scheduler.close() + await scheduler.wait_and_close() app.cleanup_ctx.append(cleanup_context) diff --git a/docs/api.rst b/docs/api.rst index a4ce8b9d..8a4fb02d 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -103,14 +103,31 @@ Scheduler The method respects :attr:`pending_limit` now. + .. py:method:: shield(coro) + :async: + + Protect an awaitable from being cancelled. + + This is a drop-in replacement for :func:`asyncio.shield`, with the + addition of tracking the shielded task in the scheduler. This can be + used to ensure that shielded tasks will actually be completed on + application shutdown. + + .. py:method:: wait_and_close(timeout=60) + :async: + + Wait for currently scheduled tasks to finish gracefully for the given + *timeout*. Then proceed with closing the scheduler, where any + remaining tasks will be cancelled. + .. py:method:: close() :async: - Close scheduler and all its jobs. + Close scheduler and all its jobs by cancelling the tasks and then + waiting on them. - It finishing time for particular job exceeds - :attr:`close_timeout` this job is logged by - :meth:`call_exception_handler`. + It finishing time for a particular job exceeds :attr:`close_timeout` + the job is logged by :meth:`call_exception_handler`. .. attribute:: exception_handler @@ -221,6 +238,15 @@ jobs. Return :class:`aiojobs.Job` instance +.. function:: shield(request, coro) + :async: + + Protect an awaitable from being cancelled while registering the shielded + task into the registered scheduler. + + Any shielded tasks will then be run to completion when the web app shuts + down (assuming it doesn't exceed the shutdown timeout). + Helpers diff --git a/docs/index.rst b/docs/index.rst index 1660ae11..e1ab465a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -36,14 +36,56 @@ Usage example await asyncio.sleep(5.0) # not all scheduled jobs are finished at the moment - # gracefully close spawned jobs - await scheduler.close() + # gracefully wait on tasks before closing any remaining spawned jobs + await scheduler.wait_and_close() asyncio.run(main()) For further information read :ref:`aiojobs-quickstart`, :ref:`aiojobs-intro` and :ref:`aiojobs-api`. +Shielding tasks with a scheduler +-------------------------------- + +It is typically recommended to use :func:`asyncio.shield` to protect tasks +from cancellation. However, the inner shielded tasks can't be tracked and +are therefore at risk of being cancelled during application shutdown. + +To resolve this issue aiojobs includes a :meth:`aiojobs.Scheduler.shield` +method to shield tasks while also keeping track of them in the scheduler. +In combination with the :meth:`aiojobs.Scheduler.wait_and_close` method, +this allows shielded tasks the required time to complete successfully +during application shutdown. + +For example: + +.. code-block:: python + + import asyncio + import aiojobs + from contextlib import suppress + + async def important(): + print("START") + await asyncio.sleep(5) + print("DONE") + + async def run_something(scheduler): + # If we use asyncio.shield() here, then the task doesn't complete and DONE is never printed. + await scheduler.shield(important()) + + async def main(): + scheduler = aiojobs.Scheduler() + t = asyncio.create_task(run_something(scheduler)) + await asyncio.sleep(0.1) + t.cancel() + with suppress(asyncio.CancelledError): + await t + await scheduler.wait_and_close() + + asyncio.run(main()) + + Integration with aiohttp.web ---------------------------- diff --git a/tests/test_aiohttp.py b/tests/test_aiohttp.py index deb1b3d6..15d40264 100644 --- a/tests/test_aiohttp.py +++ b/tests/test_aiohttp.py @@ -12,6 +12,7 @@ get_scheduler_from_app, get_scheduler_from_request, setup as aiojobs_setup, + shield, spawn, ) @@ -23,6 +24,10 @@ async def test_plugin(aiohttp_client: _Client) -> None: job = None + async def shielded() -> str: + await asyncio.sleep(0) + return "TEST" + async def coro() -> None: await asyncio.sleep(10) @@ -31,7 +36,9 @@ async def handler(request: web.Request) -> web.Response: job = await spawn(request, coro()) assert not job.closed - return web.Response() + + res = await shield(request, shielded()) + return web.Response(text=res) app = web.Application() app.router.add_get("/", handler) @@ -40,6 +47,7 @@ async def handler(request: web.Request) -> web.Response: client = await aiohttp_client(app) resp = await client.get("/") assert resp.status == 200 + assert await resp.text() == "TEST" assert job is not None assert job.active diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index 5e2a3a1c..b365c8f4 100644 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -426,6 +426,149 @@ async def f() -> None: del coro +async def test_shield(scheduler: Scheduler) -> None: + async def coro() -> str: + await asyncio.sleep(0) + return "TEST" + + result = await scheduler.shield(coro()) + assert result == "TEST" + assert len(scheduler._shields) == 0 + + +async def test_shielded_task_continues(scheduler: Scheduler) -> None: + completed = False + + async def inner() -> None: + nonlocal completed + await asyncio.sleep(0.1) + completed = True + + async def outer() -> None: + await scheduler.shield(inner()) + + t = asyncio.create_task(outer()) + await asyncio.sleep(0) + t.cancel() + with pytest.raises(asyncio.CancelledError): + await t + + assert not completed + assert len(scheduler._shields) == 1 + await asyncio.sleep(0.11) + assert completed + assert len(scheduler._shields) == 0 # type: ignore[unreachable] + + +async def test_wait_and_close(scheduler: Scheduler) -> None: + inner_done = outer_done = False + + async def inner() -> None: + nonlocal inner_done + await asyncio.sleep(0.1) + inner_done = True + + async def outer() -> None: + nonlocal outer_done + await scheduler.shield(inner()) + await asyncio.sleep(0.1) + outer_done = True + + await scheduler.spawn(outer()) + await asyncio.sleep(0) + assert not inner_done and not outer_done + assert len(scheduler._shields) == 1 + assert len(scheduler._jobs) == 1 + + await scheduler.wait_and_close() + assert inner_done and outer_done # type: ignore[unreachable] + assert len(scheduler._shields) == 0 # type: ignore[unreachable] + assert len(scheduler._jobs) == 0 + assert scheduler.closed + + +async def test_wait_and_close_timeout(scheduler: Scheduler) -> None: + inner_done = outer_cancelled = False + + async def inner() -> None: + nonlocal inner_done + await asyncio.sleep(0.1) + inner_done = True + + async def outer() -> None: + nonlocal outer_cancelled + await scheduler.shield(inner()) + try: + await asyncio.sleep(0.5) + except asyncio.CancelledError: + outer_cancelled = True + + await scheduler.spawn(outer()) + await asyncio.sleep(0) + assert not inner_done and not outer_cancelled + assert len(scheduler._shields) == 1 + assert len(scheduler._jobs) == 1 + + await scheduler.wait_and_close(0.2) + assert inner_done and outer_cancelled # type: ignore[unreachable] + assert len(scheduler._shields) == 0 # type: ignore[unreachable] + assert len(scheduler._jobs) == 0 + assert scheduler.closed + + +async def test_wait_and_close_timeout_shield(scheduler: Scheduler) -> None: + inner_cancelled = outer_cancelled = False + + async def inner() -> None: + nonlocal inner_cancelled + try: + await asyncio.sleep(0.5) + except asyncio.CancelledError: + inner_cancelled = True + raise + + async def outer() -> None: + nonlocal outer_cancelled + try: + await scheduler.shield(inner()) + except asyncio.CancelledError: + outer_cancelled = True + + await scheduler.spawn(outer()) + await asyncio.sleep(0) + assert not inner_cancelled and not outer_cancelled + assert len(scheduler._shields) == 1 + assert len(scheduler._jobs) == 1 + + await scheduler.wait_and_close(0.1) + assert inner_cancelled and outer_cancelled # type: ignore[unreachable] + assert len(scheduler._shields) == 0 # type: ignore[unreachable] + assert len(scheduler._jobs) == 0 + assert scheduler.closed + + +async def test_wait_and_close_spawn(scheduler: Scheduler) -> None: + another_spawned = another_done = False + + async def another() -> None: + nonlocal another_done + await scheduler.shield(asyncio.sleep(0.1)) + another_done = True + + async def coro() -> None: + nonlocal another_spawned + await asyncio.sleep(0.1) + another_spawned = True + await scheduler.spawn(another()) + + await scheduler.spawn(coro()) + await asyncio.sleep(0) + + assert not another_spawned and not another_done + await scheduler.wait_and_close() + assert another_spawned and another_done # type: ignore[unreachable] + + def test_scheduler_must_be_created_within_running_loop() -> None: with pytest.raises(RuntimeError) as exc_info: Scheduler(close_timeout=0, limit=0, pending_limit=0, exception_handler=None)