diff --git a/newsfragments/3175.doc.rst b/newsfragments/3175.doc.rst new file mode 100644 index 0000000000..7c8475f4e7 --- /dev/null +++ b/newsfragments/3175.doc.rst @@ -0,0 +1 @@ +Warn if a user forgot to implement ``.derive`` for an ExceptionGroup subclass. diff --git a/src/trio/_core/_run.py b/src/trio/_core/_run.py index 5dbaa18cab..0f243c4ff5 100644 --- a/src/trio/_core/_run.py +++ b/src/trio/_core/_run.py @@ -52,7 +52,7 @@ ) if sys.version_info < (3, 11): - from exceptiongroup import BaseExceptionGroup + from exceptiongroup import BaseExceptionGroup, ExceptionGroup if TYPE_CHECKING: @@ -639,6 +639,31 @@ def _close(self, exc: BaseException | None) -> BaseException | None: self.cancelled_caught = True exc = None elif isinstance(exc, BaseExceptionGroup): + # sanity check users + egs = [exc] + visited = set() + while egs: + next_eg = egs.pop() + if next_eg in visited: + continue + visited.add(next_eg) + if ( + "derive" not in type(next_eg).__dict__ + and type(next_eg) is not ExceptionGroup + ): + warnings.warn( + f"derive not implemented for {type(next_eg).__name__}, results may be unexpected", + stacklevel=1, + ) + + egs.extend( + [ + e + for e in next_eg.exceptions + if isinstance(e, BaseExceptionGroup) + ] + ) + matched, exc = exc.split(Cancelled) if matched: self.cancelled_caught = True diff --git a/src/trio/_core/_tests/test_run.py b/src/trio/_core/_tests/test_run.py index 75e5457d78..073ac0f281 100644 --- a/src/trio/_core/_tests/test_run.py +++ b/src/trio/_core/_tests/test_run.py @@ -44,8 +44,11 @@ Awaitable, Callable, Generator, + Sequence, ) + from typing_extensions import Self + if sys.version_info < (3, 11): from exceptiongroup import BaseExceptionGroup, ExceptionGroup @@ -2855,3 +2858,65 @@ def run(self, fn: Callable[[], object]) -> object: with mock.patch("trio._core._run.copy_context", return_value=Context()): assert _count_context_run_tb_frames() == 1 + + +def test_run_with_custom_exception_group() -> None: + class ExceptionGroupForTest(ExceptionGroup[Exception]): + @staticmethod + def for_test(message: str, excs: list[Exception]) -> ExceptionGroupForTest: + raise NotImplementedError() + + async def check1(exception_group_type: type[ExceptionGroupForTest]) -> None: + raise exception_group_type.for_test("test message", [ValueError("uh oh")]) + + async def check2(exception_group_type: type[ExceptionGroupForTest]) -> None: + with _core.CancelScope(): + raise exception_group_type.for_test("test message", [ValueError("uh oh")]) + + async def check3(exception_group_type: type[ExceptionGroupForTest]) -> None: + async with _core.open_nursery(): + raise exception_group_type.for_test("test message", [ValueError("uh oh")]) + + class HasDerive(ExceptionGroupForTest): + def derive(self, excs: Sequence[Exception]) -> HasDerive: + return HasDerive(self.message, excs) + + @staticmethod + def for_test(message: str, excs: list[Exception]) -> HasDerive: + return HasDerive(message, excs) + + class NormalNew(ExceptionGroupForTest): + @staticmethod + def for_test(message: str, excs: list[Exception]) -> NormalNew: + return NormalNew(message, excs) + + class AbnormalNew(ExceptionGroupForTest): + def __new__(cls, excs: Sequence[Exception]) -> Self: + return super().__new__(cls, f"has {len(excs)} exceptions", excs) + + @staticmethod + def for_test(message: str, excs: list[Exception]) -> AbnormalNew: + return AbnormalNew(excs) + + for check in (check1, check2, check3): + for error in [HasDerive, NormalNew, AbnormalNew]: + if check is check3: + if error in (NormalNew, AbnormalNew): + with ( + pytest.warns(UserWarning, match="^derive not implemented"), + pytest.raises(ExceptionGroup) as e, + ): + _core.run(check, error) + + error = ExceptionGroup # we don't provide something better + else: + with pytest.raises(ExceptionGroup) as e: + _core.run(check, error) + + assert len(e.value.exceptions) == 1 + assert isinstance(e.value.exceptions[0], error) + else: + with pytest.raises(error): + _core.run(check, error) + + print(f"{check} + {error} PASSED")