Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Save fix for t0==t1 #494

Merged
merged 22 commits into from
Dec 6, 2024
Merged

Save fix for t0==t1 #494

merged 22 commits into from
Dec 6, 2024

Conversation

dkweiss31
Copy link
Contributor

Addresses edge case raised in #488 when t0 == t1 and saveat.ts is not None. Additionally if saveat.t0 is True then those values were not updated either, which should be addressed by this PR. I've additionally included a test for this case.

WRT the implementation: while a loop is not very nice since everything could in principle be done in parallel, the below did not work for the ts part due to dynamic slicing errors. Let me know if there is a nicer workaround I could try :)

if subsaveat.ts is not None:
    _ts = subsaveat.ts
    save_idx = save_state.save_index
    ts = save_state.ts.at[save_idx: save_idx + len(_ts)].set(_ts)
    _ys = [subsaveat.fn(t1, yfinal, args)] * len(_ts)
    ys = save_state.ys.at[save_idx: save_idx + len(_ts)].set(_ys)
    save_state = SaveState(
         saveat_ts_index=save_idx + len(_ts),
         ts=ts,
         ys=ys,
         save_index=save_idx + len(_ts),
     )

@dkweiss31
Copy link
Contributor Author

To address some failing tests re reverse mode differentiation I converted it to a while_loop, but I'm still seeing some failed tests. Converting this to a draft for now

@dkweiss31 dkweiss31 marked this pull request as draft August 21, 2024 19:18
andyElking and others added 6 commits September 1, 2024 21:29
* Langevin PR

* Minor fixes

* removed the SORT solver (superseded by QUICSORT)

* made LangevinTerm.term a static field

* temporary fix for _term_compatible and LangevinTerm

* Fixed LangevinTerm YAAAYYYYY

* Nits

* Added Langevin docs, a Langevin example and backwards in time test

* Fixed Patrick's comments

* langevin -> underdamped_langevin

* round of small fixes

* check langevin drift term and diffusion term have same args

* added scan_trick in QUICSORT and ShOULD

* using RuntimeError for when ULD args have wrong structure

* small fixes
@dkweiss31 dkweiss31 marked this pull request as ready for review November 13, 2024 13:15
@dkweiss31
Copy link
Contributor Author

@patrick-kidger sorry for the long delay! I think the PR is ready for review now. All tests pass except for one of the tqdm progress bar tests involving jit: I'm not at all sure what is going on there?

Additionally I wanted to draw your attention to the line I wrote on line 773:

def _save_ts_impl(ts, fn, _save_state):
    def _cond_fun(__save_state):
        return __save_state.saveat_ts_index < len(_save_state.ts)

where I had to use _save_state.ts instead of ts in the conditional check because saveat_ts_index can already be 1 if _save_state.t0==True. So if I used ts, then the last entry doesn't get updated. This doesn't mirror exactly what's happening on lines 421-427, so I just wanted to briefly mention it.

@dkweiss31 dkweiss31 changed the title Save fix for to==t1 Save fix for t0==t1 Nov 13, 2024
@patrick-kidger
Copy link
Owner

patrick-kidger commented Nov 17, 2024

Awesome! Can you rebase on top of dev (+make that the PR target branch) and I'll do a review? :) (I think this should also fix the tqdm-jit test.)

@dkweiss31 dkweiss31 changed the base branch from main to dev November 18, 2024 23:48
@dkweiss31
Copy link
Contributor Author

Ok, done I think!

Comment on lines 816 to 804
lambda _save_state: jtu.tree_map(
_save_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat
),
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this t0 == t1 branch can be made much more efficient: I imagine you should be able to do the tree-mapped equivalent of ys = jnp.where(ts == t0, y0, jnp.inf). Does something go wrong with this approach that I'm missing?

