diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 6a8d8190..d5cb90c1 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -777,7 +777,7 @@ def _save_t1(subsaveat, save_state): def _save_ts_impl(ts, fn, _save_state): def _cond_fun(__save_state): - return __save_state.saveat_ts_index < len(ts) + return __save_state.saveat_ts_index < len(_save_state.ts) def _body_fun(__save_state): idx = __save_state.save_index @@ -803,32 +803,22 @@ def _body_fun(__save_state): checkpoints=len(ts), ) - def _save_t0_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: - if subsaveat.t0: - ts = save_state.ts.at[0].set(t0) - ys = save_state.ys.at[0].set(subsaveat.fn(t0, yfinal, args)) - save_state = SaveState( - saveat_ts_index=1, - ts=ts, - ys=ys, - save_index=1, - ) + def _save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: if subsaveat.ts is not None: save_state = _save_ts_impl(subsaveat.ts, subsaveat.fn, save_state) return save_state # if t0 == t1 then we don't enter the integration loop. In this case we have to - # manually update the saved ts and ys if we want to save at t0 or at "intermediate" + # manually update the saved ts and ys if we want to save at "intermediate" # times specified by saveat.subs.ts save_state = jax.lax.cond( t0 == t1, lambda _save_state: jtu.tree_map( - _save_t0_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat + _save_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat ), lambda _save_state: _save_state, final_state.save_state, ) - save_state = jtu.tree_map(_save_t1, saveat.subs, save_state, is_leaf=_is_subsaveat) final_state = eqx.tree_at( lambda s: s.save_state, final_state, save_state, is_leaf=_is_none