Skip to content

Commit

Permalink
Nits
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Aug 17, 2024
1 parent 3e3656b commit 2b80e3b
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 93 deletions.
14 changes: 7 additions & 7 deletions diffrax/_solver/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
RealScalarLike,
)
from .._local_interpolation import LocalLinearInterpolation
from .._term import _LangevinTuple, _LangevinX
from .._term import LangevinTuple, LangevinX
from .langevin_srk import (
_LangevinArgs,
AbstractCoeffs,
Expand Down Expand Up @@ -41,7 +41,7 @@ def __init__(self, beta, a1, b1, aa, chh):
self.dtype = jnp.result_type(*all_leaves)


_ErrorEstimate = _LangevinTuple
_ErrorEstimate = LangevinTuple


class ALIGN(AbstractLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]):
Expand Down Expand Up @@ -135,15 +135,15 @@ def _tay_coeffs_single(self, c: Array) -> _ALIGNCoeffs:
def _compute_step(
h: RealScalarLike,
levy: AbstractSpaceTimeLevyArea,
x0: _LangevinX,
v0: _LangevinX,
x0: LangevinX,
v0: LangevinX,
langevin_args: _LangevinArgs,
coeffs: _ALIGNCoeffs,
st: SolverState,
) -> tuple[_LangevinX, _LangevinX, _LangevinX, _LangevinTuple]:
) -> tuple[LangevinX, LangevinX, LangevinX, LangevinTuple]:
dtypes = jtu.tree_map(jnp.dtype, x0)
w: _LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: _LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)
w: LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)

gamma, u, f = langevin_args

Expand Down
69 changes: 31 additions & 38 deletions diffrax/_solver/langevin_srk.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import abc
from typing import Any, Callable, Generic, TypeVar
from collections.abc import Callable
from typing import Any, Generic, TypeVar

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -8,7 +9,7 @@
import jax.tree_util as jtu
from equinox import AbstractVar
from jax import vmap
from jax._src.tree_util import PyTreeDef
from jax.tree_util import PyTreeDef
from jaxtyping import Array, ArrayLike, PyTree