(Regardless of the above, do also note that you are doing a tree-map-of-a-loop. If you can, it is usually much more efficient to do a loop-of-a-tree-map.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the note about tree-map-of-a-loop vs. loop-of-a-tree-map! This was a gotcha I was unaware of.

I am probably missing something here but it seems to me at least naively that things could get screwed up based on the way data is saved in SaveState. In my test, we would have for saveat

SaveAt(
  subs=(
    SubSaveAt(t0=False, t1=True, ts=f32[3], steps=False, fn=<function save_y>),
    SubSaveAt(t0=True, t1=False, ts=f32[3], steps=False, fn=<function save_y>)
  ),
  dense=False,
  solver_state=False,
  controller_state=False,
  made_jump=False
)

and for final_state.save_state we have

(SaveState(saveat_ts_index=i32[], ts=f32[4], ys=f32[4,1], save_index=i32[]), SaveState(saveat_ts_index=i32[], ts=f32[4], ys=f32[4,1], save_index=i32[]))

The issue is that for final_state.save_state[0], we want to fill in ys and ts starting at index 0, whereas for final_state.save_state[1] we want to start filling in at index 1 since saveat.subs[1].t0==True. Its this dependence on the value of saveat_ts_index or save_index that made me implement this in a loop and makes it non-obvious to me how a tree_map would be able to handle that.

If you agree, then I will try to refactor this as a loop-of-a-tree-map vs. the way I have it now. Very possible I am missing something simple though :)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I think the only cases that affect this are t0 and t1? If we're saving neither of them then we should fill in ts == t0-many elements of the output. If we're saving just one of them then we want to save one more value. If we're saving both of them then we want to save two more values.

So untested but something like this:

mask = ts == t0
if t0 or t1:
    if t0 and t1:
        mask = lax.dynamic_update_slice_in_dim(mask, jnp.array([True, True]), jnp.argmin(mask), axis=0)
    else:
        mask = lax.dynamic_update_slice_in_dim(mask, jnp.array([True]), jnp.argmin(mask), axis=0)
ys = jnp.where(mask, y0, jnp.inf)

?

Comment on lines 814 to 807
save_state = jax.lax.cond(
t0 == t1,
lambda _save_state: jtu.tree_map(
_save_ts, saveat.subs, _save_state, is_leaf=_is_subsaveat
),
lambda _save_state: _save_state,
final_state.save_state,
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you wrap this whole cond in another cond, which has predicate eqxi.unvmap_any(t0 == t1)?

This is to avoid the vmap-of-cond-becomes-select issue: under vmap, then both branches will unconditionally trigger, which would really slow things down!

Comment on lines +150 to +151
@pytest.mark.parametrize("subs", [True, False])
def test_t0_eq_t1(subs):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test LGTM! Can you also add a variant in which the computation is vmapd? (So that e.g. one batch element has t0 == t1 and another one doesn't, or multiple batch elements have t0 == t1 but with different such values.)

This test will be sure that the eqxi.unvmap_any trick I mention above is handled correctly.

@dkweiss31
Copy link
Contributor Author

Hi @patrick-kidger ! I've tried again and as you suspected no loop is necessary. I went in a slightly different direction from what you suggested, let me know your thoughts :). I've also added the vmap test, lmk if this is what you had in mind or if it was something different

)

save_state = jtu.tree_map(_save_t1, saveat.subs, save_state, is_leaf=_is_subsaveat)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So this pretty much all LGTM! I can see that you're using the invariant t0 <= ts < t1. That's much neater than the unnecessary where check that I was suggesting!

I think my only concerns now are how this edge case interacts with a couple of other edge cases:

  1. if both t0 == t1 and the conditions of _save_t1 are triggered, then I think the latter will actually attempt to write out-of-bounds?
  2. if we have a boolean-returning event that triggers immediately (which is another reason to never enter the integration loop), then do we do the right thing? (Whether t0 == t1 or t0 != t1.)

I think (1) at least might be solved by putting the new _save_ts after _save_t1, so that we essentially just overwrite the latter if we hit this case -- WDYT?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Oh also you'll see a force-push on this branch -- I updated dev and this was what gave a readable diff for me to review here. I realised after-the-fact that this means you'll now need to force-pull your own local branch... sorry about that!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're absolutely right about point 1: fixed in most recent commit.

I'm having a hard time seeing what might go wrong with such a boolean event that immediately triggers. In this case, we should still just fill in y0 as we do now if subsaveat.ts==True and t0==t1. If t0!=t1 then the cond doesn't get triggered anyways and this discussion is moot. Do you agree?

I'm also wondering about what we should be doing if t0==t1 and subsaveat.steps==True. If it is, but subsaveat.ts is None, then we don't fill in y0 for any of the steps. This feels like the right behavior. On the other hand, if subsaveat.steps==True and subsaveat.ts is not None, then we fill in y0 for all values in save_state, including those for the "steps". Do you think this is the right behavior?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Also sorry for the long delay, been a crazy week!)

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

