From c9a214dfa931efbda856d0bfc13f4e81c74c83a9 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 5 Jan 2025 16:13:11 +0100 Subject: [PATCH] SRKs now support forward-mode autodiff. --- diffrax/_adjoint.py | 19 +++++++++++++++---- diffrax/_solver/srk.py | 6 ++++-- test/test_adjoint.py | 33 +++++++++++++++++++++++++++++---- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/diffrax/_adjoint.py b/diffrax/_adjoint.py index 769e49bd..db701bd2 100644 --- a/diffrax/_adjoint.py +++ b/diffrax/_adjoint.py @@ -16,7 +16,12 @@ from ._heuristics import is_sde, is_unsafe_sde from ._saveat import save_y, SaveAt, SubSaveAt -from ._solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver +from ._solver import ( + AbstractItoSolver, + AbstractRungeKutta, + AbstractSRK, + AbstractStratonovichSolver, +) from ._term import AbstractTerm, AdjointTerm @@ -272,7 +277,7 @@ def loop( if is_unsafe_sde(terms): raise ValueError( "`adjoint=RecursiveCheckpointAdjoint()` does not support " - "`UnsafeBrownianPath`. Consider using `adjoint=DirectAdjoint()` " + "`UnsafeBrownianPath`. Consider using `adjoint=ForwardMode()` " "instead." ) if self.checkpoints is None and max_steps is None: @@ -376,7 +381,10 @@ def loop( msg = None # Support forward-mode autodiff. # TODO: remove this hack once we can JVP through custom_vjps. - if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None: + if ( + isinstance(solver, (AbstractRungeKutta, AbstractSRK)) + and solver.scan_kind is None + ): solver = eqx.tree_at( lambda s: s.scan_kind, solver, "bounded", is_leaf=_is_none ) @@ -888,7 +896,10 @@ def loop( outer_while_loop = eqx.Partial(_outer_loop, kind="lax") # Support forward-mode autodiff. # TODO: remove this hack once we can JVP through custom_vjps. - if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None: + if ( + isinstance(solver, (AbstractRungeKutta, AbstractSRK)) + and solver.scan_kind is None + ): solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none) final_state = self._loop( solver=solver, diff --git a/diffrax/_solver/srk.py b/diffrax/_solver/srk.py index fba6120d..56be17ba 100644 --- a/diffrax/_solver/srk.py +++ b/diffrax/_solver/srk.py @@ -1,6 +1,6 @@ import abc from dataclasses import dataclass -from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar, Union +from typing import Any, Generic, Literal, Optional, TYPE_CHECKING, TypeVar, Union from typing_extensions import TypeAlias import equinox as eqx @@ -255,6 +255,8 @@ class AbstractSRK(AbstractSolver[_SolverState]): as well as $b^H$, $a^H$, $b^K$, and $a^K$ if needed. """ + scan_kind: Union[None, Literal["lax", "checkpointed"]] = None + interpolation_cls = LocalLinearInterpolation term_compatible_contr_kwargs = (dict(), dict(use_levy=True)) tableau: AbstractClassVar[StochasticButcherTableau] @@ -583,7 +585,7 @@ def compute_and_insert_kg_j(_w_kgs_in, _levylist_kgs_in): scan_inputs, len(b_sol), buffers=lambda x: x, - kind="checkpointed", + kind="checkpointed" if self.scan_kind is None else self.scan_kind, checkpoints="all", ) diff --git a/test/test_adjoint.py b/test/test_adjoint.py index 12e6ee27..c45c6286 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -366,10 +366,7 @@ def run(model): run(mlp) -@pytest.mark.parametrize( - "diffusion_fn", - ["weak", "lineax"], -) +@pytest.mark.parametrize("diffusion_fn", ["weak", "lineax"]) def test_sde_against(diffusion_fn, getkey): def f(t, y, args): del t @@ -427,3 +424,31 @@ def test_implicit_runge_kutta_direct_adjoint(): adjoint=diffrax.DirectAdjoint(), stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6), ) + + +@pytest.mark.parametrize("solver", (diffrax.Tsit5(), diffrax.GeneralShARK())) +def test_forward_mode_runge_kutta(solver, getkey): + # Totally fine that we're using Tsit5 with an SDE, it should converge to the + # Stratonovich solution. + bm = diffrax.UnsafeBrownianPath((), getkey(), levy_area=diffrax.SpaceTimeLevyArea) + drift = diffrax.ODETerm(lambda t, y, args: -y) + diffusion = diffrax.ControlTerm(lambda t, y, args: 0.1 * y, bm) + terms = diffrax.MultiTerm(drift, diffusion) + + def run(y0): + sol = diffrax.diffeqsolve( + terms, + solver, + 0, + 1, + 0.01, + y0, + adjoint=diffrax.ForwardMode(), + ) + return sol.ys + + @jax.jit + def run_jvp(y0): + return jax.jvp(run, (y0,), (jnp.ones_like(y0),)) + + run_jvp(jnp.array(1.0))