Skip to content

Commit

Permalink
Added Langevin docs, a Langevin example and backwards in time test
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Aug 19, 2024
1 parent 2b80e3b commit e38a932
Show file tree
Hide file tree
Showing 8 changed files with 242 additions and 43 deletions.
16 changes: 3 additions & 13 deletions diffrax/_solver/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
)


# UBU evaluates at l = (3 -sqrt(3))/6, at r = (3 + sqrt(3))/6 and at 1,
# so we need 3 versions of each coefficient


# For an explanation of the coefficients, see langevin_srk.py
class _ALIGNCoeffs(AbstractCoeffs):
beta: PyTree[ArrayLike]
a1: PyTree[ArrayLike]
Expand All @@ -46,15 +43,8 @@ def __init__(self, beta, a1, b1, aa, chh):

class ALIGN(AbstractLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]):
r"""The Adaptive Langevin via Interpolated Gradients and Noise method
designed by James Foster. Only works for Underdamped Langevin Diffusion
of the form
$$d x_t = v_t dt$$
$$d v_t = - gamma v_t dt - u ∇f(x_t) dt + (2gammau)^(1/2) dW_t$$
where $v$ is the velocity, $f$ is the potential, $gamma$ is the friction, and
$W$ is a Brownian motion.
designed by James Foster.
Accepts only terms given by [`diffrax.make_langevin_term`][].
"""

interpolation_cls = LocalLinearInterpolation
Expand Down
10 changes: 6 additions & 4 deletions diffrax/_solver/langevin_srk.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def init(
args: PyTree,
) -> SolverState:
"""Precompute _SolverState which carries the Taylor coefficients and the
SRK coefficients (which can be computed from h and the Taylor coeffs).
SRK coefficients (which can be computed from h and the Taylor coefficients).
Some solvers of this type are FSAL, so _SolverState also carries the previous
evaluation of grad_f.
"""
Expand Down Expand Up @@ -259,16 +259,18 @@ def step(
gamma, u, f = get_args_from_terms(terms)

h = drift.contr(t0, t1)
h_state = st.h
h_prev = st.h
tay: PyTree[_Coeffs] = st.taylor_coeffs
coeffs: _Coeffs = st.coeffs

# If h changed recompute coefficients
cond = jnp.isclose(h_state, h, rtol=1e-10, atol=1e-12)
# Even when using constant step sizes, h can fluctuate by small amounts,
# so we use `jnp.isclose` for comparison
cond = jnp.isclose(h_prev, h, rtol=1e-10, atol=1e-12)
coeffs = lax.cond(
cond,
lambda x: x,
lambda _: self._recompute_coeffs(h, gamma, tay, h_state),
lambda _: self._recompute_coeffs(h, gamma, tay, h_prev),
coeffs,
)

Expand Down
17 changes: 4 additions & 13 deletions diffrax/_solver/quicsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,9 @@
)


# For an explanation of the coefficients, see langevin_srk.py
# UBU evaluates at l = (3 -sqrt(3))/6, at r = (3 + sqrt(3))/6 and at 1,
# so we need 3 versions of each coefficient