from .._custom_types import (
Expand All @@ -20,40 +21,35 @@
from .._local_interpolation import LocalLinearInterpolation
from .._solution import RESULTS
from .._term import (
_LangevinDiffusionTerm,
_LangevinDriftTerm,
_LangevinTuple,
_LangevinX,
AbstractTerm,
LangevinDiffusionTerm,
LangevinDriftTerm,
LangevinTuple,
LangevinX,
MultiTerm,
WrapTerm,
)
from .base import AbstractItoSolver, AbstractStratonovichSolver


_ErrorEstimate = TypeVar("_ErrorEstimate", None, _LangevinTuple)
_LangevinArgs = tuple[_LangevinX, _LangevinX, Callable[[_LangevinX], _LangevinX]]
_ErrorEstimate = TypeVar("_ErrorEstimate", None, LangevinTuple)
_LangevinArgs = tuple[LangevinX, LangevinX, Callable[[LangevinX], LangevinX]]


def get_args_from_terms(
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
) -> _LangevinArgs:
drift, diffusion = terms.terms
if isinstance(drift, WrapTerm):
unwrapped_drift = drift.term
assert isinstance(diffusion, WrapTerm)
unwrapped_diffusion = diffusion.term
assert isinstance(unwrapped_drift, _LangevinDriftTerm)
assert isinstance(unwrapped_diffusion, _LangevinDiffusionTerm)
gamma = unwrapped_drift.gamma
u = unwrapped_drift.u
f = unwrapped_drift.grad_f
else:
assert isinstance(drift, _LangevinDriftTerm)
assert isinstance(diffusion, _LangevinDiffusionTerm)
gamma = drift.gamma
u = drift.u
f = drift.grad_f
drift = drift.term
diffusion = diffusion.term

assert isinstance(drift, LangevinDriftTerm)
assert isinstance(diffusion, LangevinDiffusionTerm)
gamma = drift.gamma
u = drift.u
f = drift.grad_f
return gamma, u, f


Expand All @@ -64,15 +60,12 @@ class AbstractCoeffs(eqx.Module):
_Coeffs = TypeVar("_Coeffs", bound=AbstractCoeffs)


# TODO: I'm not sure if I can use the _Coeffs type here,
# given that I do not use Generic[_Coeffs] in the class definition.
# How should I work around this?
class SolverState(eqx.Module, Generic[_Coeffs]):
h: RealScalarLike
taylor_coeffs: PyTree[_Coeffs, "_LangevinX"]
taylor_coeffs: PyTree[_Coeffs, "LangevinX"]
coeffs: _Coeffs
rho: _LangevinX
prev_f: _LangevinX
rho: LangevinX
prev_f: LangevinX


# CONCERNING COEFFICIENTS:
Expand Down Expand Up @@ -104,11 +97,11 @@ class AbstractLangevinSRK(
where $v$ is the velocity, $f$ is the potential, $gamma$ and $u$ are the
friction and momentum parameters, and $W$ is a Brownian motion.
Solvers which inherit from this class include ALIGN, SORT, ShOULD, and
QUIC_SORT.
Solvers which inherit from this class include [`diffrax.ALIGN`][],
[`diffrax.ShOULD`][], and [`diffrax.QUIC_SORT`][].
"""

term_structure = MultiTerm[tuple[_LangevinDriftTerm, _LangevinDiffusionTerm]]
term_structure = MultiTerm[tuple[LangevinDriftTerm, LangevinDiffusionTerm]]
interpolation_cls = LocalLinearInterpolation
taylor_threshold: RealScalarLike = eqx.field(static=True)
_coeffs_structure: eqx.AbstractClassVar[PyTreeDef]
Expand Down Expand Up @@ -150,7 +143,7 @@ def _eval_taylor(h, tay_coeffs: _Coeffs) -> _Coeffs:
)

def _recompute_coeffs(
self, h, gamma: _LangevinX, tay_coeffs: PyTree[_Coeffs], state_h
self, h, gamma: LangevinX, tay_coeffs: PyTree[_Coeffs], state_h
) -> _Coeffs:
def recompute_coeffs_leaf(c: ArrayLike, _tay_coeffs: _Coeffs):
# Used when the step-size h changes and coefficients need to be recomputed
Expand Down Expand Up @@ -201,7 +194,7 @@ def init(
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
t0: RealScalarLike,
t1: RealScalarLike,
y0: _LangevinTuple,
y0: LangevinTuple,
args: PyTree,
) -> SolverState:
"""Precompute _SolverState which carries the Taylor coefficients and the
Expand Down Expand Up @@ -242,24 +235,24 @@ def _check_shapes(_c, _u, _x, _v):
def _compute_step(
h: RealScalarLike,
levy,
x0: _LangevinX,
v0: _LangevinX,
x0: LangevinX,
v0: LangevinX,
langevin_args: _LangevinArgs,
coeffs: _Coeffs,
st: SolverState,
) -> tuple[_LangevinX, _LangevinX, _LangevinX, _ErrorEstimate]:
) -> tuple[LangevinX, LangevinX, LangevinX, _ErrorEstimate]:
raise NotImplementedError

def step(
self,
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
t0: RealScalarLike,
t1: RealScalarLike,
y0: _LangevinTuple,
y0: LangevinTuple,
args: PyTree,
solver_state: SolverState,
made_jump: BoolScalarLike,
) -> tuple[_LangevinTuple, _ErrorEstimate, DenseInfo, SolverState, RESULTS]:
) -> tuple[LangevinTuple, _ErrorEstimate, DenseInfo, SolverState, RESULTS]:
del made_jump, args
st = solver_state
drift, diffusion = terms.terms
Expand Down Expand Up @@ -319,7 +312,7 @@ def func(
self,
terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]],
t0: RealScalarLike,
y0: _LangevinTuple,
y0: LangevinTuple,
args: PyTree,
):
return terms.vf(t0, y0, args)
14 changes: 7 additions & 7 deletions diffrax/_solver/quicsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
RealScalarLike,
)
from .._local_interpolation import LocalLinearInterpolation
from .._term import _LangevinX
from .._term import LangevinX
from .langevin_srk import (
_LangevinArgs,
AbstractCoeffs,
Expand Down Expand Up @@ -186,16 +186,16 @@ def _tay_coeffs_single(self, c: Array) -> _QUICSORTCoeffs:
def _compute_step(
h: RealScalarLike,
levy: AbstractSpaceTimeTimeLevyArea,
x0: _LangevinX,
v0: _LangevinX,
x0: LangevinX,
v0: LangevinX,
langevin_args: _LangevinArgs,
coeffs: _QUICSORTCoeffs,
st: SolverState,
) -> tuple[_LangevinX, _LangevinX, _LangevinX, None]:
) -> tuple[LangevinX, LangevinX, LangevinX, None]:
dtypes = jtu.tree_map(jnp.dtype, x0)
w: _LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: _LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)
kk: _LangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes)
w: LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)
kk: LangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes)

gamma, u, f = langevin_args

Expand Down
14 changes: 7 additions & 7 deletions diffrax/_solver/should.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
RealScalarLike,
)
from .._local_interpolation import LocalLinearInterpolation
from .._term import _LangevinX
from .._term import LangevinX
from .langevin_srk import (
_LangevinArgs,
AbstractCoeffs,
Expand Down Expand Up @@ -184,16 +184,16 @@ def _tay_coeffs_single(self, c: Array) -> _ShOULDCoeffs:
def _compute_step(
h: RealScalarLike,
levy: AbstractSpaceTimeTimeLevyArea,
x0: _LangevinX,
v0: _LangevinX,
x0: LangevinX,
v0: LangevinX,
langevin_args: _LangevinArgs,
coeffs: _ShOULDCoeffs,
st: SolverState,
) -> tuple[_LangevinX, _LangevinX, _LangevinX, None]:
) -> tuple[LangevinX, LangevinX, LangevinX, None]:
dtypes = jtu.tree_map(jnp.dtype, x0)
w: _LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: _LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)
kk: _LangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes)
w: LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)
kk: LangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes)

gamma, u, f = langevin_args

Expand Down
49 changes: 26 additions & 23 deletions diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,12 +791,13 @@ def _to_vjp(_y, _diff_args, _diff_term):
# `x` and the velocity `v`. Both of these have the same shape. So, by LangevinX we
# denote the shape of the x component, and by LangevinTuple we denote the shape of
# the tuple (x, v).
_LangevinX = PyTree[Shaped[Array, "?*langevin"], "LangevinX"]
_LangevinTuple = tuple[_LangevinX, _LangevinX]
_LangevinBM = TypeVar("_LangevinBM", bound=Union[_LangevinX, AbstractBrownianIncrement])
LangevinX = PyTree[Shaped[Array, "?*langevin"], "LangevinX"]
LangevinTuple = tuple[LangevinX, LangevinX]


class _LangevinDiffusionTerm(AbstractTerm[_LangevinX, _LangevinBM]):
class LangevinDiffusionTerm(
AbstractTerm[LangevinX, Union[LangevinX, AbstractBrownianIncrement]]
):
r"""Represents the diffusion term in the Langevin SDE:
$d \mathbf{x}_t = \mathbf{v}_t dt$
Expand All @@ -805,11 +806,11 @@ class _LangevinDiffusionTerm(AbstractTerm[_LangevinX, _LangevinBM]):
\nabla f( \mathbf{x}_t ) dt + \sqrt{2 \gamma u} d W_t.$
"""

gamma: _LangevinX
u: _LangevinX
gamma: LangevinX
u: LangevinX
control: AbstractBrownianPath

def vf(self, t: RealScalarLike, y: _LangevinTuple, args: Args) -> _LangevinX:
def vf(self, t: RealScalarLike, y: LangevinTuple, args: Args) -> LangevinX:
x, v = y

def _fun(_gamma, _u, _v):
Expand All @@ -818,13 +819,12 @@ def _fun(_gamma, _u, _v):
d_v = jtu.tree_map(_fun, self.gamma, self.u, v)
return d_v

