Skip to content

Commit

Permalink
Add exponential Euler.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gert-Jan Both committed Jan 13, 2025
1 parent 134a40a commit bf5d3ed
Show file tree
Hide file tree
Showing 5 changed files with 336 additions and 0 deletions.
4 changes: 4 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
Dopri8 as Dopri8,
Euler as Euler,
EulerHeun as EulerHeun,
ExponentialEuler as ExponentialEuler,
GeneralShARK as GeneralShARK,
HalfSolver as HalfSolver,
Heun as Heun,
Expand Down Expand Up @@ -117,6 +118,7 @@
StochasticButcherTableau as StochasticButcherTableau,
StratonovichMilstein as StratonovichMilstein,
Tsit5 as Tsit5,

)
from ._step_size_controller import (
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
Expand All @@ -128,11 +130,13 @@
from ._term import (
AbstractTerm as AbstractTerm,
ControlTerm as ControlTerm,
LinearODETerm as LinearODETerm,
MultiTerm as MultiTerm,
ODETerm as ODETerm,
UnderdampedLangevinDiffusionTerm as UnderdampedLangevinDiffusionTerm,
UnderdampedLangevinDriftTerm as UnderdampedLangevinDriftTerm,
WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm,

)


Expand Down
1 change: 1 addition & 0 deletions diffrax/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .dopri8 import Dopri8 as Dopri8
from .euler import Euler as Euler
from .euler_heun import EulerHeun as EulerHeun
from .exp_euler import ExponentialEuler as ExponentialEuler
from .foster_langevin_srk import AbstractFosterLangevinSRK as AbstractFosterLangevinSRK
from .heun import Heun as Heun
from .implicit_euler import ImplicitEuler as ImplicitEuler
Expand Down
89 changes: 89 additions & 0 deletions diffrax/_solver/exp_euler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Stuff for custom solver
from collections.abc import Callable
from typing import ClassVar

import diffrax
from diffrax._custom_types import VF, Args, BoolScalarLike, DenseInfo, RealScalarLike, Y
from diffrax._local_interpolation import LocalLinearInterpolation
from diffrax._solution import RESULTS
from diffrax._term import AbstractTerm, MultiTerm, ODETerm, LinearODETerm
from typing_extensions import TypeAlias

_ErrorEstimate: TypeAlias = None
_SolverState: TypeAlias = None
import jax
import jax.numpy as jnp
import lineax as lx
from jax.scipy.linalg import expm


def make_operators(operator, control):
match operator:
case lx.DiagonalLinearOperator():
# If the operator is diagonal things are fairly trivial
# Technically we could do the part below through lineax too
# but it just complicates things here as we return two different
# operators (linear vs full)
exp_Ah = jnp.exp(control * operator.diagonal)
phi = jnp.expm1(control * operator.diagonal) / (operator.diagonal * control)
return lx.DiagonalLinearOperator(exp_Ah), lx.DiagonalLinearOperator(phi)

case _:
# If the linear operator is not diagonal we need to calculate the matrix exponential
# and solve the linear problem
exp_Ah = expm(control * operator.as_matrix())
A, b = operator * control, exp_Ah - jnp.eye(exp_Ah.shape[-1])
phi = jax.vmap(lx.linear_solve, in_axes=(None, 1))(A, b).value
return lx.MatrixLinearOperator(exp_Ah), lx.MatrixLinearOperator(phi)


class ExponentialEuler(diffrax.AbstractItoSolver):
term_structure: ClassVar = AbstractTerm
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
LocalLinearInterpolation
)

def order(self, terms):
return 1.0

def strong_order(self, terms):
return 0.5

def init(
self,
terms: AbstractTerm,
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
) -> _SolverState:
return None

def step(
self,
terms: AbstractTerm,
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]:
del solver_state, made_jump

# We split the terms into linear and the non-linear plus possible noise
linear, non_linear = terms.terms[0].term, MultiTerm(*terms.terms[1:])
exp_Ah, phi = make_operators(linear.operator, linear.contr(t0, t1))
Gh_sdW = non_linear.vf_prod(t0, y0, args, non_linear.contr(t0, t1))
y1 = exp_Ah.mv(y0) + phi.mv(Gh_sdW)
dense_info = dict(y0=y0, y1=y1)
return y1, None, dense_info, None, RESULTS.successful

def func(
self,
terms: AbstractTerm,
t0: RealScalarLike,
y0: Y,
args: Args,
) -> VF:
return terms.vf(t0, y0, args)
10 changes: 10 additions & 0 deletions diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,16 @@ def _to_vjp(_y, _diff_args, _diff_term):
return dy, da_y, da_diff_args, da_diff_term


class LinearODETerm(ODETerm):
operator: lx.MatrixLinearOperator | lx.DiagonalLinearOperator

def __init__(self, operator: lx.MatrixLinearOperator | lx.DiagonalLinearOperator):
self.operator = operator
vf = lambda t, y, args: self.operator.mv(y)
super().__init__(vector_field=vf)



# The Underdamped Langevin SDE trajectory consists of two components: the position
# `x` and the velocity `v`. Both of these have the same shape.
# So, by UnderdampedLangevinX we denote the shape of the x component, and by
Expand Down
232 changes: 232 additions & 0 deletions test/test_exp_euler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
from diffrax import (
ControlTerm,
Dopri5,
EulerHeun,
ExponentialEuler,
HalfSolver,
MultiTerm,
ODETerm,
PIDController,
SaveAt,
VirtualBrownianTree,
diffeqsolve,
LinearODETerm

)



