Skip to content

Commit

Permalink
simplified logic for saving, no loop necessary
Browse files Browse the repository at this point in the history
  • Loading branch information
dkweiss31 authored and patrick-kidger committed Nov 29, 2024
1 parent 0e00411 commit 065fe11
Showing 1 changed file with 18 additions and 31 deletions.
49 changes: 18 additions & 31 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,50 +775,37 @@ def _save_t1(subsaveat, save_state):
save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state)
return save_state

def _save_ts_impl(ts, fn, _save_state):
def _cond_fun(__save_state):
return __save_state.saveat_ts_index < len(_save_state.ts)

def _body_fun(__save_state):
idx = __save_state.save_index
ts = __save_state.ts.at[idx].set(t0)
def _save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.ts is not None:
ys = jtu.tree_map(
lambda _y, _ys: _ys.at[idx].set(_y),
fn(t0, yfinal, args),
__save_state.ys,
lambda y: jnp.stack([y] * len(save_state.ts)),
subsaveat.fn(t0, yfinal, args),
)
return SaveState(
saveat_ts_index=idx + 1,
ts=ts,
save_state = SaveState(
saveat_ts_index=len(save_state.ts),
ts=save_state.ts,
ys=ys,
save_index=idx + 1,
save_index=len(save_state.ts),
)

return inner_while_loop(
_cond_fun,
_body_fun,
_save_state,
max_steps=len(ts),
buffers=_inner_buffers,
checkpoints=len(ts),
)

def _save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.ts is not None:
save_state = _save_ts_impl(subsaveat.ts, subsaveat.fn, save_state)
return save_state

# if t0 == t1 then we don't enter the integration loop. In this case we have to
# manually update the saved ts and ys if we want to save at "intermediate"
# times specified by saveat.subs.ts
save_state = jax.lax.cond(
t0 == t1,
lambda _save_state: jtu.tree_map(
_save_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat
eqxi.unvmap_any(t0 == t1),
lambda __save_state: jax.lax.cond(
t0 == t1,
lambda _save_state: jtu.tree_map(
_save_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat
),
lambda _save_state: _save_state,
__save_state,
),
lambda _save_state: _save_state,
lambda __save_state: __save_state,
final_state.save_state,
)

save_state = jtu.tree_map(_save_t1, saveat.subs, save_state, is_leaf=_is_subsaveat)
final_state = eqx.tree_at(
lambda s: s.save_state, final_state, save_state, is_leaf=_is_none
Expand Down

0 comments on commit 065fe11

Please sign in to comment.