Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

avoid collapsing exception groups from user code #2830

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ combine-as-imports = true

[tool.mypy]
strict = true
python_version = "3.9"

[[tool.mypy.overrides]]
module = "starlette.testclient.*"
Expand Down
42 changes: 32 additions & 10 deletions starlette/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
import functools
import sys
import typing
from contextlib import contextmanager
from contextlib import asynccontextmanager, contextmanager

import anyio.abc

from starlette.types import Scope

Expand All @@ -13,12 +15,14 @@
else: # pragma: no cover
from typing_extensions import TypeGuard

has_exceptiongroups = True
if sys.version_info < (3, 11): # pragma: no cover
try:
from exceptiongroup import BaseExceptionGroup # type: ignore[unused-ignore,import-not-found]
except ImportError:
has_exceptiongroups = False

class BaseExceptionGroup(BaseException): # type: ignore[no-redef]
pass


T = typing.TypeVar("T")
AwaitableCallable = typing.Callable[..., typing.Awaitable[T]]
Expand Down Expand Up @@ -71,15 +75,33 @@ async def __aexit__(self, *args: typing.Any) -> None | bool:


@contextmanager
def collapse_excgroups() -> typing.Generator[None, None, None]:
def _collapse_excgroups() -> typing.Generator[None, None, None]:
try:
yield
except BaseException as exc:
if has_exceptiongroups: # pragma: no cover
while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1:
exc = exc.exceptions[0]

raise exc
except BaseExceptionGroup as excs:
if len(excs.exceptions) != 1:
raise

exc = excs.exceptions[0]
context = exc.__context__
tb = exc.__traceback__
cause = exc.__cause__
sc = exc.__suppress_context__
try:
raise exc
finally:
exc.__traceback__ = tb
exc.__context__ = context
exc.__cause__ = cause
exc.__suppress_context__ = sc
del exc, cause, tb, context


@asynccontextmanager
async def create_collapsing_task_group() -> typing.AsyncGenerator[anyio.abc.TaskGroup, None]:
with _collapse_excgroups():
async with anyio.create_task_group() as tg:
yield tg


def get_route_path(scope: Scope) -> str:
Expand Down
6 changes: 3 additions & 3 deletions starlette/middleware/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import anyio

from starlette._utils import collapse_excgroups
from starlette._utils import create_collapsing_task_group
from starlette.requests import ClientDisconnect, Request
from starlette.responses import AsyncContentStream, Response
from starlette.types import ASGIApp, Message, Receive, Scope, Send
Expand Down Expand Up @@ -173,8 +173,8 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]:
return response

send_stream, recv_stream = anyio.create_memory_object_stream[Message]()
with recv_stream, send_stream, collapse_excgroups():
async with anyio.create_task_group() as task_group:
with recv_stream, send_stream:
async with create_collapsing_task_group() as task_group:
response = await self.dispatch_func(request, call_next)
await response(scope, wrapped_receive, send)
response_sent.set()
Expand Down
3 changes: 2 additions & 1 deletion starlette/middleware/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import anyio
from anyio.abc import ObjectReceiveStream, ObjectSendStream

from starlette._utils import create_collapsing_task_group
from starlette.types import Receive, Scope, Send

warnings.warn(
Expand Down Expand Up @@ -102,7 +103,7 @@ async def __call__(self, receive: Receive, send: Send) -> None:
more_body = message.get("more_body", False)
environ = build_environ(self.scope, body)

async with anyio.create_task_group() as task_group:
async with create_collapsing_task_group() as task_group:
task_group.start_soon(self.sender, send)
async with self.stream_send:
await anyio.to_thread.run_sync(self.wsgi, environ, self.start_response)
Expand Down
3 changes: 2 additions & 1 deletion starlette/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import anyio
import anyio.to_thread

from starlette._utils import create_collapsing_task_group
from starlette.background import BackgroundTask
from starlette.concurrency import iterate_in_threadpool
from starlette.datastructures import URL, Headers, MutableHeaders
Expand Down Expand Up @@ -258,7 +259,7 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
except OSError:
raise ClientDisconnect()
else:
async with anyio.create_task_group() as task_group:
async with create_collapsing_task_group() as task_group:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we also have an issue with StreamingResponse standalone?

With my comment on the other PR I was trying to avoid the changes here. 🤔

Copy link
Member Author

@graingert graingert Dec 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it's the same sort of issue, but it's wrapped in two TaskGroups before the user gets the exception. Collapsing is ok here because if wait_for_disconnect raises an exception it is always 'catastrophic' eg won't be caught


async def wrap(func: typing.Callable[[], typing.Awaitable[None]]) -> None:
await func()
Expand Down
20 changes: 16 additions & 4 deletions tests/middleware/test_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextvars
import sys
from collections.abc import AsyncGenerator, AsyncIterator, Generator
from contextlib import AsyncExitStack
from typing import Any
Expand All @@ -21,6 +22,9 @@
from starlette.websockets import WebSocket
from tests.types import TestClientFactory

if sys.version_info < (3, 11):
from exceptiongroup import ExceptionGroup


class CustomMiddleware(BaseHTTPMiddleware):
async def dispatch(
Expand All @@ -41,6 +45,10 @@ def exc(request: Request) -> None:
raise Exception("Exc")


def eg(request: Request) -> None:
raise ExceptionGroup("my exception group", [ValueError("TEST")])


def exc_stream(request: Request) -> StreamingResponse:
return StreamingResponse(_generate_faulty_stream())

Expand Down Expand Up @@ -76,6 +84,7 @@ async def websocket_endpoint(session: WebSocket) -> None:
routes=[
Route("/", endpoint=homepage),
Route("/exc", endpoint=exc),
Route("/eg", endpoint=eg),
Route("/exc-stream", endpoint=exc_stream),
Route("/no-response", endpoint=NoResponse),
WebSocketRoute("/ws", endpoint=websocket_endpoint),
Expand All @@ -89,13 +98,16 @@ def test_custom_middleware(test_client_factory: TestClientFactory) -> None:
response = client.get("/")
assert response.headers["Custom-Header"] == "Example"

with pytest.raises(Exception) as ctx:
with pytest.raises(Exception) as ctx1:
response = client.get("/exc")
assert str(ctx.value) == "Exc"
assert str(ctx1.value) == "Exc"

with pytest.raises(Exception) as ctx:
with pytest.raises(Exception) as ctx2:
response = client.get("/exc-stream")
assert str(ctx.value) == "Faulty Stream"
assert str(ctx2.value) == "Faulty Stream"

with pytest.raises(ExceptionGroup, match=r"my exception group \(1 sub-exception\)"):
client.get("/eg")

with pytest.raises(RuntimeError):
response = client.get("/no-response")
Expand Down
3 changes: 1 addition & 2 deletions tests/middleware/test_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import pytest

from starlette._utils import collapse_excgroups
from starlette.middleware.wsgi import WSGIMiddleware, build_environ
from tests.types import TestClientFactory

Expand Down Expand Up @@ -86,7 +85,7 @@ def test_wsgi_exception(test_client_factory: TestClientFactory) -> None:
# The HTTP protocol implementations would catch this error and return 500.
app = WSGIMiddleware(raise_exception)
client = test_client_factory(app)
with pytest.raises(RuntimeError), collapse_excgroups():
with pytest.raises(RuntimeError):
client.get("/")


Expand Down
Loading