including those for the "steps". Do you think this is the right behavior?

Oh, good observation! You're right. Ever-so-technically we shouldn't fill it in for the steps, I think.

(Also sorry for the long delay, been a crazy week!)

Oh no worries at all, I sympathize entirely :D

I'm having a hard time seeing what might go wrong with such a boolean event that immediately triggers. In this case, we should still just fill in y0 as we do now if subsaveat.ts==True and t0==t1. If t0!=t1 then the cond doesn't get triggered anyways and this discussion is moot. Do you agree?

I think you're correct, it's just quite a tricky thing to thread the needle on the logic. I'll probably just add a test for this once we have this PR in :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, updated to not update steps when t0==t1 and subsaveat.steps==True. Edge cases FTW!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also added a test for this case

@patrick-kidger patrick-kidger merged commit dc68957 into patrick-kidger:dev Dec 6, 2024
@patrick-kidger
Copy link
Owner

Awesome! This LGTM. Thank you for so carefully catching the edge-cases and the edge-cases-of-edge-cases. I'll be doing a new release of Diffrax shortly, which will include this fix :)

patrick-kidger added a commit that referenced this pull request Dec 9, 2024
* Langevin PR (#453)

* Langevin PR

* Minor fixes

* removed the SORT solver (superseded by QUICSORT)

* made LangevinTerm.term a static field

* temporary fix for _term_compatible and LangevinTerm

* Fixed LangevinTerm YAAAYYYYY

* Nits

* Added Langevin docs, a Langevin example and backwards in time test

* Fixed Patrick's comments

* langevin -> underdamped_langevin

* round of small fixes

* check langevin drift term and diffusion term have same args

* added scan_trick in QUICSORT and ShOULD

* using RuntimeError for when ULD args have wrong structure

* small fixes

* tidy-ups

* Split SDE tests in half, to try and avoid GitHub runner issues?

* Added effects_barrier to fix test issue with JAX 0.4.33+

* small fix of docs in all three and a return type in quicsort

* bump doc building pipeline

* Compatibility with JAX 0.4.36, which removes ConcreteArray

* using a fori_loop to save states in edge case t0==t1

* added case for saving t0 data, which was also not getting updated.

Added a test

* using while_loop, ran into issues with reverse-mode diff using the fori_loop

* bug fix for cases when t0=True

* simplified logic for saving, no loop necessary

* added vmap test

* using a fori_loop to save states in edge case t0==t1

* added case for saving t0 data, which was also not getting updated.

Added a test

* using while_loop, ran into issues with reverse-mode diff using the fori_loop

* bug fix for cases when t0=True

* simplified logic for saving, no loop necessary

* added vmap test

* fix t1 out of bounds issue

* fix for steps: don't want to update those values if t0==t1 since we didn't take any steps.

Added test

---------

Co-authored-by: Andraž Jelinčič <[email protected]>
Co-authored-by: Patrick Kidger <[email protected]>
Co-authored-by: andyElking <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants