diff --git a/README.md b/README.md index 98655ca7..c8dea370 100644 --- a/README.md +++ b/README.md @@ -61,16 +61,24 @@ If you found this library useful in academic research, please cite: [(arXiv link ## See also: other libraries in the JAX ecosystem +[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays. + [Equinox](https://github.com/patrick-kidger/equinox): neural networks. [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers. -[Lineax](https://github.com/google/lineax): linear solvers and linear least squares. +[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares. -[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays. +[Lineax](https://github.com/google/lineax): linear solvers. -[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models. +[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling. + +[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device). [sympy2jax](https://github.com/google/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent. +[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models. + [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs). + +[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!) diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index 59a19b0c..e64e4970 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -530,6 +530,7 @@ def _loop_backsolve_bwd( throw, init_state, ): + assert discrete_terminating_event is None # # Unpack our various arguments. Delete a lot of things just to make sure we're not @@ -565,7 +566,6 @@ def _loop_backsolve_bwd( adjoint=self, solver=solver, stepsize_controller=stepsize_controller, - discrete_terminating_event=discrete_terminating_event, terms=adjoint_terms, dt0=None if dt0 is None else -dt0, max_steps=max_steps, @@ -744,6 +744,7 @@ def loop( init_state, passed_solver_state, passed_controller_state, + discrete_terminating_event, **kwargs, ): if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure( @@ -785,6 +786,10 @@ def loop( "`diffrax.BacksolveAdjoint` is only compatible with solvers that take " "a single term." ) + if discrete_terminating_event is not None: + raise NotImplementedError( + "`diffrax.BacksolveAdjoint` is not compatible with events." + ) y = init_state.y init_state = eqx.tree_at(lambda s: s.y, init_state, object()) @@ -798,6 +803,7 @@ def loop( saveat=saveat, init_state=init_state, solver=solver, + discrete_terminating_event=discrete_terminating_event, **kwargs, ) final_state = _only_transpose_ys(final_state) diff --git a/diffrax/brownian/path.py b/diffrax/brownian/path.py index c1d5f95f..a844450a 100644 --- a/diffrax/brownian/path.py +++ b/diffrax/brownian/path.py @@ -32,7 +32,7 @@ class UnsafeBrownianPath(AbstractBrownianPath): correlation structure isn't needed.) """ - shape: PyTree[jax.ShapeDtypeStruct] = eqx.static_field() + shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) # Handled as a string because PRNGKey is actually a function, not a class, which # makes it appearly badly in autogenerated documentation. key: "jax.random.PRNGKey" # noqa: F821 diff --git a/diffrax/brownian/tree.py b/diffrax/brownian/tree.py index ad93f641..577092ce 100644 --- a/diffrax/brownian/tree.py +++ b/diffrax/brownian/tree.py @@ -60,7 +60,7 @@ class VirtualBrownianTree(AbstractBrownianPath): t0: Scalar = field(init=True) t1: Scalar = field(init=True) # override init=False in AbstractPath tol: Scalar - shape: PyTree[jax.ShapeDtypeStruct] = eqx.static_field() + shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) key: "jax.random.PRNGKey" # noqa: F821 def __init__( diff --git a/diffrax/global_interpolation.py b/diffrax/global_interpolation.py index 5900bb45..e24224d6 100644 --- a/diffrax/global_interpolation.py +++ b/diffrax/global_interpolation.py @@ -67,6 +67,8 @@ class LinearInterpolation(AbstractGlobalInterpolation): ys: PyTree[Array["times", ...]] # noqa: F821 def __post_init__(self): + super().__post_init__() + def _check(_ys): if _ys.shape[0] != self.ts.shape[0]: raise ValueError( @@ -179,6 +181,8 @@ class CubicInterpolation(AbstractGlobalInterpolation): ] def __post_init__(self): + super().__post_init__() + def _check(d, c, b, a): error_msg = ( "Each cubic coefficient must have `times - 1` entries, where " @@ -287,12 +291,14 @@ def derivative(self, t: Scalar, left: bool = True) -> PyTree: class DenseInterpolation(AbstractGlobalInterpolation): ts_size: Int # Takes values in {1, 2, 3, ...} infos: DenseInfos - interpolation_cls: Type[AbstractLocalInterpolation] = eqx.static_field() + interpolation_cls: Type[AbstractLocalInterpolation] = eqx.field(static=True) direction: Scalar t0_if_trivial: Array y0_if_trivial: PyTree[Array] def __post_init__(self): + super().__post_init__() + def _check(_infos): assert _infos.shape[0] + 1 == self.ts.shape[0] diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 814c90d9..d1a1b5a1 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -10,7 +10,7 @@ import jax.tree_util as jtu from jax.typing import ArrayLike -from .adjoint import AbstractAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint +from .adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint from .custom_types import Array, Bool, Int, PyTree, Scalar from .event import AbstractDiscreteTerminatingEvent from .global_interpolation import DenseInterpolation @@ -415,6 +415,11 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: ) new_state = eqx.tree_at(lambda s: s.result, new_state, result) + if not _filtering: + # This is only necessary for Equinox <0.11.1. + # After that, this fix has been upstreamed to Equinox. + # TODO: remove once we make Equinox >=0.11.1 required. + new_state = jtu.tree_map(jnp.array, new_state) return new_state _filtering = True @@ -633,22 +638,6 @@ def diffeqsolve( "An SDE should not be solved with adaptive step sizes with Euler's " "method, as it may not converge to the correct solution." ) - # TODO: remove these lines. - # - # These are to work around an edge case: on the backward pass, - # RecursiveCheckpointAdjoint currently tries to differentiate the overall - # per-step function wrt all floating-point arrays. In particular this includes - # `state.tprev`, which feeds into the control, which feeds into - # VirtualBrownianTree, which can't be differentiated. - # We're waiting on JAX to offer a way of specifying which arguments to a - # custom_vjp have symbolic zero *tangents* (not cotangents) so that we can more - # precisely determine what to differentiate wrt. - # - # We don't replace this in the case of an unsafe SDE because - # RecursiveCheckpointAdjoint will raise an error in that case anyway, so we - # should let the normal error be raised. - if isinstance(adjoint, RecursiveCheckpointAdjoint) and not is_unsafe_sde(terms): - adjoint = DirectAdjoint() if is_unsafe_sde(terms): if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): raise ValueError( diff --git a/diffrax/nonlinear_solver/newton.py b/diffrax/nonlinear_solver/newton.py index 1458a8d0..c167bb1c 100644 --- a/diffrax/nonlinear_solver/newton.py +++ b/diffrax/nonlinear_solver/newton.py @@ -141,7 +141,7 @@ def body_fn(val): val = (flat, step + 1, diffsize, diffsize_prev) return val - val = (flat, 0, 0.0, 0.0) + val = (flat, 0, jnp.array(0.0), jnp.array(0.0)) val = lax.while_loop(cond_fn, body_fn, val) flat, num_steps, diffsize, diffsize_prev = val diff --git a/diffrax/solution.py b/diffrax/solution.py index 8ae10f31..0d1eb485 100644 --- a/diffrax/solution.py +++ b/diffrax/solution.py @@ -13,12 +13,16 @@ class RESULTS(metaclass=eqxi.ContainerMeta): successful = "" discrete_terminating_event_occurred = ( - "Terminating solve because a discrete event occurred." + "Terminating differential equation solve because a discrete terminating event " + "occurred." ) max_steps_reached = ( - "The maximum number of solver steps was reached. Try increasing `max_steps`." + "The maximum number of steps was reached in the differential equation solver. " + "Try increasing `diffrax.diffeqsolve(..., max_steps=...)`." + ) + dt_min_reached = ( + "The minimum step size was reached in the differential equation solver." ) - dt_min_reached = "The minimum step size was reached." implicit_divergence = "Implicit method diverged." implicit_nonconvergence = ( "Implicit method did not converge within the required number of iterations." diff --git a/diffrax/step_size_controller/adaptive.py b/diffrax/step_size_controller/adaptive.py index f794cfb6..71d2761d 100644 --- a/diffrax/step_size_controller/adaptive.py +++ b/diffrax/step_size_controller/adaptive.py @@ -423,8 +423,8 @@ def adapt_step_size( # h_n is the nth step size # ε_n = atol + norm(y) * rtol with y on the nth step # r_n = norm(y_error) with y_error on the nth step - # δ_{n,m} = norm(y_error / (atol + norm(y) * rtol)) with y_error on the nth - # step and y on the mth step + # δ_{n,m} = norm(y_error / (atol + norm(y) * rtol))^(-1) with y_error on the nth + # step and y on the mth step # β_1 = pcoeff + icoeff + dcoeff # β_2 = -(pcoeff + 2 * dcoeff) # β_3 = dcoeff diff --git a/diffrax/term.py b/diffrax/term.py index ba8bd506..2136440b 100644 --- a/diffrax/term.py +++ b/diffrax/term.py @@ -409,7 +409,9 @@ def is_vf_expensive( y: Tuple[PyTree, PyTree, PyTree, PyTree], args: PyTree, ) -> bool: - return self.term.is_vf_expensive(t0, t1, y, args) + _t0 = jnp.where(self.direction == 1, t0, -t1) + _t1 = jnp.where(self.direction == 1, t1, -t0) + return self.term.is_vf_expensive(_t0, _t1, y, args) class AdjointTerm(AbstractTerm): @@ -422,8 +424,8 @@ def is_vf_expensive( y: Tuple[PyTree, PyTree, PyTree, PyTree], args: PyTree, ) -> bool: - control = self.contr(t0, t1) - if sum(c.size for c in jtu.tree_leaves(control)) in (0, 1): + control_struct = jax.eval_shape(self.contr, t0, t1) + if sum(c.size for c in jtu.tree_leaves(control_struct)) in (0, 1): return False else: return True diff --git a/docs/usage/getting-started.md b/docs/usage/getting-started.md index c690c52e..f024b309 100644 --- a/docs/usage/getting-started.md +++ b/docs/usage/getting-started.md @@ -99,15 +99,6 @@ print(sol.evaluate(1.1)) # DeviceArray(0.89436394) As you can see, basically nothing has changed compared to the ODE example; all the same APIs are used. The only difference is that we created an SDE solver rather than an ODE solver. -!!! info - - If using some SDE-specific solvers, for example [`diffrax.ItoMilstein`][], then the solver makes a distinction between drift and diffusion. (In the previous example, the solver [`diffrax.Euler`][] is completely oblivious to this distinction. In this case the drift and diffusion should be passed separately as a 2-tuple of terms, rather than wrapped into a single [`diffrax.MultiTerm`][]. This would involve changing the above example with: - - ```python - terms = (ODETerm(drift), ControlTerm(diffusion, brownian_motion)) - solver = ItoMilstein() - ``` - !!! info To do adaptive stepping with an SDE, then the typical approach is to wrap the solver like so -- and to use the following default step size controller: diff --git a/docs/usage/manual-stepping.md b/docs/usage/manual-stepping.md index 93b8f12d..68ff5878 100644 --- a/docs/usage/manual-stepping.md +++ b/docs/usage/manual-stepping.md @@ -11,6 +11,7 @@ In the following example, we solve an ODE using [`diffrax.Tsit5`][], and print o See the [Abstract solvers](../api/solvers/abstract_solvers.md) page for a reference on the solver methods (`init`, `step`) used here. ```python +import jax.numpy as jnp from diffrax import ODETerm, Tsit5 vector_field = lambda t, y, args: -y @@ -20,7 +21,7 @@ solver = Tsit5() t0 = 0 dt0 = 0.05 t1 = 1 -y0 = 1 +y0 = jnp.array(1.0) args = None tprev = t0 diff --git a/examples/nonlinear_heat_pde.ipynb b/examples/nonlinear_heat_pde.ipynb index c7fc00cd..eaf2c9e8 100644 --- a/examples/nonlinear_heat_pde.ipynb +++ b/examples/nonlinear_heat_pde.ipynb @@ -85,8 +85,8 @@ "source": [ "# Represents the interval [x0, x_final] discretised into n equally-spaced points.\n", "class SpatialDiscretisation(eqx.Module):\n", - " x0: float = eqx.static_field()\n", - " x_final: float = eqx.static_field()\n", + " x0: float = eqx.field(static=True)\n", + " x_final: float = eqx.field(static=True)\n", " vals: Float[Array, \"n\"]\n", "\n", " @classmethod\n", diff --git a/setup.py b/setup.py index 908c1d58..618c8a79 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ python_requires = "~=3.9" -install_requires = ["jax>=0.4.13", "equinox>=0.10.11"] +install_requires = ["jax>=0.4.13", "equinox>=0.11.1"] setuptools.setup( name=name, diff --git a/test/test_event.py b/test/test_event.py index 3129be6c..5ba8cc87 100644 --- a/test/test_event.py +++ b/test/test_event.py @@ -1,5 +1,7 @@ import diffrax +import jax import jax.numpy as jnp +import pytest def test_discrete_terminate1(): @@ -9,7 +11,12 @@ def test_discrete_terminate1(): t1 = jnp.inf dt0 = 1 y0 = 1.0 - event = diffrax.DiscreteTerminatingEvent(lambda state, **kwargs: state.y > 10) + + def event_fn(state, **kwargs): + assert isinstance(state.y, jax.Array) + return state.tprev > 10 + + event = diffrax.DiscreteTerminatingEvent(event_fn) sol = diffrax.diffeqsolve( term, solver, t0, t1, dt0, y0, discrete_terminating_event=event ) @@ -23,11 +30,50 @@ def test_discrete_terminate2(): t1 = jnp.inf dt0 = 1 y0 = 1.0 - event = diffrax.DiscreteTerminatingEvent(lambda state, **kwargs: state.tprev > 10) + + def event_fn(state, **kwargs): + assert isinstance(state.y, jax.Array) + return state.tprev > 10 + + event = diffrax.DiscreteTerminatingEvent(event_fn) sol = diffrax.diffeqsolve( term, solver, t0, t1, dt0, y0, discrete_terminating_event=event ) assert jnp.all(sol.ts > 10) +def test_event_backsolve(): + term = diffrax.ODETerm(lambda t, y, args: y) + solver = diffrax.Tsit5() + t0 = 0 + t1 = jnp.inf + dt0 = 1 + y0 = 1.0 + + def event_fn(state, **kwargs): + assert isinstance(state.y, jax.Array) + return state.tprev > 10 + + event = diffrax.DiscreteTerminatingEvent(event_fn) + + @jax.jit + @jax.grad + def run(y0): + sol = diffrax.diffeqsolve( + term, + solver, + t0, + t1, + dt0, + y0, + discrete_terminating_event=event, + adjoint=diffrax.BacksolveAdjoint(), + ) + return jnp.sum(sol.ys) + + # And in particular not some other error. + with pytest.raises(NotImplementedError): + run(y0) + + # diffrax.SteadyStateEvent tested as part of test_adjoint.py::test_implicit diff --git a/test/test_integrate.py b/test/test_integrate.py index d2e9cd84..0df3bff6 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -418,3 +418,27 @@ def run(y0): assert sol.made_jump is False run(1) + + +def test_no_jit(): + # https://github.com/patrick-kidger/diffrax/issues/293 + # https://github.com/patrick-kidger/diffrax/issues/321 + + # Test that this doesn't crash. + with jax.disable_jit(): + + def vector_field(t, y, args): + return jnp.zeros_like(y) + + term = diffrax.ODETerm(vector_field) + y = jnp.zeros((1,)) + stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5) + diffrax.diffeqsolve( + term, + diffrax.Kvaerno4(), + t0=0, + t1=1e-2, + dt0=1e-3, + stepsize_controller=stepsize_controller, + y0=y, + )