Skip to content

Commit

Permalink
Everything SRK related squashed on top of diffrax/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Apr 23, 2024
1 parent 41c58a6 commit e645e24
Show file tree
Hide file tree
Showing 28 changed files with 2,574 additions and 163 deletions.
13 changes: 12 additions & 1 deletion diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
VirtualBrownianTree as VirtualBrownianTree,
)
from ._custom_types import (
AbstractBrownianReturn as AbstractBrownianReturn,
AbstractBrownianIncrement as AbstractBrownianIncrement,
AbstractSpaceTimeLevyArea as AbstractSpaceTimeLevyArea,
AbstractSpaceTimeTimeLevyArea as AbstractSpaceTimeTimeLevyArea,
BrownianIncrement as BrownianIncrement,
SpaceTimeLevyArea as SpaceTimeLevyArea,
SpaceTimeTimeLevyArea as SpaceTimeTimeLevyArea,
)
from ._event import (
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
Expand Down Expand Up @@ -69,6 +72,7 @@
AbstractRungeKutta as AbstractRungeKutta,
AbstractSDIRK as AbstractSDIRK,
AbstractSolver as AbstractSolver,
AbstractSRK as AbstractSRK,
AbstractStratonovichSolver as AbstractStratonovichSolver,
AbstractWrappedSolver as AbstractWrappedSolver,
Bosh3 as Bosh3,
Expand All @@ -78,6 +82,7 @@
Dopri8 as Dopri8,
Euler as Euler,
EulerHeun as EulerHeun,
GeneralShARK as GeneralShARK,
HalfSolver as HalfSolver,
Heun as Heun,
ImplicitEuler as ImplicitEuler,
Expand All @@ -93,8 +98,14 @@
MultiButcherTableau as MultiButcherTableau,
Ralston as Ralston,
ReversibleHeun as ReversibleHeun,
SEA as SEA,
SemiImplicitEuler as SemiImplicitEuler,
ShARK as ShARK,
Sil3 as Sil3,
SlowRK as SlowRK,
SPaRK as SPaRK,
SRA1 as SRA1,
StochasticButcherTableau as StochasticButcherTableau,
StratonovichMilstein as StratonovichMilstein,
Tsit5 as Tsit5,
)
Expand Down
25 changes: 24 additions & 1 deletion diffrax/_autocitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,23 @@
from ._saveat import SubSaveAt
from ._solver import (
AbstractImplicitSolver,
AbstractItoSolver,
AbstractSRK,
AbstractStratonovichSolver,
Dopri5,
Dopri8,
GeneralShARK,
Kvaerno3,
Kvaerno4,
Kvaerno5,
LeapfrogMidpoint,
ReversibleHeun,
SEA,
SemiImplicitEuler,
ShARK,
SlowRK,
SPaRK,
SRA1,
Tsit5,
)
from ._step_size_controller import PIDController
Expand Down Expand Up @@ -374,7 +383,15 @@ def _backsolve_rms_norm(adjoint):

@citation_rules.append
def _explicit_solver(solver, terms=None):
if not isinstance(solver, AbstractImplicitSolver) and not is_sde(terms):
if not isinstance(
solver,
(
AbstractImplicitSolver,
AbstractSRK,
AbstractItoSolver,
AbstractStratonovichSolver,
),
) and not is_sde(terms):
return r"""
% You are using an explicit solver, and may wish to cite the standard textbook:
@book{hairer2008solving-i,
Expand Down Expand Up @@ -467,6 +484,12 @@ def _solvers(solver, saveat=None):
Kvaerno5,
ReversibleHeun,
LeapfrogMidpoint,
ShARK,
SRA1,
SlowRK,
GeneralShARK,
SPaRK,
SEA,
):
return (
r"""
Expand Down
11 changes: 8 additions & 3 deletions diffrax/_brownian/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@
from equinox.internal import AbstractVar
from jaxtyping import Array, PyTree

