Skip to content

Commit

Permalink
added scan_trick in QUICSORT and ShOULD
Browse files Browse the repository at this point in the history
  • Loading branch information
andyElking committed Sep 1, 2024
1 parent 8e8e454 commit 2eafdbd
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 16 deletions.
24 changes: 18 additions & 6 deletions diffrax/_solver/quicsort.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from equinox.internal import scan_trick, ω
from jaxtyping import ArrayLike, PyTree

from .._custom_types import (
Expand Down Expand Up @@ -226,12 +226,24 @@ def _extract_coeffs(coeff, index):
v_tilde = (v0**ω + rho**ω * (hh**ω + 6 * kk**ω)).ω

x1 = (x0**ω + a_l**ω * v_tilde**ω + b_l**ω * rho_w_k**ω).ω
f1uh = (f(x1) ** ω * uh**ω).ω

x2 = (
x0**ω + a_r**ω * v_tilde**ω + b_r**ω * rho_w_k**ω - a_third**ω * f1uh**ω
).ω
f2uh = (f(x2) ** ω * uh**ω).ω
# Use eqinox.internal.scan_trick to compute f1, x2 and f2 in one go
# carry = x, f1, f2. We use x0 as the initial value for f1 and f2
init = x1, x0, x0

def fn(carry):
x, _f, _ = carry
fx_uh = (f(x) ** ω * uh**ω).ω
return x, _f, fx_uh

def compute_x2(carry):
_, _, f1 = carry
x = (
x0**ω + a_r**ω * v_tilde**ω + b_r**ω * rho_w_k**ω - a_third**ω * f1**ω
).ω
return x, f1, f1

x2, f1uh, f2uh = scan_trick(fn, [compute_x2], init)

x_out = (
x0**ω
Expand Down
33 changes: 23 additions & 10 deletions diffrax/_solver/should.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
from equinox.internal import ω
from equinox.internal import scan_trick, ω
from jaxtyping import ArrayLike, PyTree

from .._custom_types import (
Expand Down Expand Up @@ -203,6 +203,8 @@ def _compute_step(
hh: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes)
kk: UnderdampedLangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes)

chh_hh_plus_ckk_kk = (coeffs.chh**ω * hh**ω + coeffs.ckk**ω * kk**ω).ω

gamma, u, f = uld_args

rho_w_k = (rho**ω * (w**ω - 12 * kk**ω)).ω
Expand All @@ -215,17 +217,28 @@ def _compute_step(
+ coeffs.a_half**ω * v1**ω
+ coeffs.b_half**ω * (-(uh**ω) * f0**ω + rho_w_k**ω)
).ω
f1 = f(x1)

chh_hh_plus_ckk_kk = (coeffs.chh**ω * hh**ω + coeffs.ckk**ω * kk**ω).ω
# Use equinox.internal.scan_trick to compute f1, x_out and f_out in one go
# carry = x, f1, f2. We use x0 as the initial value for f1 and f2
init = x1, x0, x0

def fn(carry):
x, _f, _ = carry
fx = f(x)
return x, _f, fx

def compute_x2(carry):
_, _, _f1 = carry
x = (
x0**ω
+ coeffs.a1**ω * v0**ω
- uh**ω * coeffs.b1**ω * (1 / 3 * f0**ω + 2 / 3 * _f1**ω)
+ rho**ω * (coeffs.b1**ω * w**ω + chh_hh_plus_ckk_kk**ω)
).ω
return x, _f1, _f1

x_out, f1, f_out = scan_trick(fn, [compute_x2], init)

x_out = (
x0**ω
+ coeffs.a1**ω * v0**ω
- uh**ω * coeffs.b1**ω * (1 / 3 * f0**ω + 2 / 3 * f1**ω)
+ rho**ω * (coeffs.b1**ω * w**ω + chh_hh_plus_ckk_kk**ω)
).ω
f_out = f(x_out)
v_out = (
coeffs.beta1**ω * v0**ω
- uh**ω
Expand Down

0 comments on commit 2eafdbd

Please sign in to comment.