diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f12e8509..13be59cf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,4 +11,4 @@ repos: rev: v1.1.350 hooks: - id: pyright - additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions] + additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions, wadler_lindig] diff --git a/diffrax/_custom_types.py b/diffrax/_custom_types.py index 70ec5a1a..7e08aa1b 100644 --- a/diffrax/_custom_types.py +++ b/diffrax/_custom_types.py @@ -112,6 +112,14 @@ class SpaceTimeTimeLevyArea(AbstractSpaceTimeTimeLevyArea): K: BM +AbstractBrownianIncrement.__module__ = "diffrax" +AbstractSpaceTimeLevyArea.__module__ = "diffrax" +AbstractSpaceTimeTimeLevyArea.__module__ = "diffrax" +BrownianIncrement.__module__ = "diffrax" +SpaceTimeLevyArea.__module__ = "diffrax" +SpaceTimeTimeLevyArea.__module__ = "diffrax" + + def levy_tree_transpose( tree_shape, tree: PyTree[AbstractBrownianIncrement] ) -> AbstractBrownianIncrement: diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 6a31fe59..6e5d278c 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -22,6 +22,7 @@ import lineax.internal as lxi import numpy as np import optimistix as optx +import wadler_lindig as wl from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint @@ -184,7 +185,11 @@ def _check(term_cls, term, term_contr_kwargs, yi): better_isinstance, control_type, control_type_expected ) if not control_type_compatible: - raise ValueError(f"Control term {term} is incompatible.") + raise ValueError( + "Control term is incompatible: the returned control (e.g. " + f"Brownian motion for an SDE) was {control_type}, but this " + f"solver expected {control_type_expected}." + ) else: assert False, "Malformed term structure" # If we've got to this point then the term is compatible @@ -194,7 +199,13 @@ def _check(term_cls, term, term_contr_kwargs, yi): jtu.tree_map(_check, term_structure, terms, contr_kwargs, y) except Exception as e: # ValueError may also arise from mismatched tree structures - raise ValueError("Terms are not compatible with solver!") from e + pretty_term = wl.pformat(terms) + pretty_expected = wl.pformat(term_structure) + raise ValueError( + f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:" + f"\n{pretty_expected}\nNote that terms are checked recursively: if you " + "scroll up you may find a root-cause error that is more specific." + ) from e def _is_subsaveat(x: Any) -> bool: diff --git a/diffrax/_term.py b/diffrax/_term.py index bacaef9d..efa28d29 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -981,3 +981,12 @@ def prod( self, vf: UnderdampedLangevinTuple, control: RealScalarLike ) -> UnderdampedLangevinTuple: return jtu.tree_map(lambda _vf: control * _vf, vf) + + +AbstractTerm.__module__ = "diffrax" +ODETerm.__module__ = "diffrax" +ControlTerm.__module__ = "diffrax" +WeaklyDiagonalControlTerm.__module__ = "diffrax" +MultiTerm.__module__ = "diffrax" +UnderdampedLangevinDriftTerm.__module__ = "diffrax" +UnderdampedLangevinDiffusionTerm.__module__ = "diffrax" diff --git a/pyproject.toml b/pyproject.toml index 6b82ac17..01cacf52 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", ] urls = {repository = "https://github.com/patrick-kidger/diffrax" } -dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.10"] +dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.10", "wadler_lindig>=0.1.1"] [build-system] requires = ["hatchling"]