-
-
Notifications
You must be signed in to change notification settings - Fork 140
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Gert-Jan Both
committed
Jan 13, 2025
1 parent
134a40a
commit bf5d3ed
Showing
5 changed files
with
336 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# Stuff for custom solver | ||
from collections.abc import Callable | ||
from typing import ClassVar | ||
|
||
import diffrax | ||
from diffrax._custom_types import VF, Args, BoolScalarLike, DenseInfo, RealScalarLike, Y | ||
from diffrax._local_interpolation import LocalLinearInterpolation | ||
from diffrax._solution import RESULTS | ||
from diffrax._term import AbstractTerm, MultiTerm, ODETerm, LinearODETerm | ||
from typing_extensions import TypeAlias | ||
|
||
_ErrorEstimate: TypeAlias = None | ||
_SolverState: TypeAlias = None | ||
import jax | ||
import jax.numpy as jnp | ||
import lineax as lx | ||
from jax.scipy.linalg import expm | ||
|
||
|
||
def make_operators(operator, control): | ||
match operator: | ||
case lx.DiagonalLinearOperator(): | ||
# If the operator is diagonal things are fairly trivial | ||
# Technically we could do the part below through lineax too | ||
# but it just complicates things here as we return two different | ||
# operators (linear vs full) | ||
exp_Ah = jnp.exp(control * operator.diagonal) | ||
phi = jnp.expm1(control * operator.diagonal) / (operator.diagonal * control) | ||
return lx.DiagonalLinearOperator(exp_Ah), lx.DiagonalLinearOperator(phi) | ||
|
||
case _: | ||
# If the linear operator is not diagonal we need to calculate the matrix exponential | ||
# and solve the linear problem | ||
exp_Ah = expm(control * operator.as_matrix()) | ||
A, b = operator * control, exp_Ah - jnp.eye(exp_Ah.shape[-1]) | ||
phi = jax.vmap(lx.linear_solve, in_axes=(None, 1))(A, b).value | ||
return lx.MatrixLinearOperator(exp_Ah), lx.MatrixLinearOperator(phi) | ||
|
||
|
||
class ExponentialEuler(diffrax.AbstractItoSolver): | ||
term_structure: ClassVar = AbstractTerm | ||
interpolation_cls: ClassVar[Callable[..., LocalLinearInterpolation]] = ( | ||
LocalLinearInterpolation | ||
) | ||
|
||
def order(self, terms): | ||
return 1.0 | ||
|
||
def strong_order(self, terms): | ||
return 0.5 | ||
|
||
def init( | ||
self, | ||
terms: AbstractTerm, | ||
t0: RealScalarLike, | ||
t1: RealScalarLike, | ||
y0: Y, | ||
args: Args, | ||
) -> _SolverState: | ||
return None | ||
|
||
def step( | ||
self, | ||
terms: AbstractTerm, | ||
t0: RealScalarLike, | ||
t1: RealScalarLike, | ||
y0: Y, | ||
args: Args, | ||
solver_state: _SolverState, | ||
made_jump: BoolScalarLike, | ||
) -> tuple[Y, _ErrorEstimate, DenseInfo, _SolverState, RESULTS]: | ||
del solver_state, made_jump | ||
|
||
# We split the terms into linear and the non-linear plus possible noise | ||
linear, non_linear = terms.terms[0].term, MultiTerm(*terms.terms[1:]) | ||
exp_Ah, phi = make_operators(linear.operator, linear.contr(t0, t1)) | ||
Gh_sdW = non_linear.vf_prod(t0, y0, args, non_linear.contr(t0, t1)) | ||
y1 = exp_Ah.mv(y0) + phi.mv(Gh_sdW) | ||
dense_info = dict(y0=y0, y1=y1) | ||
return y1, None, dense_info, None, RESULTS.successful | ||
|
||
def func( | ||
self, | ||
terms: AbstractTerm, | ||
t0: RealScalarLike, | ||
y0: Y, | ||
args: Args, | ||
) -> VF: | ||
return terms.vf(t0, y0, args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,232 @@ | ||
import jax.numpy as jnp | ||
import jax.random as jr | ||
import lineax as lx | ||
from diffrax import ( | ||
ControlTerm, | ||
Dopri5, | ||
EulerHeun, | ||
ExponentialEuler, | ||
HalfSolver, | ||
MultiTerm, | ||
ODETerm, | ||
PIDController, | ||
SaveAt, | ||
VirtualBrownianTree, | ||
diffeqsolve, | ||
LinearODETerm | ||
|
||
) | ||
|
||
|
||
|
||
def test_linear(): | ||
"""Linear equation should be exact, even at large timesteps.""" | ||
linear_term = lx.DiagonalLinearOperator(-jnp.ones((1,))) # noqa: E731 | ||
non_linear_term = lambda t, y, args: jnp.zeros_like(y) # noqa: E731 | ||
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term)) | ||
|
||
saveat = SaveAt(ts=jnp.linspace(0, 3, 2)) | ||
sol = diffeqsolve( | ||
term, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1.0, | ||
y0=jnp.ones((1,)), | ||
saveat=saveat, | ||
) | ||
assert jnp.allclose(sol.ys, jnp.array([[jnp.exp(0.0)], [jnp.exp(-3)]])) | ||
|
||
|
||
def test_non_linear(): | ||
"""Non linear, comparison to Dopri5""" | ||
A = -jnp.abs(jr.normal(jr.key(0), (10,))) | ||
y0 = jr.normal(jr.key(42), (10,)) | ||
|
||
linear_term = lx.DiagonalLinearOperator(A) # noqa: E731 | ||
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 | ||
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term)) | ||
|
||
saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) | ||
|
||
# Exponential solver | ||
sol_exp = diffeqsolve( | ||
term, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
) | ||
|
||
# Baseline solver | ||
solver = Dopri5() | ||
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5) | ||
sol_baseline = diffeqsolve( | ||
term, | ||
solver, | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
stepsize_controller=stepsize_controller, | ||
) | ||
assert jnp.allclose(sol_exp.ys, sol_baseline.ys, rtol=2e-3, atol=2e-3) | ||
|
||
|
||
def test_error(): | ||
"""Testing if we can use the HalfSolver to get adaptive steps.""" | ||
A = -jnp.abs(jr.normal(jr.key(0), (10,))) | ||
y0 = jr.normal(jr.key(42), (10,)) | ||
|
||
linear_term = lx.DiagonalLinearOperator(A) # noqa: E731 | ||
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 | ||
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term)) | ||
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5) | ||
saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) | ||
|
||
# Exponential solver | ||
sol = diffeqsolve( | ||
term, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-4, | ||
y0=y0, | ||
saveat=saveat, | ||
max_steps=100000, | ||
) | ||
|
||
# Exponential solver | ||
sol_adapt = diffeqsolve( | ||
term, | ||
HalfSolver(ExponentialEuler()), | ||
t0=0, | ||
t1=3, | ||
dt0=1.0, # larger stepsize, should be adjusted | ||
y0=y0, | ||
saveat=saveat, | ||
stepsize_controller=stepsize_controller, | ||
) | ||
assert jnp.allclose(sol.ys, sol_adapt.ys, rtol=2e-3, atol=3e-3) | ||
|
||
|
||
def test_sde(): | ||
"Test using an SDE." | ||
A = -jnp.abs(jr.normal(jr.key(0), (10,))) | ||
y0 = jr.normal(jr.key(42), (10,)) | ||
|
||
linear_term = lx.DiagonalLinearOperator(A) # noqa: E731 | ||
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 | ||
diffusion_term = lambda t, y, args: lx.DiagonalLinearOperator(jnp.full((10,), 0.1)) # noqa: E731 | ||
brownian_motion = VirtualBrownianTree( | ||
0, 3, tol=1e-3, shape=(10,), key=jr.PRNGKey(0) | ||
) | ||
term = MultiTerm( | ||
LinearODETerm(linear_term), | ||
ODETerm(non_linear_term), | ||
ControlTerm(diffusion_term, brownian_motion), | ||
) | ||
saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) | ||
|
||
# Exponential solver | ||
sol_exp = diffeqsolve( | ||
term, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
) | ||
|
||
# Shark solver | ||
linear_term_exp = lambda t, y, args: A * y + 2 * jnp.cos(y) ** 3 # noqa: E731 | ||
term_exp = MultiTerm( | ||
ODETerm(linear_term_exp), | ||
ControlTerm(diffusion_term, brownian_motion), | ||
) | ||
sol_euler = diffeqsolve( | ||
term_exp, | ||
EulerHeun(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
) | ||
assert jnp.allclose(sol_exp.ys, sol_euler.ys, rtol=1e-3, atol=1e-3) | ||
|
||
|
||
def test_diagonal_matrix_exponential(): | ||
"""Compare result of diagonal specialisation to full matrix exponential.""" | ||
A = -jnp.abs(jr.normal(jr.key(0), (10,))) | ||
y0 = jr.normal(jr.key(42), (10,)) | ||
|
||
linear_term_diag = lx.DiagonalLinearOperator(A) | ||
linear_term_full = lx.MatrixLinearOperator(jnp.diag(A)) | ||
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 | ||
term_diag = MultiTerm(LinearODETerm(linear_term_diag), ODETerm(non_linear_term)) | ||
term_full = MultiTerm(LinearODETerm(linear_term_full), ODETerm(non_linear_term)) | ||
saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) | ||
|
||
# Diagonal approach | ||
sol_diag = diffeqsolve( | ||
term_diag, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
) | ||
# Full matrix exponential | ||
sol_full = diffeqsolve( | ||
term_full, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
) | ||
assert jnp.allclose(sol_diag.ys, sol_full.ys, rtol=1e-4, atol=1e-4) | ||
|
||
|
||
def test_matrix_exponential(): | ||
"""Test if normal solver accept our term structure, and make sure results are the same.""" | ||
A = -jnp.abs(jr.normal(jr.key(0), (10, 10))) | ||
y0 = jr.normal(jr.key(42), (10,)) | ||
|
||
linear_term = lx.MatrixLinearOperator(A) | ||
non_linear_term = lambda t, y, args: 2 * jnp.cos(y) ** 3 # noqa: E731 | ||
term = MultiTerm(LinearODETerm(linear_term), ODETerm(non_linear_term)) | ||
saveat = SaveAt(ts=jnp.linspace(0, 3, 100)) | ||
|
||
# Exponential solver | ||
sol_exp = diffeqsolve( | ||
term, | ||
ExponentialEuler(), | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
) | ||
|
||
solver = Dopri5() | ||
stepsize_controller = PIDController(rtol=1e-5, atol=1e-5) | ||
sol_dopri = diffeqsolve( | ||
term, | ||
solver, | ||
t0=0, | ||
t1=3, | ||
dt0=1e-3, | ||
y0=y0, | ||
saveat=saveat, | ||
stepsize_controller=stepsize_controller, | ||
) | ||
|
||
jnp.allclose(sol_exp.ys, sol_dopri.ys, rtol=2e-3, atol=2e-3) |