diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index dfcdff7e..6a8d8190 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -775,6 +775,34 @@ def _save_t1(subsaveat, save_state): save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state) return save_state + def _save_ts_impl(ts, fn, _save_state): + def _cond_fun(__save_state): + return __save_state.saveat_ts_index < len(ts) + + def _body_fun(__save_state): + idx = __save_state.save_index + ts = __save_state.ts.at[idx].set(t0) + ys = jtu.tree_map( + lambda _y, _ys: _ys.at[idx].set(_y), + fn(t0, yfinal, args), + __save_state.ys, + ) + return SaveState( + saveat_ts_index=idx + 1, + ts=ts, + ys=ys, + save_index=idx + 1, + ) + + return inner_while_loop( + _cond_fun, + _body_fun, + _save_state, + max_steps=len(ts), + buffers=_inner_buffers, + checkpoints=len(ts), + ) + def _save_t0_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: if subsaveat.t0: ts = save_state.ts.at[0].set(t0) @@ -786,25 +814,7 @@ def _save_t0_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: save_index=1, ) if subsaveat.ts is not None: - - def _body_fun(idx, _save_state): - ts = _save_state.ts.at[idx].set(t0) - ys = jtu.tree_map( - lambda _y, _ys: _ys.at[idx].set(_y), - subsaveat.fn(t0, yfinal, args), - _save_state.ys, - ) - return SaveState( - saveat_ts_index=idx + 1, - ts=ts, - ys=ys, - save_index=idx + 1, - ) - - save_idx = save_state.save_index - save_state = jax.lax.fori_loop( - save_idx, save_idx + len(subsaveat.ts), _body_fun, save_state - ) + save_state = _save_ts_impl(subsaveat.ts, subsaveat.fn, save_state) return save_state # if t0 == t1 then we don't enter the integration loop. In this case we have to