Skip to content

Commit

Permalink
Fixed LangevinTerm YAAAYYYYY
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Aug 9, 2024
1 parent 34aa6cb commit 3e3656b
Show file tree
Hide file tree
Showing 12 changed files with 231 additions and 179 deletions.
2 changes: 1 addition & 1 deletion diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
from ._term import (
AbstractTerm as AbstractTerm,
ControlTerm as ControlTerm,
LangevinTerm as LangevinTerm,
make_langevin_term as make_langevin_term,
MultiTerm as MultiTerm,
ODETerm as ODETerm,
WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm,
Expand Down
26 changes: 9 additions & 17 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
ConstantStepSize,
StepTo,
)
from ._term import AbstractTerm, LangevinTerm, MultiTerm, ODETerm, WrapTerm
from ._term import AbstractTerm, MultiTerm, ODETerm, WrapTerm
from ._typing import better_isinstance, get_args_of, get_origin_no_specials


Expand Down Expand Up @@ -162,15 +162,12 @@ def _check(term_cls, term, term_contr_kwargs, yi):
pass
elif n_term_args == 2:
vf_type_expected, control_type_expected = term_args
if not isinstance(term, LangevinTerm):
# TODO: The line below causes problems with LangevinTerm
# Please help me fix this
vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args)
vf_type_compatible = eqx.filter_eval_shape(
better_isinstance, vf_type, vf_type_expected
)
if not vf_type_compatible:
raise ValueError
vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args)
vf_type_compatible = eqx.filter_eval_shape(
better_isinstance, vf_type, vf_type_expected
)
if not vf_type_compatible:
raise ValueError

contr = ft.partial(term.contr, **term_contr_kwargs)
# Work around https://github.com/google/jax/issues/21825
Expand Down Expand Up @@ -1008,10 +1005,6 @@ def _promote(yi):
y0 = jtu.tree_map(_promote, y0)
del timelikes

# Langevin terms must be unwrapped unless `term_structure=LangevinTerm
if isinstance(terms, LangevinTerm) and solver.term_structure != LangevinTerm:
terms = terms.term

# Backward compatibility
if isinstance(
solver, (EulerHeun, ItoMilstein, StratonovichMilstein)
Expand Down Expand Up @@ -1070,14 +1063,13 @@ def _promote(yi):

def _wrap(term):
assert isinstance(term, AbstractTerm)
assert not isinstance(term, (MultiTerm, LangevinTerm))
assert not isinstance(term, MultiTerm)
return WrapTerm(term, direction)

terms = jtu.tree_map(
_wrap,
terms,
is_leaf=lambda x: isinstance(x, AbstractTerm)
and not isinstance(x, (MultiTerm, LangevinTerm)),
is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm),
)

if isinstance(solver, AbstractImplicitSolver):
Expand Down
22 changes: 13 additions & 9 deletions diffrax/_solver/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,13 @@
RealScalarLike,
)
from .._local_interpolation import LocalLinearInterpolation
from .._term import _LangevinArgs, LangevinTerm, LangevinTuple, LangevinX
from .langevin_srk import AbstractCoeffs, AbstractLangevinSRK, SolverState
from .._term import _LangevinTuple, _LangevinX
from .langevin_srk import (
_LangevinArgs,
AbstractCoeffs,
AbstractLangevinSRK,
SolverState,
)


# UBU evaluates at l = (3 -sqrt(3))/6, at r = (3 + sqrt(3))/6 and at 1,
Expand All @@ -36,7 +41,7 @@ def __init__(self, beta, a1, b1, aa, chh):
self.dtype = jnp.result_type(*all_leaves)


_ErrorEstimate = LangevinTuple
_ErrorEstimate = _LangevinTuple


