Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exponential Euler. #567

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
Dopri8 as Dopri8,
Euler as Euler,
EulerHeun as EulerHeun,
ExponentialEuler as ExponentialEuler,
GeneralShARK as GeneralShARK,
HalfSolver as HalfSolver,
Heun as Heun,
Expand Down Expand Up @@ -117,6 +118,7 @@
StochasticButcherTableau as StochasticButcherTableau,
StratonovichMilstein as StratonovichMilstein,
Tsit5 as Tsit5,

)
from ._step_size_controller import (
AbstractAdaptiveStepSizeController as AbstractAdaptiveStepSizeController,
Expand All @@ -128,11 +130,13 @@
from ._term import (
AbstractTerm as AbstractTerm,
ControlTerm as ControlTerm,
LinearODETerm as LinearODETerm,
MultiTerm as MultiTerm,
ODETerm as ODETerm,
UnderdampedLangevinDiffusionTerm as UnderdampedLangevinDiffusionTerm,
UnderdampedLangevinDriftTerm as UnderdampedLangevinDriftTerm,
WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm,

)


Expand Down
1 change: 1 addition & 0 deletions diffrax/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .dopri8 import Dopri8 as Dopri8
from .euler import Euler as Euler
from .euler_heun import EulerHeun as EulerHeun
from .exp_euler import ExponentialEuler as ExponentialEuler
from .foster_langevin_srk import AbstractFosterLangevinSRK as AbstractFosterLangevinSRK
from .heun import Heun as Heun
from .implicit_euler import ImplicitEuler as ImplicitEuler
Expand Down
89 changes: 89 additions & 0 deletions diffrax/_solver/exp_euler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# Stuff for custom solver
from collections.abc import Callable
from typing import ClassVar

import diffrax
from diffrax._custom_types import VF, Args, BoolScalarLike, DenseInfo, RealScalarLike, Y
from diffrax._local_interpolation import LocalLinearInterpolation
from diffrax._solution import RESULTS
from diffrax._term import AbstractTerm, MultiTerm, ODETerm, LinearODETerm
from typing_extensions import TypeAlias

_ErrorEstimate: TypeAlias = None
_SolverState: TypeAlias = None
import jax
import jax.numpy as jnp
import lineax as lx
from jax.scipy.linalg import expm


def make_operators(operator, control):
match operator:
case lx.DiagonalLinearOperator():
# If the operator is diagonal things are fairly trivial
# Technically we could do the part below through lineax too
# but it just complicates things here as we return two different
# operators (linear vs full)
exp_Ah = jnp.exp(control * operator.diagonal)
phi = jnp.expm1(control * operator.diagonal) / (operator.diagonal * control)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is numerically stable if the denominator is small? (Maybe switch to a Taylor expansion below some cutoff?)

return lx.DiagonalLinearOperator(exp_Ah), lx.DiagonalLinearOperator(phi)

case _:
# If the linear operator is not diagonal we need to calculate the matrix exponential
# and solve the linear problem
exp_Ah = expm(control * operator.as_matrix())
A, b = operator * control, exp_Ah - jnp.eye(exp_Ah.shape[-1])
phi = jax.vmap(lx.linear_solve, in_axes=(None, 1))(A, b).value
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might be able to skip the vmap by doing the dot against Gh_sdW before the linear solve.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(+same comment about numerical stability)

return lx.MatrixLinearOperator(exp_Ah), lx.MatrixLinearOperator(phi)


class ExponentialEuler(diffrax.AbstractItoSolver):
term_structure: ClassVar = AbstractTerm
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might just be part of the cleanup you mentioned, but I think you want MultiTerm[tuple[LinearODETerm, AbstractTerm]] to make this well-defined with the unpacking of the term in .step. (Likewise for the annotations for the terms: arguments.)

interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = (
LocalLinearInterpolation
)

def order(self, terms):
return 1.0

def strong_order(self, terms):
return 0.5

def init(
self,
terms: AbstractTerm,
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
) -> _SolverState:
return None

def step(
self,
terms: AbstractTerm,
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
args: Args,
solver_state: _SolverState,
made_jump: BoolScalarLike,
) -> tuple[Y, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]:
del solver_state, made_jump

# We split the terms into linear and the non-linear plus possible noise
linear, non_linear = terms.terms[0].term, MultiTerm(*terms.terms[1:])
exp_Ah, phi = make_operators(linear.operator, linear.contr(t0, t1))
Gh_sdW = non_linear.vf_prod(t0, y0, args, non_linear.contr(t0, t1))
y1 = exp_Ah.mv(y0) + phi.mv(Gh_sdW)
dense_info = dict(y0=y0, y1=y1)
return y1, None, dense_info, None, RESULTS.successful

def func(
self,
terms: AbstractTerm,
t0: RealScalarLike,
y0: Y,
args: Args,
) -> VF:
return terms.vf(t0, y0, args)
10 changes: 10 additions & 0 deletions diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,6 +788,16 @@ def _to_vjp(_y, _diff_args, _diff_term):
return dy, da_y, da_diff_args, da_diff_term


class LinearODETerm(ODETerm):
operator: lx.MatrixLinearOperator | lx.DiagonalLinearOperator

def __init__(self, operator: lx.MatrixLinearOperator | lx.DiagonalLinearOperator):
self.operator = operator
vf = lambda t, y, args: self.operator.mv(y)
super().__init__(vector_field=vf)
Comment on lines +791 to +797
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, I try to avoid subclassing concrete classes -- c.f. https://docs.kidger.site/equinox/pattern/

Also, does this strictly need to be an ODE term? I think in other problems, we might be interested in expressing linear terms in this way (for any operator: lx.AbstractLinearOperator) regardless of the control.




# The Underdamped Langevin SDE trajectory consists of two components: the position
# `x` and the velocity `v`. Both of these have the same shape.
# So, by UnderdampedLangevinX we denote the shape of the x component, and by
Expand Down
232 changes: 232 additions & 0 deletions test/test_exp_euler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
import jax.numpy as jnp
import jax.random as jr
import lineax as lx
from diffrax import (
ControlTerm,
Dopri5,
EulerHeun,
ExponentialEuler,
HalfSolver,
MultiTerm,
ODETerm,
PIDController,
SaveAt,
VirtualBrownianTree,
diffeqsolve,
LinearODETerm

)



def test_linear():
"""Linear equation should be exact, even at large timesteps."""
linear_term = lx.DiagonalLinearOperator(-jnp.ones((1,))) # noqa: E731
non_linear_term = lambda t, y, args: jnp.zeros_like(y) # noqa: E731
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term))

