Skip to content

Commit

Permalink
Save fix for t0==t1 (#494)
Browse files Browse the repository at this point in the history
* Langevin PR (#453)

* Langevin PR

* Minor fixes

* removed the SORT solver (superseded by QUICSORT)

* made LangevinTerm.term a static field

* temporary fix for _term_compatible and LangevinTerm

* Fixed LangevinTerm YAAAYYYYY

* Nits

* Added Langevin docs, a Langevin example and backwards in time test

* Fixed Patrick's comments

* langevin -> underdamped_langevin

* round of small fixes

* check langevin drift term and diffusion term have same args

* added scan_trick in QUICSORT and ShOULD

* using RuntimeError for when ULD args have wrong structure

* small fixes

* tidy-ups

* Split SDE tests in half, to try and avoid GitHub runner issues?

* Added effects_barrier to fix test issue with JAX 0.4.33+

* small fix of docs in all three and a return type in quicsort

* bump doc building pipeline

* Compatibility with JAX 0.4.36, which removes ConcreteArray

* using a fori_loop to save states in edge case t0==t1

* added case for saving t0 data, which was also not getting updated.

Added a test

* using while_loop, ran into issues with reverse-mode diff using the fori_loop

* bug fix for cases when t0=True

* simplified logic for saving, no loop necessary

* added vmap test

* using a fori_loop to save states in edge case t0==t1

* added case for saving t0 data, which was also not getting updated.

Added a test

* using while_loop, ran into issues with reverse-mode diff using the fori_loop

* bug fix for cases when t0=True

* simplified logic for saving, no loop necessary

* added vmap test

* fix t1 out of bounds issue

* fix for steps: don't want to update those values if t0==t1 since we didn't take any steps.

Added test

---------

Co-authored-by: Andraž Jelinčič <[email protected]>
Co-authored-by: Patrick Kidger <[email protected]>
Co-authored-by: andyElking <[email protected]>
  • Loading branch information
4 people authored Dec 6, 2024
1 parent ba09fba commit dc68957
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 0 deletions.
47 changes: 47 additions & 0 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,9 +775,56 @@ def _save_t1(subsaveat, save_state):
save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state)
return save_state

def _save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.ts is not None:
out_size = 1 if subsaveat.t0 else 0
out_size += 1 if subsaveat.t1 and not subsaveat.steps else 0
out_size += len(subsaveat.ts)
ys = jtu.tree_map(
lambda y: jnp.stack([y] * out_size),
subsaveat.fn(t0, yfinal, args),
)
ts = jnp.full(out_size, t0)
if subsaveat.steps:
ysteps = jtu.tree_map(
lambda y: jnp.stack([y] * max_steps),
subsaveat.fn(t0, jnp.full_like(yfinal, jnp.inf), args),
)
ys = jtu.tree_map(
lambda _ys, _ysteps: jnp.concatenate([_ys, _ysteps], axis=0),
ys,
ysteps,
)
ts = jnp.concatenate((ts, jnp.full(max_steps, jnp.inf)))
save_state = SaveState(
saveat_ts_index=out_size,
ts=ts,
ys=ys,
save_index=out_size,
)
return save_state

save_state = jtu.tree_map(
_save_t1, saveat.subs, final_state.save_state, is_leaf=_is_subsaveat
)

# if t0 == t1 then we don't enter the integration loop. In this case we have to
# manually update the saved ts and ys if we want to save at "intermediate"
# times specified by saveat.subs.ts
save_state = jax.lax.cond(
eqxi.unvmap_any(t0 == t1),
lambda __save_state: jax.lax.cond(
t0 == t1,
lambda _save_state: jtu.tree_map(
_save_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat
),
lambda _save_state: _save_state,
__save_state,
),
lambda __save_state: __save_state,
save_state,
)

final_state = eqx.tree_at(
lambda s: s.save_state, final_state, save_state, is_leaf=_is_none
)
Expand Down
85 changes: 85 additions & 0 deletions test/test_saveat_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,91 @@ def test_saveat_solution():
assert sol.result == diffrax.RESULTS.successful


@pytest.mark.parametrize("subs", [True, False])
def test_t0_eq_t1(subs):
y0 = jnp.array([2.0])
ts = jnp.linspace(1.0, 1.0, 3)
max_steps = 256
if subs:
get0 = diffrax.SubSaveAt(
ts=ts,
t1=True,
)
get1 = diffrax.SubSaveAt(
t0=True,
ts=ts,
)
get2 = diffrax.SubSaveAt(
t0=True,
ts=ts,
steps=True,
)
subs = (get0, get1, get2)
saveat = diffrax.SaveAt(subs=subs)
else:
saveat = diffrax.SaveAt(t0=True, t1=True, ts=ts)
term = diffrax.ODETerm(lambda t, y, args: y)
sol = diffrax.diffeqsolve(
term,
t0=ts[0],
t1=ts[-1],
y0=y0,
dt0=0.1,
solver=diffrax.Dopri5(),
saveat=saveat,
max_steps=max_steps,
)
if subs:
compare = jnp.full((len(ts) + 1, *y0.shape), y0)
compare_2 = jnp.concatenate(
(compare, jnp.full((max_steps, *y0.shape), jnp.inf))
)
assert tree_allclose(sol.ys[0], compare) # pyright: ignore
assert tree_allclose(sol.ys[1], compare) # pyright: ignore
assert tree_allclose(sol.ys[2], compare_2) # pyright: ignore
else:
compare = jnp.full((len(ts) + 2, *y0.shape), y0)
assert tree_allclose(sol.ys, compare)


@pytest.mark.parametrize("subs", [True, False])
def test_vmap_t0_eq_t1(subs):
ntsave = 4
y0 = jnp.array([2.0])
term = diffrax.ODETerm(lambda t, y, args: y)

def _solve(tf):
ts = jnp.linspace(0.0, tf, ntsave)
get0 = diffrax.SubSaveAt(
ts=ts,
t1=True,
)
get1 = diffrax.SubSaveAt(
t0=True,
ts=ts,
)
subs = (get0, get1)
saveat = diffrax.SaveAt(subs=subs)
return diffrax.diffeqsolve(
term,
t0=ts[0],
t1=ts[-1],
y0=y0,
dt0=0.1,
solver=diffrax.Dopri5(),
saveat=saveat,
)

compare = jnp.full((ntsave + 1, *y0.shape), y0)
sol = jax.vmap(_solve)(jnp.array([0.0, 1.0]))
assert tree_allclose(sol.ys[0][0], compare) # pyright: ignore
assert tree_allclose(sol.ys[1][0], compare) # pyright: ignore

regular_solve = _solve(1.0)
assert tree_allclose(sol.ys[0][1], regular_solve.ys[0]) # pyright: ignore
assert tree_allclose(sol.ys[1][1], regular_solve.ys[1]) # pyright: ignore


def test_trivial_dense():
term = diffrax.ODETerm(lambda t, y, args: -0.5 * y)
y0 = jnp.array([2.1])
Expand Down

0 comments on commit dc68957

Please sign in to comment.