Skip to content

Commit

Permalink
Added tests for new SRKs into test_integrate and improved srk_exapmle…
Browse files Browse the repository at this point in the history
….ipynb
  • Loading branch information
andyElking committed Dec 31, 2023
1 parent 18d6c58 commit c56dcf4
Show file tree
Hide file tree
Showing 12 changed files with 600 additions and 227 deletions.
2 changes: 1 addition & 1 deletion diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
VirtualBrownianTree as VirtualBrownianTree,
)
from ._custom_types import (
levy_tree_transpose as levy_tree_transpose,
LevyVal as LevyVal,
)
from ._event import (
Expand Down Expand Up @@ -72,6 +71,7 @@
Dopri8 as Dopri8,
Euler as Euler,
EulerHeun as EulerHeun,
FosterSRK as FosterSRK,
HalfSolver as HalfSolver,
Heun as Heun,
ImplicitEuler as ImplicitEuler,
Expand Down
15 changes: 0 additions & 15 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,21 +425,6 @@ def _brownian_arch(
there for the sake of a future extension with "space-time-time" Levy area
and should be None for now.
??? cite "Reference"
Based on section 6.1 of
```bibtex
@phdthesis{foster2020a,
publisher = {University of Oxford},
school = {University of Oxford},
title = {Numerical approximations for stochastic differential equations},
author = {Foster, James M.},
year = {2020}
}
In particular see Theorem 6.1.6.
```
**Arguments:**
- `_state`: The state of the Brownian tree
Expand Down
1 change: 1 addition & 0 deletions diffrax/_solver/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .dopri8 import Dopri8 as Dopri8
from .euler import Euler as Euler
from .euler_heun import EulerHeun as EulerHeun
from .foster_srk import FosterSRK as FosterSRK
from .heun import Heun as Heun
from .implicit_euler import ImplicitEuler as ImplicitEuler
from .kencarp3 import KenCarp3 as KenCarp3
Expand Down
4 changes: 3 additions & 1 deletion diffrax/_solver/sea.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class SEA(AbstractSRK):
r"""Shifted Euler method for SDEs with additive noise.
It has a local error of $O(h^2)$ compared to
standard Euler-Maruyama, which has $O(h^{1.5})$.
Uses one evaluation of the vector field per step and
has order 1 for additive noise SDEs.
Based on equation $(5.8)$ in
??? cite "Reference"
Expand All @@ -44,4 +46,4 @@ def order(self, terms):
return 1

def strong_order(self, terms):
return 0.5
return 1
3 changes: 2 additions & 1 deletion diffrax/_solver/shark.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@

class ShARK(AbstractSRK):
r"""Shifted Additive-noise Runge-Kutta method for SDEs by James Foster.
Applied to SDEs with additive noise, it converges strongly with order 1.5.
Applied to SDEs with additive noise, it has strong order 1.5.
Uses two evaluations of the vector field per step.
Based on equation $(6.1)$ in
Expand Down
2 changes: 2 additions & 0 deletions diffrax/_solver/sra1.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

class SRA1(AbstractSRK):
r"""Based on the SRA1 method by Andreas Rößler.
Works only for SDEs with additive noise, applied to whichit has strong order 1.5.
Uses two evaluations of the vector field per step.
??? cite "Reference"
Expand Down
45 changes: 29 additions & 16 deletions diffrax/_solver/srk.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .._custom_types import (
BoolScalarLike,
DenseInfo,
IntScalarLike,
LevyVal,
RealScalarLike,
VF,
Expand All @@ -24,10 +25,19 @@
from .base import AbstractStratonovichSolver


_ErrorEstimate: TypeAlias = None
_ErrorEstimate: TypeAlias = Optional[Y]
_SolverState: TypeAlias = None
_CarryType: TypeAlias = tuple[int, PyTree[Array], PyTree[Array], PyTree[Array]]
_LA: TypeAlias = Literal["", "space-time", "space-time-time"]
_CarryType: TypeAlias = tuple[
IntScalarLike, PyTree[Array], PyTree[Array], PyTree[Array]
]

# To Patrick:
# This is on purpose distinct from LevyArea in diffrax/_custom_types.py
# because I think it is cleaner if this SRK is "perpendicular" to _brownian
# in the sense that this can utilize sttla (which it can) even if _brownian
# cannot generate it so far. And this allows me to later update _brownian
# without having to update this SRK.
MinimalLevyArea: TypeAlias = Literal["", "space-time", "space-time-time"]


@dataclass(frozen=True)
Expand Down Expand Up @@ -240,15 +250,15 @@ class AbstractSRK(AbstractStratonovichSolver):
interpolation_cls = LocalLinearInterpolation
tableau: StochasticButcherTableau

minimal_levy_area: _LA
minimal_levy_area: MinimalLevyArea

def __init__(self):
if self.tableau.bK is not None:
self.minimal_levy_area: _LA = "space-time-time"
self.minimal_levy_area: MinimalLevyArea = "space-time-time"
elif self.tableau.bH is not None:
self.minimal_levy_area: _LA = "space-time"
self.minimal_levy_area: MinimalLevyArea = "space-time"
else:
self.minimal_levy_area: _LA = ""
self.minimal_levy_area: MinimalLevyArea = ""

def init(
self,
Expand Down Expand Up @@ -283,7 +293,7 @@ def init(
)

if self.tableau.additive_noise:
# check that the vector field of the diffusion term is constant
# check that the vector field of the diffusion term does not depend on y
ones_like_y0 = jtu.tree_map(lambda x: jnp.ones_like(x), y0)
_, y_sigma = eqx.filter_jvp(
lambda y: diffusion.vf(t0, y, args), (y0,), (ones_like_y0,)
Expand Down Expand Up @@ -323,7 +333,7 @@ def step(
# time increment
h = t1 - t0

# First all the drift related stuff
# First the drift related stuff
a = self._embed_a_lower(self.tableau.a, dtype)
c = jnp.asarray(np.insert(self.tableau.c, 0, 0.0), dtype=dtype)
b_sol = jnp.asarray(self.tableau.b_sol, dtype=dtype)
Expand Down Expand Up @@ -385,11 +395,13 @@ def aux_add_levy(w_leaf, *levy_leaves):

return aux_add_levy

a_levy = [] # will contain cH if additive_noise=True, aH otherwise
# later other kinds of Levy area will be added to this list
a_levy = [] # if noise is additive this is [cH, cK] (if those entries exist)
# otherwise this is [aH, aK] (if those entries exist)

levy_gs_list = [] # will contain levy * g(t0 + c_j * h, z_j) for each stage j
# and for each type of levy area (e.g. H, K, etc.)
# where levy is either H or K (if those entries exist)
# this is similar to tfs or wgs, but for the Levy area(s)

if additive_noise:
# compute g once since it is constant
g0 = diffusion.vf(t0, y0, args)
Expand All @@ -410,7 +422,7 @@ def aux_add_levy(w_leaf, *levy_leaves):
# Since the carry of lax.scan needs to have constant shape,
# we initialise a list of zeros of the same shape as y0, which will get
# filled with the values of W * g(t0 + c_j * h, z_j) at each stage
wg_list = make_zeros()
wgs = make_zeros()
# do the same for each type of Levy area
if stla:
levy_gs_list.append(make_zeros())
Expand All @@ -419,7 +431,7 @@ def aux_add_levy(w_leaf, *levy_leaves):
levy_gs_list.append(make_zeros())
a_levy.append(self._embed_a_lower(self.tableau.aK, dtype))

carry: _CarryType = (0, tfs, wg_list, levy_gs_list)
carry: _CarryType = (0, tfs, wgs, levy_gs_list)

a_w = self._embed_a_lower(self.tableau.aW, dtype)

Expand All @@ -435,6 +447,7 @@ def stage(
x: tuple[Array, Array, Array, list[Array]],
):
# Represents the jth stage of the SRK.

a_j, c_j, a_w_j, a_levy_list_j = x
# a_levy_list_j = [aH_j, aK_j] (if those entries exist) where
# aH_j is the row in the aH matrix corresponding to stage j
Expand All @@ -443,8 +456,8 @@ def stage(

if additive_noise:
# carry = (j, _tfs) where
# _tfs = Array[hk1, hk2, ..., hk_{j-1}, 0, 0, ..., 0]
# hki = drift.vf_prod(t0 + c_i*h, y_i, args, h) (i.e. hki = h * k_i)
# _tfs = Array[tf_1, tf_2, ..., hk_{j-1}, 0, 0, ..., 0]
# tf_i = drift.vf_prod(t0 + c_i*h, y_i, args, h) (i.e. tf_i = h * f_i)
assert _wgs is None and _levy_gs_list is None
assert isinstance(levy_gs_list, list)
_diffusion_result = jtu.tree_map(
Expand Down
28 changes: 14 additions & 14 deletions examples/langevin.ipynb

Large diffs are not rendered by default.

598 changes: 461 additions & 137 deletions examples/srk_example.ipynb

Large diffs are not rendered by default.

31 changes: 19 additions & 12 deletions test/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
UnsafeBrownianPath,
VirtualBrownianTree,
)
from diffrax._custom_types import _LA, RealScalarLike

# I'm not sure if this is the right way to import these types
from diffrax._custom_types import LevyArea, RealScalarLike
from jax import numpy as jnp
from jaxtyping import PyTree

Expand Down Expand Up @@ -97,7 +99,7 @@ def tree_allclose(x, y, *, rtol=1e-5, atol=1e-8, equal_nan=False):
def path_l2_dist(ys1: PyTree[jax.Array], ys2: PyTree[jax.Array]):
# first compute the square of the difference and sum over
# all but the first two axes (which represent the number of samples
# and the length of saveat). Also sum all the PyTree leaves
# and the length of saveat). Also sum all the PyTree leaves.
def sum_square_diff(y1, y2):
square_diff = jnp.square(y1 - y2)
# sum all but the first two axes
Expand Down Expand Up @@ -134,7 +136,7 @@ def get_dtype(self):
def get_bm(
self,
key,
levy_area: _LA = "space-time",
levy_area: LevyArea = "space-time",
use_tree=True,
tol=2**-14,
):
Expand All @@ -157,10 +159,12 @@ def _batch_sde_solve(
sde: SDE,
dt0,
solver,
levy_area: _LA,
levy_area: LevyArea,
):
_saveat = SaveAt(ts=[sde.t1])

@jax.jit
@jax.vmap
def end_value(key):
path = sde.get_bm(key, levy_area=levy_area, use_tree=True)
terms = sde.get_terms(path)
Expand All @@ -177,18 +181,18 @@ def end_value(key):
)
return sol.ys

return jax.vmap(end_value)(keys)
return end_value(keys)


def sde_solver_order(keys, sde: SDE, solver, ref_solver, dt_precise, dts):
# TODO: remove this once we have a better way to handle this
def sde_solver_strong_order(keys, sde: SDE, solver, ref_solver, dt_precise, dts):
if hasattr(solver, "minimal_levy_area"):
levy_area = solver.minimal_levy_area
else:
levy_area = ""

# Stricter levy_area requirements are a longer string, so only override
# solver's levy_area if the ref_solver requires more levy area
# TODO: this is a bit hacky, but I'm not sure how else to do it
if hasattr(ref_solver, "minimal_levy_area") and len(
ref_solver.minimal_levy_area
) > len(levy_area):
Expand Down Expand Up @@ -223,21 +227,24 @@ def diffusion(t, y, args):
return 0.25 * mlp(y).reshape(3, noise_dim)


def get_mlp_sde(t0=0.3, t1=15.0, dtype=jnp.float32, key=jr.PRNGKey(0), noise_dim=1):
def get_mlp_sde(t0, t1, dtype, key, noise_dim):
driftkey, diffusionkey, ykey = jr.split(key, 3)
# To Patrick: I had to increase the depth of these MLPs, otherwise many SDE
# solvers had order ~0.72 which is more than 0.5 + 0.2, which is the maximal
# tolerated order. I think the issue was that it was too linear and too easy.
drift_mlp = eqx.nn.MLP(
in_size=3,
out_size=3,
width_size=8,
depth=1,
depth=2,
activation=_squareplus,
key=driftkey,
)
diffusion_mlp = eqx.nn.MLP(
in_size=3,
out_size=3 * noise_dim,
width_size=8,
depth=1,
depth=2,
activation=_squareplus,
final_activation=jnp.tanh,
key=diffusionkey,
Expand All @@ -251,13 +258,13 @@ def get_terms(bm):
return SDE(get_terms, args, y0, t0, t1, (noise_dim,))


def get_time_sde(t0=0.3, t1=15.0, dtype=jnp.float32, key=jr.PRNGKey(5678), noise_dim=1):
def get_time_sde(t0, t1, dtype, key, noise_dim):
driftkey, diffusionkey, ykey = jr.split(key, 3)
drift_mlp = eqx.nn.MLP(
in_size=3,
out_size=3,
width_size=8,
depth=1,
depth=2,
activation=_squareplus,
key=driftkey,
)
Expand Down
5 changes: 2 additions & 3 deletions test/test_brownian.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
import jax.tree_util as jtu
import pytest
import scipy.stats as stats
from diffrax._custom_types import LevyArea


_Spline: TypeAlias = Literal["quad", "sqrt", "zero"]
Expand All @@ -33,7 +32,7 @@ def _make_struct(shape, dtype):
)
@pytest.mark.parametrize("levy_area", ["", "space-time"])
@pytest.mark.parametrize("use_levy", (False, True))
def test_shape_and_dtype(ctr, levy_area: LevyArea, use_levy, getkey):
def test_shape_and_dtype(ctr, levy_area, use_levy, getkey):
t0 = 0
t1 = 2

Expand Down Expand Up @@ -115,7 +114,7 @@ def is_tuple_of_ints(obj):
)
@pytest.mark.parametrize("levy_area", ["", "space-time"])
@pytest.mark.parametrize("use_levy", (False, True))
def test_statistics(ctr, levy_area: LevyArea, use_levy):
def test_statistics(ctr, levy_area, use_levy):
# Deterministic key for this test; not using getkey()
key = jr.PRNGKey(5678)
keys = jr.split(key, 10000)
Expand Down
Loading

0 comments on commit c56dcf4

Please sign in to comment.