From 57afbfc1eb34db375f59018b16edc91eeb75175d Mon Sep 17 00:00:00 2001 From: dromanov Date: Thu, 15 Aug 2024 00:40:42 +0300 Subject: [PATCH 1/3] add web.Runner context manager --- aiohttp/web.py | 261 ++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 234 insertions(+), 27 deletions(-) diff --git a/aiohttp/web.py b/aiohttp/web.py index 68b29c79d0b..3432dbd1b1c 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -1,12 +1,19 @@ import asyncio +import contextvars +import enum +import functools import logging import os +import signal import socket import sys +import threading import warnings from argparse import ArgumentParser +from asyncio import Task, constants, coroutines, events, exceptions, tasks from collections.abc import Iterable from importlib import import_module +from types import FrameType, TracebackType from typing import ( Any, Awaitable, @@ -18,6 +25,7 @@ Type, Union, cast, + final, ) from .abc import AbstractAccessLogger @@ -263,9 +271,9 @@ "WSMsgType", # web "run_app", + "Runner", ) - try: from ssl import SSLContext except ImportError: # pragma: no cover @@ -277,6 +285,222 @@ HostSequence = TypingIterable[str] +class _State(enum.Enum): + CREATED = "created" + INITIALIZED = "initialized" + CLOSED = "closed" + + +@final +class Runner: + """A context manager that controls event loop life cycle""" + + def __init__( + self, + *, + debug: Optional[bool] = None, + loop_factory: Optional[Callable[[], asyncio.AbstractEventLoop]] = None, + ): + self._state = _State.CREATED + self._debug = debug + self._loop_factory = loop_factory + self._loop = None + self._context = None + self._interrupt_count = 0 + self._set_event_loop = False + + def __enter__(self) -> "Runner": + self._lazy_init() + return self + + def __exit__( + self, + exc_type: Optional[type], + exc_val: Optional[BaseException], + exc_tb: Optional[TracebackType], + ) -> None: + self.close() + + def close(self) -> None: + """Shutdown and close event loop.""" + if self._state is not _State.INITIALIZED: + return + loop = self._loop + try: + _cancel_tasks(tasks.all_tasks(loop), loop) + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.run_until_complete( + loop.shutdown_default_executor(constants.THREAD_JOIN_TIMEOUT) + ) + finally: + if self._set_event_loop: + events.set_event_loop(None) + loop.close() + self._loop = None + self._state = _State.CLOSED + + def get_loop(self) -> asyncio.AbstractEventLoop: + """Return embedded event loop.""" + self._lazy_init() + return self._loop + + def run( + self, coro: Awaitable, *, context: Optional[contextvars.Context] = None + ) -> Any: + """Run a coroutine inside the embedded event loop.""" + if not coroutines.iscoroutine(coro): + raise ValueError(f"a coroutine was expected, got {coro!r}") + + if events._get_running_loop() is not None: + # fail fast with short traceback + raise RuntimeError( + "Runner.run() cannot be called from a running event loop" + ) + + self._lazy_init() + + if context is None: + context = self._context + task = self._loop.create_task(coro, context=context) + + if ( + threading.current_thread() is threading.main_thread() + and signal.getsignal(signal.SIGINT) is signal.default_int_handler + ): + sigint_handler = functools.partial(self._on_sigint, main_task=task) + try: + signal.signal(signal.SIGINT, sigint_handler) + except ValueError: + # `signal.signal` may throw if `threading.main_thread` does + # not support signals (e.g. embedded interpreter with signals + # not registered - see gh-91880) + sigint_handler = None + else: + sigint_handler = None + + self._interrupt_count = 0 + try: + return self._loop.run_until_complete(task) + except exceptions.CancelledError: + if self._interrupt_count > 0: + uncancel = getattr(task, "uncancel", None) + if uncancel is not None and uncancel() == 0: + raise KeyboardInterrupt() + raise # CancelledError + finally: + if ( + sigint_handler is not None + and signal.getsignal(signal.SIGINT) is sigint_handler + ): + signal.signal(signal.SIGINT, signal.default_int_handler) + + def run_app( + self, + app: Union[Application, Awaitable[Application]], + *, + host: Optional[Union[str, HostSequence]] = None, + port: Optional[int] = None, + path: Union[PathLike, TypingIterable[PathLike], None] = None, + sock: Optional[Union[socket.socket, TypingIterable[socket.socket]]] = None, + shutdown_timeout: float = 60.0, + keepalive_timeout: float = 75.0, + ssl_context: Optional[SSLContext] = None, + print: Optional[Callable[..., None]] = print, + backlog: int = 128, + access_log_class: Type[AbstractAccessLogger] = AccessLogger, + access_log_format: str = AccessLogger.LOG_FORMAT, + access_log: Optional[logging.Logger] = access_logger, + handle_signals: bool = True, + reuse_address: Optional[bool] = None, + reuse_port: Optional[bool] = None, + handler_cancellation: bool = False, + ) -> None: + """Run an app locally""" + self._lazy_init() + + self._loop.set_debug(self._debug) + + if ( + self._loop.get_debug() + and access_log + and access_log.name == "aiohttp.access" + ): + if access_log.level == logging.NOTSET: + access_log.setLevel(logging.DEBUG) + if not access_log.hasHandlers(): + access_log.addHandler(logging.StreamHandler()) + + main_task = self._loop.create_task( + _run_app( + app, + host=host, + port=port, + path=path, + sock=sock, + shutdown_timeout=shutdown_timeout, + keepalive_timeout=keepalive_timeout, + ssl_context=ssl_context, + print=print, + backlog=backlog, + access_log_class=access_log_class, + access_log_format=access_log_format, + access_log=access_log, + handle_signals=handle_signals, + reuse_address=reuse_address, + reuse_port=reuse_port, + handler_cancellation=handler_cancellation, + ) + ) + + try: + if self._set_event_loop: + asyncio.set_event_loop(self._loop) + self._loop.run_until_complete(main_task) + except (GracefulExit, KeyboardInterrupt): # pragma: no cover + pass + finally: + _cancel_tasks({main_task}, self._loop) + _cancel_tasks(asyncio.all_tasks(self._loop), self._loop) + self._loop.run_until_complete(self._loop.shutdown_asyncgens()) + self.close() + asyncio.set_event_loop(None) + + def _lazy_init(self) -> None: + if self._state is _State.CLOSED: + raise RuntimeError("Runner is closed") + if self._state is _State.INITIALIZED: + return + if self._loop_factory is None: + self._loop = events.new_event_loop() + if not self._set_event_loop: + # Call set_event_loop only once to avoid calling + # attach_loop multiple times on child watchers + events.set_event_loop(self._loop) + self._set_event_loop = True + else: + try: + self._loop = self._loop_factory() + except RuntimeError: + self._loop = events.new_event_loop() + events.set_event_loop(self._loop) + self._set_event_loop = True + if self._debug is not None: + self._loop.set_debug(self._debug) + self._context = contextvars.copy_context() + self._state = _State.INITIALIZED + + def _on_sigint( + self, signum: int, frame: Optional[FrameType], main_task: Task + ) -> None: + self._interrupt_count += 1 + if self._interrupt_count == 1 and not main_task.done(): + main_task.cancel() + # wakeup loop if it is blocked by select() with long timeout + self._loop.call_soon_threadsafe(lambda: None) + return + raise KeyboardInterrupt() + + async def _run_app( app: Union[Application, Awaitable[Application]], *, @@ -462,19 +686,15 @@ def run_app( loop: Optional[asyncio.AbstractEventLoop] = None, ) -> None: """Run an app locally""" - if loop is None: - loop = asyncio.new_event_loop() - loop.set_debug(debug) - - # Configure if and only if in debugging mode and using the default logger - if loop.get_debug() and access_log and access_log.name == "aiohttp.access": - if access_log.level == logging.NOTSET: - access_log.setLevel(logging.DEBUG) - if not access_log.hasHandlers(): - access_log.addHandler(logging.StreamHandler()) - - main_task = loop.create_task( - _run_app( + if loop is not None: + + def loop_factory(): + return loop + + else: + loop_factory = events.get_running_loop + with Runner(debug=debug, loop_factory=loop_factory) as runner: + runner.run_app( app, host=host, port=port, @@ -493,19 +713,6 @@ def run_app( reuse_port=reuse_port, handler_cancellation=handler_cancellation, ) - ) - - try: - asyncio.set_event_loop(loop) - loop.run_until_complete(main_task) - except (GracefulExit, KeyboardInterrupt): # pragma: no cover - pass - finally: - _cancel_tasks({main_task}, loop) - _cancel_tasks(asyncio.all_tasks(loop), loop) - loop.run_until_complete(loop.shutdown_asyncgens()) - loop.close() - asyncio.set_event_loop(None) def main(argv: List[str]) -> None: From 283570f6724a7c3f125e9959c0c92d1e78f66045 Mon Sep 17 00:00:00 2001 From: dromanov Date: Fri, 16 Aug 2024 20:00:07 +0300 Subject: [PATCH 2/3] =?UTF-8?q?=E2=9C=85=20add=20tests=20for=20web.Runner?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_runner.py | 295 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 295 insertions(+) create mode 100644 tests/test_runner.py diff --git a/tests/test_runner.py b/tests/test_runner.py new file mode 100644 index 00000000000..f0ce7e3c1c2 --- /dev/null +++ b/tests/test_runner.py @@ -0,0 +1,295 @@ +import asyncio +import contextvars +import re +import unittest +from unittest import mock + +from aiohttp import web + + +def tearDownModule(): + asyncio.set_event_loop_policy(None) + + +class _TestPolicy(asyncio.AbstractEventLoopPolicy): + + def __init__(self, loop_factory): + self.loop_factory = loop_factory + self.loop = None + + def get_event_loop(self): + # shouldn't ever be called by asyncio.run() + raise RuntimeError + + def new_event_loop(self): + return self.loop_factory() + + def set_event_loop(self, loop): + if loop is not None: + # we want to check if the loop is closed + # in BaseTest.tearDown + self.loop = loop + + +class BaseTest(unittest.TestCase): + + def new_loop(self): + loop = asyncio.BaseEventLoop() + loop._process_events = mock.Mock() + loop._selector = mock.Mock() + loop._selector.select.return_value = () + loop.shutdown_ag_run = False + + async def shutdown_asyncgens(): + loop.shutdown_ag_run = True + + loop.shutdown_asyncgens = shutdown_asyncgens + + return loop + + def setUp(self): + super().setUp() + + policy = _TestPolicy(self.new_loop) + asyncio.set_event_loop_policy(policy) + + def tearDown(self): + policy = asyncio.get_event_loop_policy() + if policy.loop is not None: + self.assertTrue(policy.loop.is_closed()) + self.assertTrue(policy.loop.shutdown_ag_run) + + asyncio.set_event_loop_policy(None) + super().tearDown() + + +class RunTests(BaseTest): + + def test_asyncio_run_return(self): + async def main(): + await asyncio.sleep(0) + return 42 + + self.assertEqual(asyncio.run(main()), 42) + + def test_asyncio_run_raises(self): + async def main(): + await asyncio.sleep(0) + raise ValueError("spam") + + with self.assertRaisesRegex(ValueError, "spam"): + asyncio.run(main()) + + def test_asyncio_run_only_coro(self): + for o in {1, lambda: None}: + with self.subTest(obj=o), self.assertRaisesRegex( + ValueError, "a coroutine was expected" + ): + asyncio.run(o) + + def test_asyncio_run_debug(self): + async def main(expected): + loop = asyncio.get_event_loop() + self.assertIs(loop.get_debug(), expected) + + asyncio.run(main(False)) + asyncio.run(main(True), debug=True) + with mock.patch("asyncio.coroutines._is_debug_mode", lambda: True): + asyncio.run(main(True)) + asyncio.run(main(False), debug=False) + + def test_asyncio_run_from_running_loop(self): + async def main(): + coro = main() + try: + asyncio.run(coro) + finally: + coro.close() # Suppress ResourceWarning + + with self.assertRaisesRegex(RuntimeError, "cannot be called from a running"): + asyncio.run(main()) + + def test_asyncio_run_cancels_hanging_tasks(self): + lo_task = None + + async def leftover(): + await asyncio.sleep(0.1) + + async def main(): + nonlocal lo_task + lo_task = asyncio.create_task(leftover()) + return 123 + + self.assertEqual(asyncio.run(main()), 123) + self.assertTrue(lo_task.done()) + + def test_asyncio_run_reports_hanging_tasks_errors(self): + lo_task = None + call_exc_handler_mock = mock.Mock() + + async def leftover(): + try: + await asyncio.sleep(0.1) + except asyncio.CancelledError: + 1 / 0 + + async def main(): + loop = asyncio.get_running_loop() + loop.call_exception_handler = call_exc_handler_mock + + nonlocal lo_task + lo_task = asyncio.create_task(leftover()) + return 123 + + self.assertEqual(asyncio.run(main()), 123) + self.assertTrue(lo_task.done()) + + def test_asyncio_run_closes_gens_after_hanging_tasks_errors(self): + spinner = None + lazyboy = None + + class FancyExit(Exception): + pass + + async def fidget(): + while True: + yield 1 + await asyncio.sleep(1) + + async def spin(): + nonlocal spinner + spinner = fidget() + try: + async for the_meaning_of_life in spinner: + pass + except asyncio.CancelledError: + 1 / 0 + + async def main(): + loop = asyncio.get_running_loop() + loop.call_exception_handler = mock.Mock() + + nonlocal lazyboy + lazyboy = asyncio.create_task(spin()) + raise FancyExit + + with self.assertRaises(FancyExit): + asyncio.run(main()) + + self.assertTrue(lazyboy.done()) + + self.assertIsNone(spinner.ag_frame) + self.assertFalse(spinner.ag_running) + + +class RunnerTests(BaseTest): + + def test_non_debug(self): + with web.Runner(debug=False) as runner: + self.assertFalse(runner.get_loop().get_debug()) + + def test_debug(self): + with web.Runner(debug=True) as runner: + self.assertTrue(runner.get_loop().get_debug()) + + def test_custom_factory(self): + loop = mock.Mock() + with web.Runner(loop_factory=lambda: loop) as runner: + self.assertIs(runner.get_loop(), loop) + + def test_run(self): + async def f(): + await asyncio.sleep(0) + return "done" + + with web.Runner() as runner: + self.assertEqual("done", runner.run(f())) + loop = runner.get_loop() + + with self.assertRaisesRegex(RuntimeError, "Runner is closed"): + runner.get_loop() + + self.assertTrue(loop.is_closed()) + + def test_run_non_coro(self): + with web.Runner() as runner: + with self.assertRaisesRegex(ValueError, "a coroutine was expected"): + runner.run(123) + + def test_run_future(self): + with web.Runner() as runner: + with self.assertRaisesRegex(ValueError, "a coroutine was expected"): + fut = runner.get_loop().create_future() + runner.run(fut) + + def test_explicit_close(self): + runner = web.Runner() + loop = runner.get_loop() + runner.close() + with self.assertRaisesRegex(RuntimeError, "Runner is closed"): + runner.get_loop() + + self.assertTrue(loop.is_closed()) + + def test_double_close(self): + runner = web.Runner() + loop = runner.get_loop() + + runner.close() + self.assertTrue(loop.is_closed()) + + # the second call is no-op + runner.close() + self.assertTrue(loop.is_closed()) + + def test_second_with_block_raises(self): + ret = [] + + async def f(arg): + ret.append(arg) + + runner = web.Runner() + with runner: + runner.run(f(1)) + + with self.assertRaisesRegex(RuntimeError, "Runner is closed"): + with runner: + runner.run(f(2)) + + self.assertEqual([1], ret) + + def test_run_keeps_context(self): + cvar = contextvars.ContextVar("cvar", default=-1) + + async def f(val): + old = cvar.get() + await asyncio.sleep(0) + cvar.set(val) + return old + + async def get_context(): + return contextvars.copy_context() + + with web.Runner() as runner: + self.assertEqual(-1, runner.run(f(1))) + self.assertEqual(1, runner.run(f(2))) + + def test_recursine_run(self): + async def g(): + pass + + async def f(): + runner.run(g()) + + with web.Runner() as runner: + with self.assertWarnsRegex( + RuntimeWarning, + "coroutine .+ was never awaited", + ): + with self.assertRaisesRegex( + RuntimeError, + re.escape( + "Runner.run() cannot be called from a running event loop" + ), + ): + runner.run(f()) From 77ec5ebb6658feed0add3f1b919702d5936f6921 Mon Sep 17 00:00:00 2001 From: dromanov Date: Fri, 16 Aug 2024 21:46:11 +0300 Subject: [PATCH 3/3] =?UTF-8?q?=F0=9F=93=9D=20add=20newsfragment?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CHANGES/8723.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 CHANGES/8723.feature.rst diff --git a/CHANGES/8723.feature.rst b/CHANGES/8723.feature.rst new file mode 100644 index 00000000000..59fc945e45a --- /dev/null +++ b/CHANGES/8723.feature.rst @@ -0,0 +1 @@ +Implement web.Runner context manager -- by :user:`DavidRomanovizc`