Skip to content

Commit

Permalink
added scan_trick in QUICSORT and ShOULD
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Sep 1, 2024
1 parent 8e8e454 commit aa5cbd0
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 23 deletions.
4 changes: 2 additions & 2 deletions diffrax/_solver/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _compute_step(
levy: AbstractSpaceTimeLevyArea,
x0: UnderdampedLangevinX,
v0: UnderdampedLangevinX,
uld_args: UnderdampedLangevinArgs,
underdamped_langevin_args: UnderdampedLangevinArgs,
coeffs: _ALIGNCoeffs,
rho: UnderdampedLangevinX,
prev_f: UnderdampedLangevinX,
Expand All @@ -163,7 +163,7 @@ def _compute_step(
w: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes)
hh: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)

gamma, u, f = uld_args
gamma, u, f = underdamped_langevin_args

uh = (u**ω * h).ω
f0 = prev_f
Expand Down
2 changes: 1 addition & 1 deletion diffrax/_solver/foster_langevin_srk.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def _compute_step(
levy,
x0: UnderdampedLangevinX,
v0: UnderdampedLangevinX,
uld_args: UnderdampedLangevinArgs,
underdamped_langevin_args: UnderdampedLangevinArgs,
coeffs: _Coeffs,
rho: UnderdampedLangevinX,
prev_f: Optional[UnderdampedLangevinX],
Expand Down
28 changes: 20 additions & 8 deletions diffrax/_solver/quicsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from equinox.internal import scan_trick, ω
from jaxtyping import ArrayLike, PyTree

from .._custom_types import (
Expand Down Expand Up @@ -193,7 +193,7 @@ def _compute_step(
levy: AbstractSpaceTimeTimeLevyArea,
x0: UnderdampedLangevinX,
v0: UnderdampedLangevinX,
uld_args: UnderdampedLangevinArgs,
underdamped_langevin_args: UnderdampedLangevinArgs,
coeffs: _QUICSORTCoeffs,
rho: UnderdampedLangevinX,
prev_f: Optional[UnderdampedLangevinX],
Expand All @@ -204,7 +204,7 @@ def _compute_step(
hh: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)
kk: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes)

gamma, u, f = uld_args
gamma, u, f = underdamped_langevin_args

def _extract_coeffs(coeff, index):
return jtu.tree_map(lambda arr: arr[..., index], coeff)
Expand All @@ -226,12 +226,24 @@ def _extract_coeffs(coeff, index):
v_tilde = (v0**ω + rho**ω * (hh**ω + 6 * kk**ω)).ω

x1 = (x0**ω + a_l**ω * v_tilde**ω + b_l**ω * rho_w_k**ω).ω
f1uh = (f(x1) ** ω * uh**ω).ω

x2 = (
x0**ω + a_r**ω * v_tilde**ω + b_r**ω * rho_w_k**ω - a_third**ω * f1uh**ω
).ω
f2uh = (f(x2) ** ω * uh**ω).ω
# Use eqinox.internal.scan_trick to compute f1, x2 and f2 in one go
# carry = x, f1, f2. We use x0 as the initial value for f1 and f2
init = x1, x0, x0

def fn(carry):
x, _f, _ = carry
fx_uh = (f(x) ** ω * uh**ω).ω
return x, _f, fx_uh

def compute_x2(carry):
_, _, f1 = carry
x = (
x0**ω + a_r**ω * v_tilde**ω + b_r**ω * rho_w_k**ω - a_third**ω * f1**ω
).ω
return x, f1, f1

x2, f1uh, f2uh = scan_trick(fn, [compute_x2], init)

x_out = (
x0**ω
Expand Down
37 changes: 25 additions & 12 deletions diffrax/_solver/should.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from equinox.internal import scan_trick, ω
from jaxtyping import ArrayLike, PyTree

from .._custom_types import (
Expand Down Expand Up @@ -193,7 +193,7 @@ def _compute_step(
levy: AbstractSpaceTimeTimeLevyArea,
x0: UnderdampedLangevinX,
v0: UnderdampedLangevinX,
uld_args: UnderdampedLangevinArgs,
underdamped_langevin_args: UnderdampedLangevinArgs,
coeffs: _ShOULDCoeffs,
rho: UnderdampedLangevinX,
prev_f: UnderdampedLangevinX,
Expand All @@ -203,7 +203,9 @@ def _compute_step(
hh: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)
kk: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes)

gamma, u, f = uld_args
chh_hh_plus_ckk_kk = (coeffs.chh**ω * hh**ω + coeffs.ckk**ω * kk**ω).ω

gamma, u, f = underdamped_langevin_args

rho_w_k = (rho**ω * (w**ω - 12 * kk**ω)).ω
uh = (u**ω * h).ω
Expand All @@ -215,17 +217,28 @@ def _compute_step(
+ coeffs.a_half**ω * v1**ω
+ coeffs.b_half**ω * (-(uh**ω) * f0**ω + rho_w_k**ω)
).ω
f1 = f(x1)

chh_hh_plus_ckk_kk = (coeffs.chh**ω * hh**ω + coeffs.ckk**ω * kk**ω).ω
# Use equinox.internal.scan_trick to compute f1, x_out and f_out in one go
# carry = x, f1, f2. We use x0 as the initial value for f1 and f2
init = x1, x0, x0

def fn(carry):
x, _f, _ = carry
fx = f(x)
return x, _f, fx

def compute_x2(carry):
_, _, _f1 = carry
x = (
x0**ω
+ coeffs.a1**ω * v0**ω
- uh**ω * coeffs.b1**ω * (1 / 3 * f0**ω + 2 / 3 * _f1**ω)
+ rho**ω * (coeffs.b1**ω * w**ω + chh_hh_plus_ckk_kk**ω)
).ω
return x, _f1, _f1

x_out, f1, f_out = scan_trick(fn, [compute_x2], init)

x_out = (
x0**ω
+ coeffs.a1**ω * v0**ω
- uh**ω * coeffs.b1**ω * (1 / 3 * f0**ω + 2 / 3 * f1**ω)
+ rho**ω * (coeffs.b1**ω * w**ω + chh_hh_plus_ckk_kk**ω)
).ω
f_out = f(x_out)
v_out = (
coeffs.beta1**ω * v0**ω
- uh**ω
Expand Down

0 comments on commit aa5cbd0

Please sign in to comment.