From 065fe118e61262b92e5422cb5b05e9aae044168e Mon Sep 17 00:00:00 2001 From: Danny Date: Wed, 27 Nov 2024 13:28:52 -0500 Subject: [PATCH] simplified logic for saving, no loop necessary --- diffrax/_integrate.py | 49 ++++++++++++++++--------------------------- 1 file changed, 18 insertions(+), 31 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index d5cb90c1..255bf86d 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -775,50 +775,37 @@ def _save_t1(subsaveat, save_state): save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state) return save_state - def _save_ts_impl(ts, fn, _save_state): - def _cond_fun(__save_state): - return __save_state.saveat_ts_index < len(_save_state.ts) - - def _body_fun(__save_state): - idx = __save_state.save_index - ts = __save_state.ts.at[idx].set(t0) + def _save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: + if subsaveat.ts is not None: ys = jtu.tree_map( - lambda _y, _ys: _ys.at[idx].set(_y), - fn(t0, yfinal, args), - __save_state.ys, + lambda y: jnp.stack([y] * len(save_state.ts)), + subsaveat.fn(t0, yfinal, args), ) - return SaveState( - saveat_ts_index=idx + 1, - ts=ts, + save_state = SaveState( + saveat_ts_index=len(save_state.ts), + ts=save_state.ts, ys=ys, - save_index=idx + 1, + save_index=len(save_state.ts), ) - - return inner_while_loop( - _cond_fun, - _body_fun, - _save_state, - max_steps=len(ts), - buffers=_inner_buffers, - checkpoints=len(ts), - ) - - def _save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: - if subsaveat.ts is not None: - save_state = _save_ts_impl(subsaveat.ts, subsaveat.fn, save_state) return save_state # if t0 == t1 then we don't enter the integration loop. In this case we have to # manually update the saved ts and ys if we want to save at "intermediate" # times specified by saveat.subs.ts save_state = jax.lax.cond( - t0 == t1, - lambda _save_state: jtu.tree_map( - _save_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat + eqxi.unvmap_any(t0 == t1), + lambda __save_state: jax.lax.cond( + t0 == t1, + lambda _save_state: jtu.tree_map( + _save_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat + ), + lambda _save_state: _save_state, + __save_state, ), - lambda _save_state: _save_state, + lambda __save_state: __save_state, final_state.save_state, ) + save_state = jtu.tree_map(_save_t1, saveat.subs, save_state, is_leaf=_is_subsaveat) final_state = eqx.tree_at( lambda s: s.save_state, final_state, save_state, is_leaf=_is_none