def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> _LangevinBM:
def contr(
self, t0: RealScalarLike, t1: RealScalarLike, **kwargs
) -> Union[LangevinX, AbstractBrownianIncrement]:
return self.control.evaluate(t0, t1, **kwargs)

def prod(self, vf: _LangevinX, control: _LangevinX) -> _LangevinTuple:
if isinstance(control, AbstractBrownianIncrement):
control = control.W

def prod(self, vf: LangevinX, control: LangevinX) -> LangevinTuple:
dv = vf
dw = control
v_out = jtu.tree_map(operator.mul, dv, dw)
Expand All @@ -843,12 +843,12 @@ def _inner_broadcast(_src_arr, _inner_tree_shape):
return jtu.tree_map(_inner_broadcast, source, target_tree_shape)


class _LangevinDriftTerm(AbstractTerm):
gamma: _LangevinX
u: _LangevinX
grad_f: Callable[[_LangevinX], _LangevinX]
class LangevinDriftTerm(AbstractTerm):
gamma: LangevinX
u: LangevinX
grad_f: Callable[[LangevinX], LangevinX]

def vf(self, t: RealScalarLike, y: _LangevinTuple, args: Args) -> _LangevinTuple:
def vf(self, t: RealScalarLike, y: LangevinTuple, args: Args) -> LangevinTuple:
x, v = y
f_x = self.grad_f(x)
d_x = v
Expand All @@ -865,7 +865,7 @@ def vf(self, t: RealScalarLike, y: _LangevinTuple, args: Args) -> _LangevinTuple
def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> RealScalarLike:
return t1 - t0

def prod(self, vf: _LangevinTuple, control: RealScalarLike) -> _LangevinTuple:
def prod(self, vf: LangevinTuple, control: RealScalarLike) -> LangevinTuple:
return jtu.tree_map(lambda _vf: control * _vf, vf)


Expand All @@ -874,16 +874,19 @@ def make_langevin_term(
u: PyTree[ArrayLike],
grad_f: Callable,
bm: AbstractBrownianPath,
x0: _LangevinX,
) -> MultiTerm[tuple[_LangevinDriftTerm, _LangevinDiffusionTerm]]:
x0: LangevinX,
) -> MultiTerm[tuple[LangevinDriftTerm, LangevinDiffusionTerm]]:
r"""Creates a term that represents the Underdamped Langevin Diffusion, given by:
$d \mathbf{x}_t = \mathbf{v}_t dt$
$d \mathbf{v}_t = - \gamma \mathbf{v}_t dt - u
\nabla f( \mathbf{x}_t ) dt + \sqrt{2 \gamma u} d W_t.$
where $\mathbf{x}_t, \mathbf{v}_t \in \mathbb{R}^d$ represent the position
and velocity, $W$ is a Brownian motion in $\mathbb{R}^d$,
$f: \mathbb{R}^d \rightarrow \mathbb{R}$ is a potential function, and
$ \gamma,u\in\mathbb{R}^{d\times d}$ are diagonal matrices governing
$\gamma , u \in \mathbb{R}^{d \times d}$ are diagonal matrices governing
the friction and the dampening of the system.
**Arguments:**
Expand Down Expand Up @@ -915,6 +918,6 @@ def _shape_check_fun(_x, _g, _u, _fx):
jtu.tree_map(_shape_check_fun, x0, gamma, u, grad_f_shape)
), "The shapes of gamma, u, and grad_f(x0) must be the same as x0."

drift = _LangevinDriftTerm(gamma, u, grad_f)
diffusion = _LangevinDiffusionTerm(gamma, u, bm)
drift = LangevinDriftTerm(gamma, u, grad_f)
diffusion = LangevinDiffusionTerm(gamma, u, bm)
return MultiTerm(drift, diffusion)
2 changes: 2 additions & 0 deletions docs/api/terms.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,5 @@ Some example term structures include:
selection:
members:
- __init__

::: diffrax.make_langevin_term
Loading

0 comments on commit 2b80e3b

Please sign in to comment.