class ALIGN(AbstractLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]):
Expand All @@ -52,7 +57,6 @@ class ALIGN(AbstractLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]):
$W$ is a Brownian motion.
"""

term_structure = LangevinTerm
interpolation_cls = LocalLinearInterpolation
taylor_threshold: RealScalarLike = eqx.field(static=True)
_coeffs_structure = jtu.tree_structure(
Expand Down Expand Up @@ -131,15 +135,15 @@ def _tay_coeffs_single(self, c: Array) -> _ALIGNCoeffs:
def _compute_step(
h: RealScalarLike,
levy: AbstractSpaceTimeLevyArea,
x0: LangevinX,
v0: LangevinX,
x0: _LangevinX,
v0: _LangevinX,
langevin_args: _LangevinArgs,
coeffs: _ALIGNCoeffs,
st: SolverState,
) -> tuple[LangevinX, LangevinX, LangevinX, LangevinTuple]:
) -> tuple[_LangevinX, _LangevinX, _LangevinX, _LangevinTuple]:
dtypes = jtu.tree_map(jnp.dtype, x0)
w: LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)
w: _LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: _LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)

gamma, u, f = langevin_args

Expand Down
12 changes: 7 additions & 5 deletions diffrax/_solver/euler_heun.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from collections.abc import Callable
from typing import ClassVar
from typing import Any, ClassVar
from typing_extensions import TypeAlias

from equinox.internal import ω

from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y
from .._local_interpolation import LocalLinearInterpolation
from .._solution import RESULTS
from .._term import AbstractTerm, MultiTerm, ODETerm
from .._term import AbstractTerm, MultiTerm
from .base import AbstractStratonovichSolver


Expand All @@ -26,7 +26,9 @@ class EulerHeun(AbstractStratonovichSolver):
Used to solve SDEs, and converges to the Stratonovich solution.
"""

term_structure: ClassVar = MultiTerm[tuple[ODETerm, AbstractTerm]]
term_structure: ClassVar = MultiTerm[
tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]
]
interpolation_cls: ClassVar[
Callable[..., LocalLinearInterpolation]
] = LocalLinearInterpolation
Expand All @@ -39,7 +41,7 @@ def strong_order(self, terms):