saveat = SaveAt(ts=jnp.linspace(0, 3, 2))
sol = diffeqsolve(
term,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1.0,
y0=jnp.ones((1,)),
saveat=saveat,
)
assert jnp.allclose(sol.ys, jnp.array([[jnp.exp(0.0)], [jnp.exp(-3)]]))


def test_non_linear():
"""Non linear, comparison to Dopri5"""
A = -jnp.abs(jr.normal(jr.key(0), (10,)))
y0 = jr.normal(jr.key(42), (10,))

linear_term = lx.DiagonalLinearOperator(A) # noqa: E731
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term))

saveat = SaveAt(ts=jnp.linspace(0, 3, 100))

# Exponential solver
sol_exp = diffeqsolve(
term,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
)

# Baseline solver
solver = Dopri5()
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
sol_baseline = diffeqsolve(
term,
solver,
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
stepsize_controller=stepsize_controller,
)
assert jnp.allclose(sol_exp.ys, sol_baseline.ys, rtol=2e-3, atol=2e-3)


def test_error():
"""Testing if we can use the HalfSolver to get adaptive steps."""
A = -jnp.abs(jr.normal(jr.key(0), (10,)))
y0 = jr.normal(jr.key(42), (10,))

linear_term = lx.DiagonalLinearOperator(A) # noqa: E731
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term))
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
saveat = SaveAt(ts=jnp.linspace(0, 3, 100))