class _QUICSORTCoeffs(AbstractCoeffs):
beta_lr1: PyTree[ArrayLike] # (gamma, 3, *taylor)
a_lr1: PyTree[ArrayLike] # (gamma, 3, *taylor)
Expand Down Expand Up @@ -66,14 +65,7 @@ class QUICSORT(AbstractLangevinSRK[_QUICSORTCoeffs, None]):
}
```
Works for underdamped Langevin SDEs of the form
$$d x_t = v_t dt$$
$$d v_t = - gamma v_t dt - u ∇f(x_t) dt + (2gammau)^(1/2) dW_t$$
where $v$ is the velocity, $f$ is the potential, $gamma$ is the friction, and
$W$ is a Brownian motion.
Accepts only terms given by [`diffrax.make_langevin_term`][].
"""

interpolation_cls = LocalLinearInterpolation
Expand Down Expand Up @@ -246,9 +238,8 @@ def _one(coeff):
).ω
v_out = (v_out_tilde**ω - st.rho**ω * (hh**ω - 6 * kk**ω)).ω

f_fsal = (
st.prev_f
) # this method is not FSAL, but this is for compatibility with the base class
# this method is not FSAL, but for compatibility with the base class we set
f_fsal = st.prev_f

# TODO: compute error estimate
return x_out, v_out, f_fsal, None
16 changes: 3 additions & 13 deletions diffrax/_solver/should.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,7 @@
)


# UBU evaluates at l = (3 -sqrt(3))/6, at r = (3 + sqrt(3))/6 and at 1,
# so we need 3 versions of each coefficient


# For an explanation of the coefficients, see langevin_srk.py
class _ShOULDCoeffs(AbstractCoeffs):
beta_half: PyTree[ArrayLike]
a_half: PyTree[ArrayLike]
Expand Down Expand Up @@ -63,15 +60,8 @@ def __init__(self, beta_half, a_half, b_half, beta1, a1, b1, aa, chh, ckk):

class ShOULD(AbstractLangevinSRK[_ShOULDCoeffs, None]):
r"""The Shifted-ODE Runge-Kutta Three method
designed by James Foster. Only works for Underdamped Langevin Diffusion
of the form
$$d x_t = v_t dt$$
$$d v_t = - gamma v_t dt - u ∇f(x_t) dt + (2gammau)^(1/2) dW_t$$
where $v$ is the velocity, $f$ is the potential, $gamma$ is the friction, and
$W$ is a Brownian motion.
designed by James Foster.
Accepts only terms given by [`diffrax.make_langevin_term`][].
"""

interpolation_cls = LocalLinearInterpolation
Expand Down
37 changes: 37 additions & 0 deletions docs/api/solvers/sde_solvers.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,40 @@ These are reversible in the same way as when applied to ODEs. [See here.](./ode_
selection:
members:
- __init__


---

### Underdamped Langevin solvers

These solvers are specifically designed for the Underdamped Langevin diffusion (ULD),
which takes the form

$d \mathbf{x}_t = \mathbf{v}_t dt$

$d \mathbf{v}_t = - \gamma \mathbf{v}_t dt - u
\nabla f( \mathbf{x}_t ) dt + \sqrt{2 \gamma u} d W_t.$

where $\mathbf{x}_t, \mathbf{v}_t \in \mathbb{R}^d$ represent the position
and velocity, $W$ is a Brownian motion in $\mathbb{R}^d$,
$f: \mathbb{R}^d \rightarrow \mathbb{R}$ is a potential function, and
$\gamma , u \in \mathbb{R}^{d \times d}$ are diagonal matrices governing
the friction and the dampening of the system.

They are more precise for this diffusion than the general-purpose solvers above, but
cannot be used for any other SDEs. They only accept terms generated by the
[`diffrax.make_langevin_term`][] function. They all have the same `__init__` signature.
For an example of their usage, see the [Langevin example](../../examples/langevin_example.ipynb).

::: diffrax.ALIGN
selection:
members:
- __init__

::: diffrax.ShOULD
selection:
members: false

::: diffrax.QUICSORT
selection:
members: false
151 changes: 151 additions & 0 deletions examples/langevin_example.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ nav:
- Kalman filter: 'examples/kalman_filter.ipynb'
- Second-order sensitivities: 'examples/hessian.ipynb'
- Nonlinear heat PDE: 'examples/nonlinear_heat_pde.ipynb'
- Langevin diffusion: 'examples/langevin_example.ipynb'
- Basic API:
- 'api/diffeqsolve.md'
- Solvers:
Expand Down
37 changes: 37 additions & 0 deletions test/test_langevin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from .helpers import (
get_bqp,
get_harmonic_oscillator,
path_l2_dist,
SDE,
simple_batch_sde_solve,
simple_sde_order,
Expand Down Expand Up @@ -194,3 +195,39 @@ def get_dt_and_controller(level):
assert (
-0.2 < order - theoretical_order < 0.25
), f"order={order}, theoretical_order={theoretical_order}"


@pytest.mark.parametrize("solver_cls", _only_langevin_solvers_cls())
def test_reverse_solve(solver_cls):
t0, t1 = 0.7, -1.2
dt0 = -0.01
saveat = SaveAt(ts=jnp.linspace(t0, t1, 20, endpoint=True))

gamma = jnp.array([2, 0.5], dtype=jnp.float64)
u = jnp.array([0.5, 2], dtype=jnp.float64)
x0 = jnp.zeros((2,), dtype=jnp.float64)
v0 = jnp.zeros((2,), dtype=jnp.float64)
y0 = (x0, v0)

bm = diffrax.VirtualBrownianTree(
t1,
t0,
tol=0.005,
shape=(2,),
key=jr.key(0),
levy_area=diffrax.SpaceTimeTimeLevyArea,
)
terms = diffrax.make_langevin_term(gamma, u, lambda x: 2 * x, bm, x0)

solver = solver_cls(0.01)
sol = diffeqsolve(terms, solver, t0, t1, dt0=dt0, y0=y0, args=None, saveat=saveat)

ref_solver = diffrax.Heun()
ref_sol = diffeqsolve(
terms, ref_solver, t0, t1, dt0=dt0, y0=y0, args=None, saveat=saveat
)

# print(jtu.tree_map(lambda x: x.shape, sol.ys))
# print(jtu.tree_map(lambda x: x.shape, ref_sol.ys))
error = path_l2_dist(sol.ys, ref_sol.ys)
assert error < 0.1

0 comments on commit e38a932

Please sign in to comment.