Skip to content

Commit

Permalink
Fixed spurious failure of test/test_adjoint.py::test_implicit
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jun 2, 2024
1 parent 0a59c9d commit c6cc85c
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,14 @@ def _solve(inputs):
)


# Unwrap jaxtyping decorator during tests, so that these are global functions.
# This is needed to ensure `optx.implicit_jvp` is happy.
if _vf.__globals__["__name__"].startswith("jaxtyping"):
_vf = _vf.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]
if _solve.__globals__["__name__"].startswith("jaxtyping"):
_solve = _solve.__wrapped__ # pyright: ignore[reportFunctionMemberAccess]


def _frozenset(x: Union[object, Iterable[object]]) -> frozenset[object]:
try:
iter_x = iter(x) # pyright: ignore
Expand Down

0 comments on commit c6cc85c

Please sign in to comment.