Skip to content

Commit

Permalink
bug fix for cases when t0=True
Browse files Browse the repository at this point in the history
  • Loading branch information
dkweiss31 committed Nov 18, 2024
1 parent 439887c commit dc0dba4
Showing 1 changed file with 4 additions and 14 deletions.
18 changes: 4 additions & 14 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,7 +777,7 @@ def _save_t1(subsaveat, save_state):

def _save_ts_impl(ts, fn, _save_state):
def _cond_fun(__save_state):
return __save_state.saveat_ts_index < len(ts)
return __save_state.saveat_ts_index < len(_save_state.ts)

def _body_fun(__save_state):
idx = __save_state.save_index
Expand All @@ -803,32 +803,22 @@ def _body_fun(__save_state):
checkpoints=len(ts),
)

def _save_t0_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.t0:
ts = save_state.ts.at[0].set(t0)
ys = save_state.ys.at[0].set(subsaveat.fn(t0, yfinal, args))
save_state = SaveState(
saveat_ts_index=1,
ts=ts,
ys=ys,
save_index=1,
)
def _save_ts(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
if subsaveat.ts is not None:
save_state = _save_ts_impl(subsaveat.ts, subsaveat.fn, save_state)
return save_state

# if t0 == t1 then we don't enter the integration loop. In this case we have to
# manually update the saved ts and ys if we want to save at t0 or at "intermediate"
# manually update the saved ts and ys if we want to save at "intermediate"
# times specified by saveat.subs.ts
save_state = jax.lax.cond(
t0 == t1,
lambda _save_state: jtu.tree_map(
_save_t0_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat
_save_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat
),
lambda _save_state: _save_state,
final_state.save_state,
)

save_state = jtu.tree_map(_save_t1, saveat.subs, save_state, is_leaf=_is_subsaveat)
final_state = eqx.tree_at(
lambda s: s.save_state, final_state, save_state, is_leaf=_is_none
Expand Down

0 comments on commit dc0dba4

Please sign in to comment.