Skip to content

Commit

Permalink
using while_loop, ran into issues with reverse-mode diff using the fo…
Browse files Browse the repository at this point in the history
…ri_loop
  • Loading branch information
dkweiss31 committed Aug 21, 2024
1 parent 3e3ceb8 commit fa009f0
Showing 1 changed file with 29 additions and 19 deletions.
48 changes: 29 additions & 19 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,34 @@ def _save_t1(subsaveat, save_state):
save_state = _save(tfinal, yfinal, args, subsaveat.fn, save_state)
return save_state

def _save_ts_impl(ts, fn, _save_state):
def _cond_fun(__save_state):
return __save_state.saveat_ts_index < len(ts)

def _body_fun(__save_state):
idx = __save_state.save_index
ts = __save_state.ts.at[idx].set(t0)
ys = jtu.tree_map(
lambda _y, _ys: _ys.at[idx].set(_y),
fn(t0, yfinal, args),
__save_state.ys,
)
return SaveState(
saveat_ts_index=idx + 1,
ts=ts,
ys=ys,
save_index=idx + 1,
)

return inner_while_loop(
_cond_fun,
_body_fun,
_save_state,
max_steps=len(ts),
buffers=_inner_buffers,
checkpoints=len(ts),
)

def _save_t0_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.t0:
ts = save_state.ts.at[0].set(t0)
Expand All @@ -781,25 +809,7 @@ def _save_t0_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
save_index=1,
)
if subsaveat.ts is not None:

def _body_fun(idx, _save_state):
ts = _save_state.ts.at[idx].set(t0)
ys = jtu.tree_map(
lambda _y, _ys: _ys.at[idx].set(_y),
subsaveat.fn(t0, yfinal, args),
_save_state.ys,
)
return SaveState(
saveat_ts_index=idx + 1,
ts=ts,
ys=ys,
save_index=idx + 1,
)

save_idx = save_state.save_index
save_state = jax.lax.fori_loop(
save_idx, save_idx + len(subsaveat.ts), _body_fun, save_state
)
save_state = _save_ts_impl(subsaveat.ts, subsaveat.fn, save_state)
return save_state

# if t0 == t1 then we don't enter the integration loop. In this case we have to
Expand Down

0 comments on commit fa009f0

Please sign in to comment.