Skip to content

Commit

Permalink
added case for saving t0 data, which was also not getting updated.
Browse files Browse the repository at this point in the history
Added a test
  • Loading branch information
dkweiss31 authored and patrick-kidger committed Nov 29, 2024
1 parent 666948c commit 1bd4e08
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 6 deletions.
25 changes: 19 additions & 6 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,14 +775,23 @@ def _save_t1(subsaveat, save_state):
save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state)
return save_state

def _save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
def _save_t0_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.t0:
ts = save_state.ts.at[0].set(t0)
ys = save_state.ys.at[0].set(subsaveat.fn(t0, yfinal, args))
save_state = SaveState(
saveat_ts_index=1,
ts=ts,
ys=ys,
save_index=1,
)
if subsaveat.ts is not None:

def _body_fun(idx, _save_state):
ts = _save_state.ts.at[idx].set(t1)
ts = _save_state.ts.at[idx].set(t0)
ys = jtu.tree_map(
lambda _y, _ys: _ys.at[idx].set(_y),
subsaveat.fn(t1, yfinal, args),
subsaveat.fn(t0, yfinal, args),
_save_state.ys,
)
return SaveState(
Expand All @@ -792,15 +801,19 @@ def _body_fun(idx, _save_state):
save_index=idx + 1,
)

save_state = jax.lax.fori_loop(0, len(subsaveat.ts), _body_fun, save_state)
save_idx = save_state.save_index
save_state = jax.lax.fori_loop(
save_idx, save_idx + len(subsaveat.ts), _body_fun, save_state
)
return save_state

# if t0 == t1 then we don't enter the integration loop. In this case we have to
# manually update the saved values if saveat.subs.ts is not None
# manually update the saved ts and ys if we want to save at t0 or at "intermediate"
# times specified by saveat.subs.ts
save_state = jax.lax.cond(
t0 == t1,
lambda _save_state: jtu.tree_map(
_save_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat
_save_t0_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat
),
lambda _save_state: _save_state,
final_state.save_state,
Expand Down
36 changes: 36 additions & 0 deletions test/test_saveat_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,42 @@ def test_saveat_solution():
assert sol.result == diffrax.RESULTS.successful


@pytest.mark.parametrize("subs", [True, False])
def test_t0_eq_t1(subs):
y0 = jnp.array([2.0])
ts = jnp.linspace(1.0, 1.0, 3)
if subs:
get0 = diffrax.SubSaveAt(
ts=ts,
t1=True,
)
get1 = diffrax.SubSaveAt(
t0=True,
ts=ts,
)
subs = (get0, get1)
saveat = diffrax.SaveAt(subs=subs)
else:
saveat = diffrax.SaveAt(t0=True, t1=True, ts=ts)
term = diffrax.ODETerm(lambda t, y, args: y)
sol = diffrax.diffeqsolve(
term,
t0=ts[0],
t1=ts[-1],
y0=y0,
dt0=0.1,
solver=diffrax.Dopri5(),
saveat=saveat,
)
if subs:
compare = jnp.full((len(ts) + 1, *y0.shape), y0)
assert tree_allclose(sol.ys[0], compare) # pyright: ignore
assert tree_allclose(sol.ys[1], compare) # pyright: ignore
else:
compare = jnp.full((len(ts) + 2, *y0.shape), y0)
assert tree_allclose(sol.ys, compare)


def test_trivial_dense():
term = diffrax.ODETerm(lambda t, y, args: -0.5 * y)
y0 = jnp.array([2.1])
Expand Down

0 comments on commit 1bd4e08

Please sign in to comment.