from .._custom_types import AbstractBrownianReturn, RealScalarLike
from .._custom_types import (
AbstractBrownianIncrement,
BrownianIncrement,
RealScalarLike,
SpaceTimeLevyArea,
)
from .._path import AbstractPath


_Control = TypeVar("_Control", bound=Union[PyTree[Array], AbstractBrownianReturn])
_Control = TypeVar("_Control", bound=Union[PyTree[Array], AbstractBrownianIncrement])


class AbstractBrownianPath(AbstractPath[_Control]):
"""Abstract base class for all Brownian paths."""

levy_area: AbstractVar[type[AbstractBrownianReturn]]
levy_area: AbstractVar[type[Union[BrownianIncrement, SpaceTimeLevyArea]]]

@abc.abstractmethod
def evaluate(
Expand Down
7 changes: 4 additions & 3 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from jaxtyping import Array, PRNGKeyArray, PyTree

from .._custom_types import (
AbstractBrownianIncrement,
BrownianIncrement,
levy_tree_transpose,
RealScalarLike,
Expand Down Expand Up @@ -87,7 +88,7 @@ def evaluate(
t1: Optional[RealScalarLike] = None,
left: bool = True,
use_levy: bool = False,
) -> Union[PyTree[Array], BrownianIncrement, SpaceTimeLevyArea]:
) -> Union[PyTree[Array], AbstractBrownianIncrement]:
del left
if t1 is None:
dtype = jnp.result_type(t0)
Expand Down Expand Up @@ -129,13 +130,13 @@ def _evaluate_leaf(
):
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
w = jr.normal(key, shape.shape, shape.dtype) * w_std
dt = t1 - t0
dt = jnp.asarray(t1 - t0, dtype=shape.dtype)

if levy_area is SpaceTimeLevyArea:
key, key_hh = jr.split(key, 2)
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh, K=None)
levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh)
elif levy_area is BrownianIncrement:
levy_val = BrownianIncrement(dt=dt, W=w)
else:
Expand Down
64 changes: 36 additions & 28 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from jaxtyping import Array, Float, PRNGKeyArray, PyTree

from .._custom_types import (
AbstractBrownianReturn,
AbstractBrownianIncrement,
BoolScalarLike,
BrownianIncrement,
IntScalarLike,
Expand Down Expand Up @@ -59,7 +59,7 @@
Float[Array, " *shape"], Float[Array, " *shape"], Float[Array, " *shape"]
]
_Spline: TypeAlias = Literal["sqrt", "quad", "zero"]
_BrownianReturn = TypeVar("_BrownianReturn", bound=AbstractBrownianReturn)
_BrownianReturn = TypeVar("_BrownianReturn", bound=AbstractBrownianIncrement)


class _State(eqx.Module):
Expand All @@ -71,7 +71,7 @@ class _State(eqx.Module):
bkk_s_u_su: Optional[FloatTriple] # \bar{K}_s, _u, _{s,u}


