diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 42073a10..c911b1f7 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -89,6 +89,7 @@ Dopri8 as Dopri8, Euler as Euler, EulerHeun as EulerHeun, + ExponentialEuler as ExponentialEuler, GeneralShARK as GeneralShARK, HalfSolver as HalfSolver, Heun as Heun, @@ -117,6 +118,7 @@ StochasticButcherTableau as StochasticButcherTableau, StratonovichMilstein as StratonovichMilstein, Tsit5 as Tsit5, + ) from ._step_size_controller import ( AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController, @@ -128,11 +130,13 @@ from ._term import ( AbstractTerm as AbstractTerm, ControlTerm as ControlTerm, + LinearODETerm as LinearODETerm, MultiTerm as MultiTerm, ODETerm as ODETerm, UnderdampedLangevinDiffusionTerm as UnderdampedLangevinDiffusionTerm, UnderdampedLangevinDriftTerm as UnderdampedLangevinDriftTerm, WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm, + ) diff --git a/diffrax/_solver/__init__.py b/diffrax/_solver/__init__.py index 0a840413..bf2ff90a 100644 --- a/diffrax/_solver/__init__.py +++ b/diffrax/_solver/__init__.py @@ -13,6 +13,7 @@ from .dopri8 import Dopri8 as Dopri8 from .euler import Euler as Euler from .euler_heun import EulerHeun as EulerHeun +from .exp_euler import ExponentialEuler as ExponentialEuler from .foster_langevin_srk import AbstractFosterLangevinSRK as AbstractFosterLangevinSRK from .heun import Heun as Heun from .implicit_euler import ImplicitEuler as ImplicitEuler diff --git a/diffrax/_solver/exp_euler.py b/diffrax/_solver/exp_euler.py new file mode 100644 index 00000000..155a368f --- /dev/null +++ b/diffrax/_solver/exp_euler.py @@ -0,0 +1,89 @@ +# Stuff for custom solver +from collections.abc import Callable +from typing import ClassVar + +import diffrax +from diffrax._custom_types import VF, Args, BoolScalarLike, DenseInfo, RealScalarLike, Y +from diffrax._local_interpolation import LocalLinearInterpolation +from diffrax._solution import RESULTS +from diffrax._term import AbstractTerm, MultiTerm, ODETerm, LinearODETerm +from typing_extensions import TypeAlias + +_ErrorEstimate: TypeAlias = None +_SolverState: TypeAlias = None +import jax +import jax.numpy as jnp +import lineax as lx +from jax.scipy.linalg import expm + + +def make_operators(operator, control): + match operator: + case lx.DiagonalLinearOperator(): + # If the operator is diagonal things are fairly trivial + # Technically we could do the part below through lineax too + # but it just complicates things here as we return two different + # operators (linear vs full) + exp_Ah = jnp.exp(control * operator.diagonal) + phi = jnp.expm1(control * operator.diagonal) / (operator.diagonal * control) + return lx.DiagonalLinearOperator(exp_Ah), lx.DiagonalLinearOperator(phi) + + case _: + # If the linear operator is not diagonal we need to calculate the matrix exponential + # and solve the linear problem + exp_Ah = expm(control * operator.as_matrix()) + A, b = operator * control, exp_Ah - jnp.eye(exp_Ah.shape[-1]) + phi = jax.vmap(lx.linear_solve, in_axes=(None, 1))(A, b).value + return lx.MatrixLinearOperator(exp_Ah), lx.MatrixLinearOperator(phi) + + +class ExponentialEuler(diffrax.AbstractItoSolver): + term_structure: ClassVar = AbstractTerm + interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( + LocalLinearInterpolation + ) + + def order(self, terms): + return 1.0 + + def strong_order(self, terms): + return 0.5 + + def init( + self, + terms: AbstractTerm, + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + args: Args, + ) -> _SolverState: + return None + + def step( + self, + terms: AbstractTerm, + t0: RealScalarLike, + t1: RealScalarLike, + y0: Y, + args: Args, + solver_state: _SolverState, + made_jump: BoolScalarLike, + ) -> tuple[Y, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: + del solver_state, made_jump + + # We split the terms into linear and the non-linear plus possible noise + linear, non_linear = terms.terms[0].term, MultiTerm(*terms.terms[1:]) + exp_Ah, phi = make_operators(linear.operator, linear.contr(t0, t1)) + Gh_sdW = non_linear.vf_prod(t0, y0, args, non_linear.contr(t0, t1)) + y1 = exp_Ah.mv(y0) + phi.mv(Gh_sdW) + dense_info = dict(y0=y0, y1=y1) + return y1, None, dense_info, None, RESULTS.successful + + def func( + self, + terms: AbstractTerm, + t0: RealScalarLike, + y0: Y, + args: Args, + ) -> VF: + return terms.vf(t0, y0, args) diff --git a/diffrax/_term.py b/diffrax/_term.py index efa28d29..87ea1606 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -788,6 +788,16 @@ def _to_vjp(_y, _diff_args, _diff_term): return dy, da_y, da_diff_args, da_diff_term +class LinearODETerm(ODETerm): + operator: lx.MatrixLinearOperator | lx.DiagonalLinearOperator + + def __init__(self, operator: lx.MatrixLinearOperator | lx.DiagonalLinearOperator): + self.operator = operator + vf = lambda t, y, args: self.operator.mv(y) + super().__init__(vector_field=vf) + + + # The Underdamped Langevin SDE trajectory consists of two components: the position # `x` and the velocity `v`. Both of these have the same shape. # So, by UnderdampedLangevinX we denote the shape of the x component, and by diff --git a/test/test_exp_euler.py b/test/test_exp_euler.py new file mode 100644 index 00000000..bc9b0ba7 --- /dev/null +++ b/test/test_exp_euler.py @@ -0,0 +1,232 @@ +import jax.numpy as jnp +import jax.random as jr +import lineax as lx +from diffrax import ( + ControlTerm, + Dopri5, + EulerHeun, + ExponentialEuler, + HalfSolver, + MultiTerm, + ODETerm, + PIDController, + SaveAt, + VirtualBrownianTree, + diffeqsolve, + LinearODETerm + +) + + + +def test_linear(): + """Linear equation should be exact, even at large timesteps.""" + linear_term = lx.DiagonalLinearOperator(-jnp.ones((1,))) # noqa: E731 + non_linear_term = lambda t, y, args: jnp.zeros_like(y) # noqa: E731 + term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term)) + + saveat = SaveAt(ts=jnp.linspace(0, 3, 2)) + sol = diffeqsolve( + term, + ExponentialEuler(), + t0=0, + t1=3, + dt0=1.0, + y0=jnp.ones((1,)), + saveat=saveat, + ) + assert jnp.allclose(sol.ys, jnp.array([[jnp.exp(0.0)], [jnp.exp(-3)]])) + + +def test_non_linear(): + """Non linear, comparison to Dopri5""" + A = -jnp.abs(jr.normal(jr.key(0), (10,))) + y0 = jr.normal(jr.key(42), (10,)) + + linear_term = lx.DiagonalLinearOperator(A) # noqa: E731 + non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 + term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term)) + + saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) + + # Exponential solver + sol_exp = diffeqsolve( + term, + ExponentialEuler(), + t0=0, + t1=3, + dt0=1e-3, + y0=y0, + saveat=saveat, + ) + + # Baseline solver + solver = Dopri5() + stepsize_controller = PIDController(rtol=1e-5, atol=1e-5) + sol_baseline = diffeqsolve( + term, + solver, + t0=0, + t1=3, + dt0=1e-3, + y0=y0, + saveat=saveat, + stepsize_controller=stepsize_controller, + ) + assert jnp.allclose(sol_exp.ys, sol_baseline.ys, rtol=2e-3, atol=2e-3) + + +def test_error(): + """Testing if we can use the HalfSolver to get adaptive steps.""" + A = -jnp.abs(jr.normal(jr.key(0), (10,))) + y0 = jr.normal(jr.key(42), (10,)) + + linear_term = lx.DiagonalLinearOperator(A) # noqa: E731 + non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 + term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term)) + stepsize_controller = PIDController(rtol=1e-5, atol=1e-5) + saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) + + # Exponential solver + sol = diffeqsolve( + term, + ExponentialEuler(), + t0=0, + t1=3, + dt0=1e-4, + y0=y0, + saveat=saveat, + max_steps=100000, + ) + + # Exponential solver + sol_adapt = diffeqsolve( + term, + HalfSolver(ExponentialEuler()), + t0=0, + t1=3, + dt0=1.0, # larger stepsize, should be adjusted + y0=y0, + saveat=saveat, + stepsize_controller=stepsize_controller, + ) + assert jnp.allclose(sol.ys, sol_adapt.ys, rtol=2e-3, atol=3e-3) + + +def test_sde(): + "Test using an SDE." + A = -jnp.abs(jr.normal(jr.key(0), (10,))) + y0 = jr.normal(jr.key(42), (10,)) + + linear_term = lx.DiagonalLinearOperator(A) # noqa: E731 + non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 + diffusion_term = lambda t, y, args: lx.DiagonalLinearOperator(jnp.full((10,), 0.1)) # noqa: E731 + brownian_motion = VirtualBrownianTree( + 0, 3, tol=1e-3, shape=(10,), key=jr.PRNGKey(0) + ) + term = MultiTerm( + LinearODETerm(linear_term), + ODETerm(non_linear_term), + ControlTerm(diffusion_term, brownian_motion), + ) + saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) + + # Exponential solver + sol_exp = diffeqsolve( + term, + ExponentialEuler(), + t0=0, + t1=3, + dt0=1e-3, + y0=y0, + saveat=saveat, + ) + + # Shark solver + linear_term_exp = lambda t, y, args: A * y + 2 * jnp.cos(y) ** 3 # noqa: E731 + term_exp = MultiTerm( + ODETerm(linear_term_exp), + ControlTerm(diffusion_term, brownian_motion), + ) + sol_euler = diffeqsolve( + term_exp, + EulerHeun(), + t0=0, + t1=3, + dt0=1e-3, + y0=y0, + saveat=saveat, + ) + assert jnp.allclose(sol_exp.ys, sol_euler.ys, rtol=1e-3, atol=1e-3) + + +def test_diagonal_matrix_exponential(): + """Compare result of diagonal specialisation to full matrix exponential.""" + A = -jnp.abs(jr.normal(jr.key(0), (10,))) + y0 = jr.normal(jr.key(42), (10,)) + + linear_term_diag = lx.DiagonalLinearOperator(A) + linear_term_full = lx.MatrixLinearOperator(jnp.diag(A)) + non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 + term_diag = MultiTerm(LinearODETerm(linear_term_diag), ODETerm(non_linear_term)) + term_full = MultiTerm(LinearODETerm(linear_term_full), ODETerm(non_linear_term)) + saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) + + # Diagonal approach + sol_diag = diffeqsolve( + term_diag, + ExponentialEuler(), + t0=0, + t1=3, + dt0=1e-3, + y0=y0, + saveat=saveat, + ) + # Full matrix exponential + sol_full = diffeqsolve( + term_full, + ExponentialEuler(), + t0=0, + t1=3, + dt0=1e-3, + y0=y0, + saveat=saveat, + ) + assert jnp.allclose(sol_diag.ys, sol_full.ys, rtol=1e-4, atol=1e-4) + + +def test_matrix_exponential(): + """Test if normal solver accept our term structure, and make sure results are the same.""" + A = -jnp.abs(jr.normal(jr.key(0), (10, 10))) + y0 = jr.normal(jr.key(42), (10,)) + + linear_term = lx.MatrixLinearOperator(A) + non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 + term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term)) + saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) + + # Exponential solver + sol_exp = diffeqsolve( + term, + ExponentialEuler(), + t0=0, + t1=3, + dt0=1e-3, + y0=y0, + saveat=saveat, + ) + + solver = Dopri5() + stepsize_controller = PIDController(rtol=1e-5, atol=1e-5) + sol_dopri = diffeqsolve( + term, + solver, + t0=0, + t1=3, + dt0=1e-3, + y0=y0, + saveat=saveat, + stepsize_controller=stepsize_controller, + ) + + jnp.allclose(sol_exp.ys, sol_dopri.ys, rtol=2e-3, atol=2e-3)