-
-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add exponential Euler. #567
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Stuff for custom solver | ||
from collections.abc import Callable | ||
from typing import ClassVar | ||
|
||
import diffrax | ||
from diffrax._custom_types import VF, Args, BoolScalarLike, DenseInfo, RealScalarLike, Y | ||
from diffrax._local_interpolation import LocalLinearInterpolation | ||
from diffrax._solution import RESULTS | ||
from diffrax._term import AbstractTerm, MultiTerm, ODETerm, LinearODETerm | ||
from typing_extensions import TypeAlias | ||
|
||
_ErrorEstimate: TypeAlias = None | ||
_SolverState: TypeAlias = None | ||
import jax | ||
import jax.numpy as jnp | ||
import lineax as lx | ||
from jax.scipy.linalg import expm | ||
|
||
|
||
def make_operators(operator, control): | ||
match operator: | ||
case lx.DiagonalLinearOperator(): | ||
# If the operator is diagonal things are fairly trivial | ||
# Technically we could do the part below through lineax too | ||
# but it just complicates things here as we return two different | ||
# operators (linear vs full) | ||
exp_Ah = jnp.exp(control * operator.diagonal) | ||
phi = jnp.expm1(control * operator.diagonal) / (operator.diagonal * control) | ||
return lx.DiagonalLinearOperator(exp_Ah), lx.DiagonalLinearOperator(phi) | ||
|
||
case _: | ||
# If the linear operator is not diagonal we need to calculate the matrix exponential | ||
# and solve the linear problem | ||
exp_Ah = expm(control * operator.as_matrix()) | ||
A, b = operator * control, exp_Ah - jnp.eye(exp_Ah.shape[-1]) | ||
phi = jax.vmap(lx.linear_solve, in_axes=(None, 1))(A, b).value | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you might be able to skip the vmap by doing the dot against There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (+same comment about numerical stability) |
||
return lx.MatrixLinearOperator(exp_Ah), lx.MatrixLinearOperator(phi) | ||
|
||
|
||
class ExponentialEuler(diffrax.AbstractItoSolver): | ||
term_structure: ClassVar = AbstractTerm | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might just be part of the cleanup you mentioned, but I think you want |
||
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( | ||
LocalLinearInterpolation | ||
) | ||
|
||
def order(self, terms): | ||
return 1.0 | ||
|
||
def strong_order(self, terms): | ||
return 0.5 | ||
|
||
def init( | ||
self, | ||
terms: AbstractTerm, | ||
t0: RealScalarLike, | ||
t1: RealScalarLike, | ||
y0: Y, | ||
args: Args, | ||
) -> _SolverState: | ||
return None | ||
|
||
def step( | ||
self, | ||
terms: AbstractTerm, | ||
t0: RealScalarLike, | ||
t1: RealScalarLike, | ||
y0: Y, | ||
args: Args, | ||
solver_state: _SolverState, | ||
made_jump: BoolScalarLike, | ||
) -> tuple[Y, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: | ||
del solver_state, made_jump | ||
|
||
# We split the terms into linear and the non-linear plus possible noise | ||
linear, non_linear = terms.terms[0].term, MultiTerm(*terms.terms[1:]) | ||
exp_Ah, phi = make_operators(linear.operator, linear.contr(t0, t1)) | ||
Gh_sdW = non_linear.vf_prod(t0, y0, args, non_linear.contr(t0, t1)) | ||
y1 = exp_Ah.mv(y0) + phi.mv(Gh_sdW) | ||
dense_info = dict(y0=y0, y1=y1) | ||
return y1, None, dense_info, None, RESULTS.successful | ||
|
||
def func( | ||
self, | ||
terms: AbstractTerm, | ||
t0: RealScalarLike, | ||
y0: Y, | ||
args: Args, | ||
) -> VF: | ||
return terms.vf(t0, y0, args) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -788,6 +788,16 @@ def _to_vjp(_y, _diff_args, _diff_term): | |
return dy, da_y, da_diff_args, da_diff_term | ||
|
||
|
||
class LinearODETerm(ODETerm): | ||
operator: lx.MatrixLinearOperator | lx.DiagonalLinearOperator | ||
|
||
def __init__(self, operator: lx.MatrixLinearOperator | lx.DiagonalLinearOperator): | ||
self.operator = operator | ||
vf = lambda t, y, args: self.operator.mv(y) | ||
super().__init__(vector_field=vf) | ||
Comment on lines
+791
to
+797
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw, I try to avoid subclassing concrete classes -- c.f. https://docs.kidger.site/equinox/pattern/ Also, does this strictly need to be an ODE term? I think in other problems, we might be interested in expressing linear terms in this way (for any |
||
|
||
|
||
|
||
# The Underdamped Langevin SDE trajectory consists of two components: the position | ||
# `x` and the velocity `v`. Both of these have the same shape. | ||
# So, by UnderdampedLangevinX we denote the shape of the x component, and by | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
import jax.numpy as jnp | ||
import jax.random as jr | ||
import lineax as lx | ||
from diffrax import ( | ||
ControlTerm, | ||
Dopri5, | ||
EulerHeun, | ||
ExponentialEuler, | ||
HalfSolver, | ||
MultiTerm, | ||
ODETerm, | ||
PIDController, | ||
SaveAt, | ||
VirtualBrownianTree, | ||
diffeqsolve, | ||
LinearODETerm | ||
|
||
) | ||
|
||
|
||
|
||
def test_linear(): | ||
"""Linear equation should be exact, even at large timesteps.""" | ||
linear_term = lx.DiagonalLinearOperator(-jnp.ones((1,))) # noqa: E731 | ||
non_linear_term = lambda t, y, args: jnp.zeros_like(y) # noqa: E731 | ||
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term)) | ||
|
||
saveat = SaveAt(ts=jnp.linspace(0, 3, 2)) | ||
sol = diffeqsolve( | ||
term, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1.0, | ||
y0=jnp.ones((1,)), | ||
saveat=saveat, | ||
) | ||
assert jnp.allclose(sol.ys, jnp.array([[jnp.exp(0.0)], [jnp.exp(-3)]])) | ||
|
||
|
||
def test_non_linear(): | ||
"""Non linear, comparison to Dopri5""" | ||
A = -jnp.abs(jr.normal(jr.key(0), (10,))) | ||
y0 = jr.normal(jr.key(42), (10,)) | ||
|
||
linear_term = lx.DiagonalLinearOperator(A) # noqa: E731 | ||
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 | ||
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term)) | ||
|
||
saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) | ||
|
||
# Exponential solver | ||
sol_exp = diffeqsolve( | ||
term, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
) | ||
|
||
# Baseline solver | ||
solver = Dopri5() | ||
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5) | ||
sol_baseline = diffeqsolve( | ||
term, | ||
solver, | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
stepsize_controller=stepsize_controller, | ||
) | ||
assert jnp.allclose(sol_exp.ys, sol_baseline.ys, rtol=2e-3, atol=2e-3) | ||
|
||
|
||
def test_error(): | ||
"""Testing if we can use the HalfSolver to get adaptive steps.""" | ||
A = -jnp.abs(jr.normal(jr.key(0), (10,))) | ||
y0 = jr.normal(jr.key(42), (10,)) | ||
|
||
linear_term = lx.DiagonalLinearOperator(A) # noqa: E731 | ||
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 | ||
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term)) | ||
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5) | ||
saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) | ||
|
||
# Exponential solver | ||
sol = diffeqsolve( | ||
term, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-4, | ||
y0=y0, | ||
saveat=saveat, | ||
max_steps=100000, | ||
) | ||
|
||
# Exponential solver | ||
sol_adapt = diffeqsolve( | ||
term, | ||
HalfSolver(ExponentialEuler()), | ||
t0=0, | ||
t1=3, | ||
dt0=1.0, # larger stepsize, should be adjusted | ||
y0=y0, | ||
saveat=saveat, | ||
stepsize_controller=stepsize_controller, | ||
) | ||
assert jnp.allclose(sol.ys, sol_adapt.ys, rtol=2e-3, atol=3e-3) | ||
|
||
|
||
def test_sde(): | ||
"Test using an SDE." | ||
A = -jnp.abs(jr.normal(jr.key(0), (10,))) | ||
y0 = jr.normal(jr.key(42), (10,)) | ||
|
||
linear_term = lx.DiagonalLinearOperator(A) # noqa: E731 | ||
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 | ||
diffusion_term = lambda t, y, args: lx.DiagonalLinearOperator(jnp.full((10,), 0.1)) # noqa: E731 | ||
brownian_motion = VirtualBrownianTree( | ||
0, 3, tol=1e-3, shape=(10,), key=jr.PRNGKey(0) | ||
) | ||
term = MultiTerm( | ||
LinearODETerm(linear_term), | ||
ODETerm(non_linear_term), | ||
ControlTerm(diffusion_term, brownian_motion), | ||
) | ||
saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) | ||
|
||
# Exponential solver | ||
sol_exp = diffeqsolve( | ||
term, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
) | ||
|
||
# Shark solver | ||
linear_term_exp = lambda t, y, args: A * y + 2 * jnp.cos(y) ** 3 # noqa: E731 | ||
term_exp = MultiTerm( | ||
ODETerm(linear_term_exp), | ||
ControlTerm(diffusion_term, brownian_motion), | ||
) | ||
sol_euler = diffeqsolve( | ||
term_exp, | ||
EulerHeun(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
) | ||
assert jnp.allclose(sol_exp.ys, sol_euler.ys, rtol=1e-3, atol=1e-3) | ||
|
||
|
||
def test_diagonal_matrix_exponential(): | ||
"""Compare result of diagonal specialisation to full matrix exponential.""" | ||
A = -jnp.abs(jr.normal(jr.key(0), (10,))) | ||
y0 = jr.normal(jr.key(42), (10,)) | ||
|
||
linear_term_diag = lx.DiagonalLinearOperator(A) | ||
linear_term_full = lx.MatrixLinearOperator(jnp.diag(A)) | ||
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 | ||
term_diag = MultiTerm(LinearODETerm(linear_term_diag), ODETerm(non_linear_term)) | ||
term_full = MultiTerm(LinearODETerm(linear_term_full), ODETerm(non_linear_term)) | ||
saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) | ||
|
||
# Diagonal approach | ||
sol_diag = diffeqsolve( | ||
term_diag, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
) | ||
# Full matrix exponential | ||
sol_full = diffeqsolve( | ||
term_full, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
) | ||
assert jnp.allclose(sol_diag.ys, sol_full.ys, rtol=1e-4, atol=1e-4) | ||
|
||
|
||
def test_matrix_exponential(): | ||
"""Test if normal solver accept our term structure, and make sure results are the same.""" | ||
A = -jnp.abs(jr.normal(jr.key(0), (10, 10))) | ||
y0 = jr.normal(jr.key(42), (10,)) | ||
|
||
linear_term = lx.MatrixLinearOperator(A) | ||
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 | ||
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term)) | ||
saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) | ||
|
||
# Exponential solver | ||
sol_exp = diffeqsolve( | ||
term, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
) | ||
|
||
solver = Dopri5() | ||
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5) | ||
sol_dopri = diffeqsolve( | ||
term, | ||
solver, | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
stepsize_controller=stepsize_controller, | ||
) | ||
|
||
jnp.allclose(sol_exp.ys, sol_dopri.ys, rtol=2e-3, atol=2e-3) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is numerically stable if the denominator is small? (Maybe switch to a Taylor expansion below some cutoff?)