def init(
self,
terms: MultiTerm[tuple[ODETerm, AbstractTerm]],
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
Expand All @@ -49,7 +51,7 @@ def init(

def step(
self,
terms: MultiTerm[tuple[ODETerm, AbstractTerm]],
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
t0: RealScalarLike,
t1: RealScalarLike,
y0: Y,
Expand Down
84 changes: 56 additions & 28 deletions diffrax/_solver/langevin_srk.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import abc
from typing import Generic, TypeVar
from typing import Any, Callable, Generic, TypeVar

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -19,11 +19,42 @@
)
from .._local_interpolation import LocalLinearInterpolation
from .._solution import RESULTS
from .._term import _LangevinArgs, LangevinTerm, LangevinTuple, LangevinX
from .._term import (
_LangevinDiffusionTerm,
_LangevinDriftTerm,
_LangevinTuple,
_LangevinX,
AbstractTerm,
MultiTerm,
WrapTerm,
)
from .base import AbstractItoSolver, AbstractStratonovichSolver


_ErrorEstimate = TypeVar("_ErrorEstimate", None, LangevinTuple)
_ErrorEstimate = TypeVar("_ErrorEstimate", None, _LangevinTuple)
_LangevinArgs = tuple[_LangevinX, _LangevinX, Callable[[_LangevinX], _LangevinX]]


def get_args_from_terms(
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
) -> _LangevinArgs:
drift, diffusion = terms.terms
if isinstance(drift, WrapTerm):
unwrapped_drift = drift.term
assert isinstance(diffusion, WrapTerm)
unwrapped_diffusion = diffusion.term
assert isinstance(unwrapped_drift, _LangevinDriftTerm)
assert isinstance(unwrapped_diffusion, _LangevinDiffusionTerm)
gamma = unwrapped_drift.gamma
u = unwrapped_drift.u
f = unwrapped_drift.grad_f
else:
assert isinstance(drift, _LangevinDriftTerm)
assert isinstance(diffusion, _LangevinDiffusionTerm)
gamma = drift.gamma
u = drift.u
f = drift.grad_f
return gamma, u, f


class AbstractCoeffs(eqx.Module):
Expand All @@ -38,10 +69,10 @@ class AbstractCoeffs(eqx.Module):
# How should I work around this?
class SolverState(eqx.Module, Generic[_Coeffs]):
h: RealScalarLike
taylor_coeffs: PyTree[_Coeffs, "LangevinX"]
taylor_coeffs: PyTree[_Coeffs, "_LangevinX"]
coeffs: _Coeffs
rho: LangevinX
prev_f: LangevinX
rho: _LangevinX
prev_f: _LangevinX


# CONCERNING COEFFICIENTS:
Expand Down Expand Up @@ -77,7 +108,7 @@ class AbstractLangevinSRK(
QUIC_SORT.
"""

term_structure = LangevinTerm
term_structure = MultiTerm[tuple[_LangevinDriftTerm, _LangevinDiffusionTerm]]
interpolation_cls = LocalLinearInterpolation
taylor_threshold: RealScalarLike = eqx.field(static=True)
_coeffs_structure: eqx.AbstractClassVar[PyTreeDef]
Expand Down Expand Up @@ -119,7 +150,7 @@ def _eval_taylor(h, tay_coeffs: _Coeffs) -> _Coeffs:
)

def _recompute_coeffs(
self, h, gamma: LangevinX, tay_coeffs: PyTree[_Coeffs], state_h
self, h, gamma: _LangevinX, tay_coeffs: PyTree[_Coeffs], state_h
) -> _Coeffs:
def recompute_coeffs_leaf(c: ArrayLike, _tay_coeffs: _Coeffs):
# Used when the step-size h changes and coefficients need to be recomputed
Expand Down Expand Up @@ -167,30 +198,28 @@ def _choose(tay_leaf, direct_leaf):

def init(
self,
terms: LangevinTerm,
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
t0: RealScalarLike,
t1: RealScalarLike,
y0: LangevinTuple,
y0: _LangevinTuple,
args: PyTree,
) -> SolverState:
"""Precompute _SolverState which carries the Taylor coefficients and the
SRK coefficients (which can be computed from h and the Taylor coeffs).
Some solvers of this type are FSAL, so _SolverState also carries the previous
evaluation of grad_f.
"""
assert isinstance(terms, LangevinTerm)
gamma, u, f = terms.args # f is in fact grad(f)
drift, diffusion = terms.terms
gamma, u, f = get_args_from_terms(terms)

h = drift.contr(t0, t1)
x0, v0 = y0

def _check_shapes(_c, _u, _x, _v):
# assert _x.ndim in [0, 1]
assert _c.shape == _u.shape == _x.shape == _v.shape

assert jtu.tree_all(jtu.tree_map(_check_shapes, gamma, u, x0, v0))

h = t1 - t0

tay_coeffs = jtu.tree_map(self._comp_taylor_coeffs_leaf, gamma)
# tay_coeffs have the same tree structure as gamma, with each leaf being a
# _Coeffs and the arrays have an extra trailing dimension of 6
Expand All @@ -213,30 +242,30 @@ def _check_shapes(_c, _u, _x, _v):
def _compute_step(
h: RealScalarLike,
levy,
x0: LangevinX,
v0: LangevinX,
x0: _LangevinX,
v0: _LangevinX,
langevin_args: _LangevinArgs,
coeffs: _Coeffs,
st: SolverState,
) -> tuple[LangevinX, LangevinX, LangevinX, _ErrorEstimate]:
) -> tuple[_LangevinX, _LangevinX, _LangevinX, _ErrorEstimate]:
raise NotImplementedError

def step(
self,
terms: LangevinTerm,
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
t0: RealScalarLike,
t1: RealScalarLike,
y0: LangevinTuple,
y0: _LangevinTuple,
args: PyTree,
solver_state: SolverState,
made_jump: BoolScalarLike,
) -> tuple[LangevinTuple, _ErrorEstimate, DenseInfo, SolverState, RESULTS]:
) -> tuple[_LangevinTuple, _ErrorEstimate, DenseInfo, SolverState, RESULTS]:
del made_jump, args
st = solver_state
h = t1 - t0
assert isinstance(terms, LangevinTerm)
gamma, u, f = terms.args
drift, diffusion = terms.terms
gamma, u, f = get_args_from_terms(terms)

h = drift.contr(t0, t1)
h_state = st.h
tay: PyTree[_Coeffs] = st.taylor_coeffs
coeffs: _Coeffs = st.coeffs
Expand All @@ -250,7 +279,6 @@ def step(
coeffs,
)

drift, diffusion = terms.term.terms
# compute the Brownian increment and space-time Levy area
levy = diffusion.contr(t0, t1, use_levy=True)
assert isinstance(levy, self.minimal_levy_area), (
Expand All @@ -260,7 +288,7 @@ def step(

x0, v0 = y0
x_out, v_out, f_fsal, error = self._compute_step(
h, levy, x0, v0, terms.args, coeffs, st
h, levy, x0, v0, (gamma, u, f), coeffs, st
)

def check_shapes_dtypes(_x, _v, _f, _x0):
Expand Down Expand Up @@ -289,9 +317,9 @@ def check_shapes_dtypes(_x, _v, _f, _x0):

def func(
self,
terms: LangevinTerm,
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
t0: RealScalarLike,
y0: LangevinTuple,
y0: _LangevinTuple,
args: PyTree,
):
return terms.vf(t0, y0, args)
Loading

0 comments on commit 3e3656b

Please sign in to comment.