Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a sanity check for a missing derive for exception groups #3176

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions newsfragments/3175.doc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Warn if a user forgot to implement ``.derive`` for an ExceptionGroup subclass.
27 changes: 26 additions & 1 deletion src/trio/_core/_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup
from exceptiongroup import BaseExceptionGroup, ExceptionGroup


if TYPE_CHECKING:
Expand Down Expand Up @@ -639,6 +639,31 @@ def _close(self, exc: BaseException | None) -> BaseException | None:
self.cancelled_caught = True
exc = None
elif isinstance(exc, BaseExceptionGroup):
# sanity check users
egs = [exc]
visited = set()
while egs:
next_eg = egs.pop()
if next_eg in visited:
continue
visited.add(next_eg)
if (
"derive" not in type(next_eg).__dict__
and type(next_eg) is not ExceptionGroup
):
warnings.warn(
f"derive not implemented for {type(next_eg).__name__}, results may be unexpected",
stacklevel=1,
)

egs.extend(
[
e
for e in next_eg.exceptions
if isinstance(e, BaseExceptionGroup)
]
)

matched, exc = exc.split(Cancelled)
if matched:
self.cancelled_caught = True
Expand Down
65 changes: 65 additions & 0 deletions src/trio/_core/_tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@
Awaitable,
Callable,
Generator,
Sequence,
)

from typing_extensions import Self

if sys.version_info < (3, 11):
from exceptiongroup import BaseExceptionGroup, ExceptionGroup

Expand Down Expand Up @@ -2855,3 +2858,65 @@

with mock.patch("trio._core._run.copy_context", return_value=Context()):
assert _count_context_run_tb_frames() == 1


def test_run_with_custom_exception_group() -> None:
class ExceptionGroupForTest(ExceptionGroup[Exception]):
@staticmethod
def for_test(message: str, excs: list[Exception]) -> ExceptionGroupForTest:
raise NotImplementedError()

async def check1(exception_group_type: type[ExceptionGroupForTest]) -> None:
raise exception_group_type.for_test("test message", [ValueError("uh oh")])

async def check2(exception_group_type: type[ExceptionGroupForTest]) -> None:
with _core.CancelScope():
raise exception_group_type.for_test("test message", [ValueError("uh oh")])

async def check3(exception_group_type: type[ExceptionGroupForTest]) -> None:
async with _core.open_nursery():
raise exception_group_type.for_test("test message", [ValueError("uh oh")])

class HasDerive(ExceptionGroupForTest):
def derive(self, excs: Sequence[Exception]) -> HasDerive:
return HasDerive(self.message, excs)

Check failure on line 2882 in src/trio/_core/_tests/test_run.py

View workflow job for this annotation

GitHub Actions / Ubuntu (3.13, check formatting)

Mypy-Linux+Mac+Windows

src/trio/_core/_tests/test_run.py:(2881:9 - 2882:48): Signature of "derive" incompatible with supertype "BaseExceptionGroup" [override]

Check notice on line 2882 in src/trio/_core/_tests/test_run.py

View workflow job for this annotation

GitHub Actions / Ubuntu (3.13, check formatting)

Mypy-Linux+Mac+Windows

src/trio/_core/_tests/test_run.py:(2881:9 - 2882:48): Superclass:

Check notice on line 2882 in src/trio/_core/_tests/test_run.py

View workflow job for this annotation

GitHub Actions / Ubuntu (3.13, check formatting)

Mypy-Linux+Linux+Mac+Mac+Windows+Windows

src/trio/_core/_tests/test_run.py:(2881:9 - 2882:48): @overload

Check notice on line 2882 in src/trio/_core/_tests/test_run.py

View workflow job for this annotation

GitHub Actions / Ubuntu (3.13, check formatting)

Mypy-Linux+Mac+Windows

src/trio/_core/_tests/test_run.py:(2881:9 - 2882:48): def [_ExceptionT: Exception] derive(self, Sequence[_ExceptionT], /) -> ExceptionGroup[_ExceptionT]

Check notice on line 2882 in src/trio/_core/_tests/test_run.py

View workflow job for this annotation

GitHub Actions / Ubuntu (3.13, check formatting)

Mypy-Linux+Mac+Windows

src/trio/_core/_tests/test_run.py:(2881:9 - 2882:48): def [_BaseExceptionT: BaseException] derive(self, Sequence[_BaseExceptionT], /) -> BaseExceptionGroup[_BaseExceptionT]

Check notice on line 2882 in src/trio/_core/_tests/test_run.py

View workflow job for this annotation

GitHub Actions / Ubuntu (3.13, check formatting)

Mypy-Linux+Mac+Windows

src/trio/_core/_tests/test_run.py:(2881:9 - 2882:48): Subclass:

Check notice on line 2882 in src/trio/_core/_tests/test_run.py

View workflow job for this annotation

GitHub Actions / Ubuntu (3.13, check formatting)

Mypy-Linux+Mac+Windows

src/trio/_core/_tests/test_run.py:(2881:9 - 2882:48): def derive(self, excs: Sequence[Exception]) -> HasDerive

@staticmethod
def for_test(message: str, excs: list[Exception]) -> HasDerive:
return HasDerive(message, excs)

class NormalNew(ExceptionGroupForTest):
@staticmethod
def for_test(message: str, excs: list[Exception]) -> NormalNew:
return NormalNew(message, excs)

class AbnormalNew(ExceptionGroupForTest):
def __new__(cls, excs: Sequence[Exception]) -> Self:
return super().__new__(cls, f"has {len(excs)} exceptions", excs)

@staticmethod
def for_test(message: str, excs: list[Exception]) -> AbnormalNew:
return AbnormalNew(excs)

for check in (check1, check2, check3):
for error in [HasDerive, NormalNew, AbnormalNew]:
if check is check3:
if error in (NormalNew, AbnormalNew):
with (
pytest.warns(UserWarning, match="^derive not implemented"),
pytest.raises(ExceptionGroup) as e,
):
_core.run(check, error)

error = ExceptionGroup # we don't provide something better
else:
with pytest.raises(ExceptionGroup) as e:
_core.run(check, error)

assert len(e.value.exceptions) == 1
assert isinstance(e.value.exceptions[0], error)
else:
with pytest.raises(error):
_core.run(check, error)

print(f"{check} + {error} PASSED")
Loading