# Exponential solver
sol = diffeqsolve(
term,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1e-4,
y0=y0,
saveat=saveat,
max_steps=100000,
)

# Exponential solver
sol_adapt = diffeqsolve(
term,
HalfSolver(ExponentialEuler()),
t0=0,
t1=3,
dt0=1.0, # larger stepsize, should be adjusted
y0=y0,
saveat=saveat,
stepsize_controller=stepsize_controller,
)
assert jnp.allclose(sol.ys, sol_adapt.ys, rtol=2e-3, atol=3e-3)


def test_sde():
"Test using an SDE."
A = -jnp.abs(jr.normal(jr.key(0), (10,)))
y0 = jr.normal(jr.key(42), (10,))

linear_term = lx.DiagonalLinearOperator(A) # noqa: E731
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731
diffusion_term = lambda t, y, args: lx.DiagonalLinearOperator(jnp.full((10,), 0.1)) # noqa: E731
brownian_motion = VirtualBrownianTree(
0, 3, tol=1e-3, shape=(10,), key=jr.PRNGKey(0)
)
term = MultiTerm(
LinearODETerm(linear_term),
ODETerm(non_linear_term),
ControlTerm(diffusion_term, brownian_motion),
)
saveat = SaveAt(ts=jnp.linspace(0, 3, 100))

# Exponential solver
sol_exp = diffeqsolve(
term,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
)

# Shark solver
linear_term_exp = lambda t, y, args: A * y + 2 * jnp.cos(y) ** 3 # noqa: E731
term_exp = MultiTerm(
ODETerm(linear_term_exp),
ControlTerm(diffusion_term, brownian_motion),
)
sol_euler = diffeqsolve(
term_exp,
EulerHeun(),
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
)
assert jnp.allclose(sol_exp.ys, sol_euler.ys, rtol=1e-3, atol=1e-3)


def test_diagonal_matrix_exponential():
"""Compare result of diagonal specialisation to full matrix exponential."""
A = -jnp.abs(jr.normal(jr.key(0), (10,)))
y0 = jr.normal(jr.key(42), (10,))

linear_term_diag = lx.DiagonalLinearOperator(A)
linear_term_full = lx.MatrixLinearOperator(jnp.diag(A))
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731
term_diag = MultiTerm(LinearODETerm(linear_term_diag), ODETerm(non_linear_term))
term_full = MultiTerm(LinearODETerm(linear_term_full), ODETerm(non_linear_term))
saveat = SaveAt(ts=jnp.linspace(0, 3, 100))

# Diagonal approach
sol_diag = diffeqsolve(
term_diag,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
)
# Full matrix exponential
sol_full = diffeqsolve(
term_full,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
)
assert jnp.allclose(sol_diag.ys, sol_full.ys, rtol=1e-4, atol=1e-4)


def test_matrix_exponential():
"""Test if normal solver accept our term structure, and make sure results are the same."""
A = -jnp.abs(jr.normal(jr.key(0), (10, 10)))
y0 = jr.normal(jr.key(42), (10,))

linear_term = lx.MatrixLinearOperator(A)
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term))
saveat = SaveAt(ts=jnp.linspace(0, 3, 100))

# Exponential solver
sol_exp = diffeqsolve(
term,
ExponentialEuler(),
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
)

solver = Dopri5()
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5)
sol_dopri = diffeqsolve(
term,
solver,
t0=0,
t1=3,
dt0=1e-3,
y0=y0,
saveat=saveat,
stepsize_controller=stepsize_controller,
)

jnp.allclose(sol_exp.ys, sol_dopri.ys, rtol=2e-3, atol=2e-3)
Loading