Skip to content

Commit

Permalink
Merge branch 'main' into Owen/control_revamp
Browse files Browse the repository at this point in the history
  • Loading branch information
lockwo authored Jan 5, 2025
2 parents 919abf9 + 4a308b8 commit 12bcf5a
Show file tree
Hide file tree
Showing 10 changed files with 135 additions and 65 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,4 @@ repos:
rev: v1.1.350
hooks:
- id: pyright
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions]
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typeguard==2.13.3, typing_extensions, wadler_lindig]
17 changes: 14 additions & 3 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@

from ._heuristics import is_sde, is_unsafe_sde
from ._saveat import save_y, SaveAt, SubSaveAt
from ._solver import AbstractItoSolver, AbstractRungeKutta, AbstractStratonovichSolver
from ._solver import (
AbstractItoSolver,
AbstractRungeKutta,
AbstractSRK,
AbstractStratonovichSolver,
)
from ._term import AbstractTerm, AdjointTerm


Expand Down Expand Up @@ -396,7 +401,10 @@ def loop(
msg = None
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
if (
isinstance(solver, (AbstractRungeKutta, AbstractSRK))
and solver.scan_kind is None
):
solver = eqx.tree_at(
lambda s: s.scan_kind, solver, "bounded", is_leaf=_is_none
)
Expand Down Expand Up @@ -923,7 +931,10 @@ def loop(
outer_while_loop = eqx.Partial(_outer_loop, kind="lax")
# Support forward-mode autodiff.
# TODO: remove this hack once we can JVP through custom_vjps.
if isinstance(solver, AbstractRungeKutta) and solver.scan_kind is None:
if (
isinstance(solver, (AbstractRungeKutta, AbstractSRK))
and solver.scan_kind is None
):
solver = eqx.tree_at(lambda s: s.scan_kind, solver, "lax", is_leaf=_is_none)
final_state = self._loop(
solver=solver,
Expand Down
8 changes: 8 additions & 0 deletions diffrax/_custom_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,14 @@ class SpaceTimeTimeLevyArea(AbstractSpaceTimeTimeLevyArea):
K: BM


AbstractBrownianIncrement.__module__ = "diffrax"
AbstractSpaceTimeLevyArea.__module__ = "diffrax"
AbstractSpaceTimeTimeLevyArea.__module__ = "diffrax"
BrownianIncrement.__module__ = "diffrax"
SpaceTimeLevyArea.__module__ = "diffrax"
SpaceTimeTimeLevyArea.__module__ = "diffrax"


def levy_tree_transpose(
tree_shape, tree: PyTree[AbstractBrownianIncrement]
) -> AbstractBrownianIncrement:
Expand Down
15 changes: 13 additions & 2 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import lineax.internal as lxi
import numpy as np
import optimistix as optx
import wadler_lindig as wl
from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real

from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint
Expand Down Expand Up @@ -192,7 +193,11 @@ def _check(term_cls, term, term_contr_kwargs, yi):
better_isinstance, control_type, control_type_expected
)
if not control_type_compatible:
raise ValueError(f"Control term {term} is incompatible.")
raise ValueError(
"Control term is incompatible: the returned control (e.g. "
f"Brownian motion for an SDE) was {control_type}, but this "
f"solver expected {control_type_expected}."
)
path_type_compatible = eqx.filter_eval_shape(
better_isinstance, path_type, path_type_expected
)
Expand All @@ -207,7 +212,13 @@ def _check(term_cls, term, term_contr_kwargs, yi):
jtu.tree_map(_check, term_structure, terms, contr_kwargs, y)
except Exception as e:
# ValueError may also arise from mismatched tree structures
raise ValueError("Terms are not compatible with solver!") from e
pretty_term = wl.pformat(terms)
pretty_expected = wl.pformat(term_structure)
raise ValueError(
f"Terms are not compatible with solver! Got:\n{pretty_term}\nbut expected:"
f"\n{pretty_expected}\nNote that terms are checked recursively: if you "
"scroll up you may find a root-cause error that is more specific."
) from e


def _is_subsaveat(x: Any) -> bool:
Expand Down
6 changes: 4 additions & 2 deletions diffrax/_solver/srk.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import abc
from dataclasses import dataclass
from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar, Union
from typing import Any, Generic, Literal, Optional, TYPE_CHECKING, TypeVar, Union
from typing_extensions import TypeAlias

import equinox as eqx
Expand Down Expand Up @@ -256,6 +256,8 @@ class AbstractSRK(AbstractSolver[_SolverState]):
as well as $b^H$, $a^H$, $b^K$, and $a^K$ if needed.
"""

scan_kind: Union[None, Literal["lax", "checkpointed"]] = None

interpolation_cls = LocalLinearInterpolation
term_compatible_contr_kwargs = (dict(), dict(use_levy=True))
tableau: AbstractClassVar[StochasticButcherTableau]
Expand Down Expand Up @@ -588,7 +590,7 @@ def compute_and_insert_kg_j(_w_kgs_in, _levylist_kgs_in):
scan_inputs,
len(b_sol),
buffers=lambda x: x,
kind="checkpointed",
kind="checkpointed" if self.scan_kind is None else self.scan_kind,
checkpoints="all",
)

Expand Down
9 changes: 9 additions & 0 deletions diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -1153,3 +1153,12 @@ def prod(
self, vf: UnderdampedLangevinTuple, control: RealScalarLike
) -> UnderdampedLangevinTuple:
return jtu.tree_map(lambda _vf: control * _vf, vf)


AbstractTerm.__module__ = "diffrax"
ODETerm.__module__ = "diffrax"
ControlTerm.__module__ = "diffrax"
WeaklyDiagonalControlTerm.__module__ = "diffrax"
MultiTerm.__module__ = "diffrax"
UnderdampedLangevinDriftTerm.__module__ = "diffrax"
UnderdampedLangevinDiffusionTerm.__module__ = "diffrax"
6 changes: 4 additions & 2 deletions examples/continuous_normalising_flow.ipynb

Large diffs are not rendered by default.

102 changes: 52 additions & 50 deletions examples/kalman_filter.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
urls = {repository = "https://github.com/patrick-kidger/diffrax" }
dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.9"]
dependencies = ["jax>=0.4.38", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.10", "lineax>=0.0.5", "optimistix>=0.0.10", "wadler_lindig>=0.1.1"]

[build-system]
requires = ["hatchling"]
Expand Down
33 changes: 29 additions & 4 deletions test/test_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,10 +506,7 @@ def run(model):
run(mlp)


@pytest.mark.parametrize(
"diffusion_fn",
["weak", "lineax"],
)
@pytest.mark.parametrize("diffusion_fn", ["weak", "lineax"])
def test_sde_against(diffusion_fn, getkey):
def f(t, y, args):
del t
Expand Down Expand Up @@ -567,3 +564,31 @@ def test_implicit_runge_kutta_direct_adjoint():
adjoint=diffrax.DirectAdjoint(),
stepsize_controller=diffrax.PIDController(rtol=1e-3, atol=1e-6),
)


@pytest.mark.parametrize("solver", (diffrax.Tsit5(), diffrax.GeneralShARK()))
def test_forward_mode_runge_kutta(solver, getkey):
# Totally fine that we're using Tsit5 with an SDE, it should converge to the
# Stratonovich solution.
bm = diffrax.UnsafeBrownianPath((), getkey(), levy_area=diffrax.SpaceTimeLevyArea)
drift = diffrax.ODETerm(lambda t, y, args: -y)
diffusion = diffrax.ControlTerm(lambda t, y, args: 0.1 * y, bm)
terms = diffrax.MultiTerm(drift, diffusion)

def run(y0):
sol = diffrax.diffeqsolve(
terms,
solver,
0,
1,
0.01,
y0,
adjoint=diffrax.ForwardMode(),
)
return sol.ys

@jax.jit
def run_jvp(y0):
return jax.jvp(run, (y0,), (jnp.ones_like(y0),))

run_jvp(jnp.array(1.0))

0 comments on commit 12bcf5a

Please sign in to comment.