Skip to content

Commit

Permalink
Improves error messages for mismatched terms.
Browse files Browse the repository at this point in the history
Serendipitously, this can also use the new `wadler_lindig` library for pretty-printing complicated types.
  • Loading branch information
patrick-kidger committed Jan 5, 2025
1 parent c9a214d commit 4a308b8
Show file tree
Hide file tree
Showing 5 changed files with 32 additions and 4 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ repos:
rev: v1.1.350
hooks:
- id: pyright
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions]
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions, wadler_lindig]
8 changes: 8 additions & 0 deletions diffrax/_custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ class SpaceTimeTimeLevyArea(AbstractSpaceTimeTimeLevyArea):
K: BM


AbstractBrownianIncrement.__module__ = "diffrax"
AbstractSpaceTimeLevyArea.__module__ = "diffrax"
AbstractSpaceTimeTimeLevyArea.__module__ = "diffrax"
BrownianIncrement.__module__ = "diffrax"
SpaceTimeLevyArea.__module__ = "diffrax"
SpaceTimeTimeLevyArea.__module__ = "diffrax"


def levy_tree_transpose(
tree_shape, tree: PyTree[AbstractBrownianIncrement]
) -> AbstractBrownianIncrement:
Expand Down
15 changes: 13 additions & 2 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import lineax.internal as lxi
import numpy as np
import optimistix as optx
import wadler_lindig as wl
from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real

from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint
Expand Down Expand Up @@ -184,7 +185,11 @@ def _check(term_cls, term, term_contr_kwargs, yi):
better_isinstance, control_type, control_type_expected
)
if not control_type_compatible:
raise ValueError(f"Control term {term} is incompatible.")
raise ValueError(
"Control term is incompatible: the returned control (e.g. "
f"Brownian motion for an SDE) was {control_type}, but this "
f"solver expected {control_type_expected}."
)
else:
assert False, "Malformed term structure"
# If we've got to this point then the term is compatible
Expand All @@ -194,7 +199,13 @@ def _check(term_cls, term, term_contr_kwargs, yi):
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
except Exception as e:
# ValueError may also arise from mismatched tree structures
raise ValueError("Terms are not compatible with solver!") from e
pretty_term = wl.pformat(terms)
pretty_expected = wl.pformat(term_structure)
raise ValueError(
f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:"
f"\n{pretty_expected}\nNote that terms are checked recursively: if you "
"scroll up you may find a root-cause error that is more specific."
) from e


def _is_subsaveat(x: Any) -> bool:
Expand Down
9 changes: 9 additions & 0 deletions diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,3 +981,12 @@ def prod(
self, vf: UnderdampedLangevinTuple, control: RealScalarLike
) -> UnderdampedLangevinTuple:
return jtu.tree_map(lambda _vf: control * _vf, vf)


AbstractTerm.__module__ = "diffrax"
ODETerm.__module__ = "diffrax"
ControlTerm.__module__ = "diffrax"
WeaklyDiagonalControlTerm.__module__ = "diffrax"
MultiTerm.__module__ = "diffrax"
UnderdampedLangevinDriftTerm.__module__ = "diffrax"
UnderdampedLangevinDiffusionTerm.__module__ = "diffrax"
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
urls = {repository = "https://github.com/patrick-kidger/diffrax" }
dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.10"]
dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.10", "wadler_lindig>=0.1.1"]

[build-system]
requires = ["hatchling"]
Expand Down

0 comments on commit 4a308b8

Please sign in to comment.