def _levy_diff(_, x0: tuple, x1: tuple) -> Union[BrownianIncrement, SpaceTimeLevyArea]:
def _levy_diff(_, x0: tuple, x1: tuple) -> AbstractBrownianIncrement:
r"""Computes $(W_{s,u}, H_{s,u})$ from $(W_s, \bar{H}_{s,u})$ and
$(W_u, \bar{H}_u)$, where $\bar{H}_u = u * H_u$.
Expand Down Expand Up @@ -105,18 +105,18 @@ def _levy_diff(_, x0: tuple, x1: tuple) -> Union[BrownianIncrement, SpaceTimeLev
u_bb_s = dt1 * w0 - dt0 * w1
bhh_su = bhh1 - bhh0 - 0.5 * u_bb_s # bhh_su = H_{s,u} * (u-s)
hh_su = inverse_su * bhh_su
return SpaceTimeLevyArea(dt=su, W=w_su, H=hh_su, K=None)
return SpaceTimeLevyArea(dt=su, W=w_su, H=hh_su)
else:
assert False


def _make_levy_val(_, x: tuple) -> Union[BrownianIncrement, SpaceTimeLevyArea]:
def _make_levy_val(_, x: tuple) -> AbstractBrownianIncrement:
if len(x) == 2:
dt, w = x
return BrownianIncrement(dt=dt, W=w)
elif len(x) == 4:
dt, w, hh, bhh = x
return SpaceTimeLevyArea(dt=dt, W=w, H=hh, K=None)
return SpaceTimeLevyArea(dt=dt, W=w, H=hh)
else:
assert False

Expand Down Expand Up @@ -243,7 +243,7 @@ def evaluate(
t1: Optional[RealScalarLike] = None,
left: bool = True,
use_levy: bool = False,
) -> Union[PyTree[Array], BrownianIncrement, SpaceTimeLevyArea]:
) -> Union[PyTree[Array], AbstractBrownianIncrement]:
t0 = eqxi.nondifferentiable(t0, name="t0")
# map the interval [self.t0, self.t1] onto [0,1]
t0 = linear_rescale(self.t0, t0, self.t1)
Expand Down Expand Up @@ -294,11 +294,14 @@ def _evaluate_leaf(
bhh = (bhh_0, bhh_1, bhh_1)
bkk = None

else:
elif self.levy_area is BrownianIncrement:
state_key, init_key_w = jr.split(key, 2)
bhh = None
bkk = None

else:
assert False

w_0 = jnp.zeros(shape, dtype)
w_1 = jr.normal(init_key_w, shape, dtype)
w = (w_0, w_1, w_1)
Expand Down Expand Up @@ -334,11 +337,13 @@ def _body_fun(_state: _State):

_w = _split_interval(_cond, _w_stu, _w_inc)
_bkk = None
if self.levy_area is not BrownianIncrement:
if self.levy_area is SpaceTimeLevyArea:
assert _bhh_stu is not None and _bhh_st_tu is not None
_bhh = _split_interval(_cond, _bhh_stu, _bhh_st_tu)
else:
elif self.levy_area is BrownianIncrement:
_bhh = None
else:
assert False

return _State(
level=_level,
Expand All @@ -359,23 +364,7 @@ def _body_fun(_state: _State):
ru = jax.nn.relu(su - sr)

w_s, w_u, w_su = final_state.w_s_u_su

if self.levy_area is BrownianIncrement:
w_mean = w_s + sr / su * w_su
if self._spline == "sqrt":
z = jr.normal(final_state.key, shape, dtype)
bb = jnp.sqrt(sr * ru / su) * z
elif self._spline == "quad":
z = jr.normal(final_state.key, shape, dtype)
bb = (sr * ru / su) * z
elif self._spline == "zero":
bb = jnp.zeros(shape, dtype)
else:
assert False
w_r = w_mean + bb
return r, w_r

elif self.levy_area is SpaceTimeLevyArea:
if self.levy_area is SpaceTimeLevyArea:
# This is based on Theorem 6.1.4 of Foster's thesis (see above).

assert final_state.bhh_s_u_su is not None
Expand Down Expand Up @@ -414,6 +403,21 @@ def _body_fun(_state: _State):
inverse_r = 1 / jnp.where(jnp.abs(r) < jnp.finfo(r).eps, jnp.inf, r)
hh_r = inverse_r * bhh_r

elif self.levy_area is BrownianIncrement:
w_mean = w_s + sr / su * w_su
if self._spline == "sqrt":
z = jr.normal(final_state.key, shape, dtype)
bb = jnp.sqrt(sr * ru / su) * z
elif self._spline == "quad":
z = jr.normal(final_state.key, shape, dtype)
bb = (sr * ru / su) * z
elif self._spline == "zero":
bb = jnp.zeros(shape, dtype)
else:
assert False
w_r = w_mean + bb
return r, w_r

else:
assert False

Expand Down Expand Up @@ -499,7 +503,7 @@ def _brownian_arch(
bkk_stu = None
bkk_st_tu = None

else:
elif self.levy_area is BrownianIncrement:
assert _state.bhh_s_u_su is None
assert _state.bkk_s_u_su is None
mean = 0.5 * w_su
Expand All @@ -510,4 +514,8 @@ def _brownian_arch(
w_t = w_s + w_st
w_stu = (w_s, w_t, w_u)
bhh_stu, bhh_st_tu, bkk_stu, bkk_st_tu = None, None, None, None

else:
assert False

return t, w_stu, w_st_tu, keys, bhh_stu, bhh_st_tu, bkk_stu, bkk_st_tu
52 changes: 33 additions & 19 deletions diffrax/_custom_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import typing
from typing import Any, Optional, TYPE_CHECKING, Union
from typing import Any, TYPE_CHECKING, Union

import equinox as eqx
import equinox.internal as eqxi
Expand Down Expand Up @@ -53,40 +53,54 @@
sentinel: Any = eqxi.doc_repr(object(), "sentinel")


class AbstractBrownianReturn(eqx.Module):
dt: eqx.AbstractVar[PyTree]
W: eqx.AbstractVar[PyTree]
class AbstractBrownianIncrement(eqx.Module):
dt: eqx.AbstractVar[PyTree[FloatScalarLike]]
W: eqx.AbstractVar[PyTree[Array]]


class BrownianIncrement(AbstractBrownianReturn):
dt: PyTree
W: PyTree
class AbstractSpaceTimeLevyArea(AbstractBrownianIncrement):
H: eqx.AbstractVar[PyTree[Array]]


class SpaceTimeLevyArea(AbstractBrownianReturn):
dt: PyTree
W: PyTree
H: Optional[PyTree]
K: Optional[PyTree]
class AbstractSpaceTimeTimeLevyArea(AbstractSpaceTimeLevyArea):
K: eqx.AbstractVar[PyTree[Array]]


class BrownianIncrement(AbstractBrownianIncrement):
dt: PyTree[FloatScalarLike]
W: PyTree[Array]


class SpaceTimeLevyArea(AbstractSpaceTimeLevyArea):
dt: PyTree[FloatScalarLike]
W: PyTree[Array]
H: PyTree[Array]


class SpaceTimeTimeLevyArea(AbstractSpaceTimeTimeLevyArea):
dt: PyTree[FloatScalarLike]
W: PyTree[Array]
H: PyTree[Array]
K: PyTree[Array]


def levy_tree_transpose(
tree_shape, tree: PyTree[AbstractBrownianReturn]
) -> AbstractBrownianReturn:
"""Helper that takes a PyTree of AbstractBrownianReturn and transposes
into an AbstractBrownianReturn of PyTrees.
tree_shape, tree: PyTree[AbstractBrownianIncrement]
) -> AbstractBrownianIncrement:
"""Helper that takes a `PyTree `of `AbstractBrownianIncrement`s and transposes
into an `AbstractBrownianIncrement` of `PyTree`s.
**Arguments:**
- `tree_shape`: Corresponds to `outer_treedef` in `jax.tree_transpose`.
- `tree`: the PyTree of AbstractBrownianReturn to transpose.
- `tree`: the `PyTree` of `AbstractBrownianIncrement`s to transpose.
**Returns:**
An `AbstractBrownianReturn` of PyTrees.
An `AbstractBrownianIncrement` of `PyTree`s.
"""
inner_tree = jtu.tree_leaves(
tree, is_leaf=lambda x: isinstance(x, AbstractBrownianReturn)
tree, is_leaf=lambda x: isinstance(x, AbstractBrownianIncrement)
)[0]
inner_tree_shape = jtu.tree_structure(inner_tree)
return jtu.tree_transpose(
Expand Down
Loading

0 comments on commit e645e24

Please sign in to comment.