def test_linear():
"""Linear equation should be exact, even at large timesteps."""
linear_term = lx.DiagonalLinearOperator(-jnp.ones((1,))) # noqa: E731
non_linear_term = lambda t, y, args: jnp.zeros_like(y) # noqa: E731
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term))

saveat = SaveAt(ts=jnp.linspace(0, 3, 2))
sol = diffeqsolve(
term,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1.0,
y0=jnp.ones((1,)),
saveat=saveat,
)
assert jnp.allclose(sol.ys, jnp.array([[jnp.exp(0.0)], [jnp.exp(-3)]]))


def test_non_linear():
"""Non linear, comparison to Dopri5"""
A = -jnp.abs(jr.normal(jr.key(0), (10,)))
y0 = jr.normal(jr.key(42), (10,))

linear_term = lx.DiagonalLinearOperator(A) # noqa: E731
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term))

saveat = SaveAt(ts=jnp.linspace(0, 3, 100))

# Exponential solver
sol_exp = diffeqsolve(
term,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
)

# Baseline solver
solver = Dopri5()
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
sol_baseline = diffeqsolve(
term,
solver,
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
stepsize_controller=stepsize_controller,
)
assert jnp.allclose(sol_exp.ys, sol_baseline.ys, rtol=2e-3, atol=2e-3)


def test_error():
"""Testing if we can use the HalfSolver to get adaptive steps."""
A = -jnp.abs(jr.normal(jr.key(0), (10,)))
y0 = jr.normal(jr.key(42), (10,))

linear_term = lx.DiagonalLinearOperator(A) # noqa: E731
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term))
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
saveat = SaveAt(ts=jnp.linspace(0, 3, 100))

# Exponential solver
sol = diffeqsolve(
term,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1e-4,
y0=y0,
saveat=saveat,
max_steps=100000,
)

# Exponential solver
sol_adapt = diffeqsolve(
term,
HalfSolver(ExponentialEuler()),
t0=0,
t1=3,
dt0=1.0, # larger stepsize, should be adjusted
y0=y0,
saveat=saveat,
stepsize_controller=stepsize_controller,
)
assert jnp.allclose(sol.ys, sol_adapt.ys, rtol=2e-3, atol=3e-3)


def test_sde():
"Test using an SDE."
A = -jnp.abs(jr.normal(jr.key(0), (10,)))
y0 = jr.normal(jr.key(42), (10,))

linear_term = lx.DiagonalLinearOperator(A) # noqa: E731
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731
diffusion_term = lambda t, y, args: lx.DiagonalLinearOperator(jnp.full((10,), 0.1)) # noqa: E731
brownian_motion = VirtualBrownianTree(
0, 3, tol=1e-3, shape=(10,), key=jr.PRNGKey(0)
)
term = MultiTerm(
LinearODETerm(linear_term),
ODETerm(non_linear_term),
ControlTerm(diffusion_term, brownian_motion),
)
saveat = SaveAt(ts=jnp.linspace(0, 3, 100))

# Exponential solver
sol_exp = diffeqsolve(
term,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
)

# Shark solver
linear_term_exp = lambda t, y, args: A * y + 2 * jnp.cos(y) ** 3 # noqa: E731
term_exp = MultiTerm(
ODETerm(linear_term_exp),
ControlTerm(diffusion_term, brownian_motion),
)
sol_euler = diffeqsolve(
term_exp,
EulerHeun(),
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
)
assert jnp.allclose(sol_exp.ys, sol_euler.ys, rtol=1e-3, atol=1e-3)


def test_diagonal_matrix_exponential():
"""Compare result of diagonal specialisation to full matrix exponential."""
A = -jnp.abs(jr.normal(jr.key(0), (10,)))
y0 = jr.normal(jr.key(42), (10,))

linear_term_diag = lx.DiagonalLinearOperator(A)
linear_term_full = lx.MatrixLinearOperator(jnp.diag(A))
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731
term_diag = MultiTerm(LinearODETerm(linear_term_diag), ODETerm(non_linear_term))
term_full = MultiTerm(LinearODETerm(linear_term_full), ODETerm(non_linear_term))
saveat = SaveAt(ts=jnp.linspace(0, 3, 100))

# Diagonal approach
sol_diag = diffeqsolve(
term_diag,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
)
# Full matrix exponential
sol_full = diffeqsolve(
term_full,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
)
assert jnp.allclose(sol_diag.ys, sol_full.ys, rtol=1e-4, atol=1e-4)


def test_matrix_exponential():
"""Test if normal solver accept our term structure, and make sure results are the same."""
A = -jnp.abs(jr.normal(jr.key(0), (10, 10)))
y0 = jr.normal(jr.key(42), (10,))

linear_term = lx.MatrixLinearOperator(A)
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term))
saveat = SaveAt(ts=jnp.linspace(0, 3, 100))

# Exponential solver
sol_exp = diffeqsolve(
term,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
)

solver = Dopri5()
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
sol_dopri = diffeqsolve(
term,
solver,
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
stepsize_controller=stepsize_controller,
)

jnp.allclose(sol_exp.ys, sol_dopri.ys, rtol=2e-3, atol=2e-3)

0 comments on commit bf5d3ed

Please sign in to comment.