Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SRK #344

Merged
merged 1 commit into from
Apr 27, 2024
Merged

SRK #344

merged 1 commit into from
Apr 27, 2024

Conversation

andyElking
Copy link
Contributor

This PR adds all the shiny new SRK methods, while leaving out the Langevin-only methods (ALIGN and SORT).
A few things to keep in mind:

  • I wrote some comments in test/helpers.py, test_integrate.py, and _solvers/srk.py that are there for you to read and will be removed later. They are mostly questions about somewhat shaky parts of my code that I want to hear your opinions on.
  • I included langevin.ipynb in the examples folder. Because we are not merging in the LangevinTerm in _terms.py, I just copied the code for that into the notebook. If you want, I can remove this whole notebook until the next PR, but I do think it is very informative as to how the new stochastic solvers work. Furthermore, it gives you a showcase of what is to come later, which might make it clearer why I made some design decisions, as they make adding the Langevin things more seamless later on.
  • I added the new solvers to test_integrate.test_sde_strong_order. There I added a new "additive noise" mode in addition to the "commutative" and "non-commutative" noise options. Later on, once we add LangevinTerm, I also have a new set of tests called test_langevin.py prepared, which further test these methods (although they are primarily aimed and ALIGN and SORT, as those don't work with non-Langevin SDEs).
  • Apart from SRA1, I currently don't have any of Rossler's other solvers. Adding these will amount to adding just a few tableaus, which can easily be done later down the line (i.e. after you review the existing code). To be fair these aren't really my priority right now, but I will try to add them before we merge this whole PR.
  • I will probably be quite busy over the next two weeks, so I might not be able to make many edits at this point. I'm mostly making this PR now to give you ample time to read all the code 😊.

@andyElking andyElking force-pushed the srk_pr branch 4 times, most recently from 9872349 to b58343b Compare December 31, 2023 10:16
@patrick-kidger patrick-kidger deleted the branch patrick-kidger:dev January 8, 2024 22:27
@patrick-kidger patrick-kidger reopened this Jan 8, 2024
@patrick-kidger patrick-kidger changed the base branch from dev to main January 8, 2024 22:28
@patrick-kidger
Copy link
Owner

patrick-kidger commented Jan 8, 2024

Whoops, looks like GitHub auto-closed this due to merging the dev branch.

If you can rebase on top of main then I'd be happy to review this now. :) (Getting this in is my next priority.)

@andyElking
Copy link
Contributor Author

No worries, I rebased it :)

@lockwo
Copy link
Contributor

lockwo commented Jan 25, 2024

I think there is an error with the SRK notebook, I see these when I look at it on the branch or in the diff
Screenshot 2024-01-25 at 2 08 32 PM
Screenshot 2024-01-25 at 2 08 48 PM

@andyElking
Copy link
Contributor Author

andyElking commented Jan 25, 2024 via email

@lockwo
Copy link
Contributor

lockwo commented Jan 26, 2024

Awesome, thanks for that! Seems like a local comment also got left on ;)
image

Excited to see these methods make it to main!

@andyElking
Copy link
Contributor Author

andyElking commented Jan 26, 2024 via email

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, sorry for taking so long to get around to reviewing this; I've been on a long holiday. :)

I still need to review the mathematics of the SRK step itself, but I wanted to start by leaving some first-pass comments on everything else. (Also, because notebooks don't support commenting yet -- note that langevin.ipynb seems to import from the private diffrax._custom_types module.)

Overall, my impression is that this looks pretty good! It's also really highlighting how we really need a way to say something about the control in the term (see also #359), so that we can write something like

term_structure: MultiTerm[tuple[ODETerm, AbstractTerm[AbstractBrownianMotion[SpaceTimeLevyArea]]]]

and try to avoid the current heuristic Levy area checks happening in AbstractSRK.init.

No worries about doing this here, but if you'd find it interesting, then I'd welcome any thoughts on the cleanest way to do that. (Maybe in the future we find ourselves wanting to say something else about our controls as well, e.g. that they return terms from the log-signature of the path?)

diffrax/_solver/foster_srk.py Outdated Show resolved Hide resolved
diffrax/_solver/foster_srk.py Outdated Show resolved Hide resolved
diffrax/_solver/foster_srk.py Outdated Show resolved Hide resolved
diffrax/_solver/foster_srk.py Outdated Show resolved Hide resolved
diffrax/_solver/foster_srk.py Outdated Show resolved Hide resolved
diffrax/_solver/srk.py Outdated Show resolved Hide resolved
diffrax/_solver/srk.py Outdated Show resolved Hide resolved
diffrax/_solver/srk.py Outdated Show resolved Hide resolved
diffrax/_solver/srk.py Outdated Show resolved Hide resolved
docs/api/solvers/sde_solvers.md Outdated Show resolved Hide resolved
@andyElking
Copy link
Contributor Author

Thanks for these comments, I made the appropriate corrections.

Overall, my impression is that this looks pretty good! It's also really highlighting how we really need a way to say something about the control in the term (see also #359), so that we can write something like

term_structure: MultiTerm[tuple[ODETerm, AbstractTerm[AbstractBrownianMotion[SpaceTimeLevyArea]]]]

and try to avoid the current heuristic Levy area checks happening in AbstractSRK.init.

Yes, I will think about that. So I don't think the type of Levy area generated is important enough to make its way into the type annotation, and could be left at the level of an attribute.

Probably the least invasive way would be to generally just determine compatibility with the control based on the PyTree signature of it output (e.g. LevyVal), like I do in srk.py:

if sttla:
     assert bm_inc.K is not None

If the control doesn't have the right shape we throw a well-explained error. In either case it is up to the user to make sure the BM has the correct shape (to match the diffusion VF), so maybe Levy area doesn't need any more special treatment than that.

I'm not really sure what is "nice enough" though. You are certainly much more of an expert on how to write good code.

@andyElking
Copy link
Contributor Author

Also, regarding langevin.ipynb:
The reason I import private things is that I moved the LangevinTerm from _term.py into that notebook, just to give you some context on what is coming later.
I intend to remove the notebook from this PR before it gets merged, but I can remove it now already if you wish.

@andyElking andyElking force-pushed the srk_pr branch 2 times, most recently from 457a7f2 to 30dbeab Compare January 31, 2024 15:26
@andyElking andyElking force-pushed the srk_pr branch 2 times, most recently from 6f5903e to eecef67 Compare February 6, 2024 15:36
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, just done a review. I think the implementation looks really clean. I want to scrutinise the mathematics more carefully, but I think this now looks to be most of the way there!

(And yup, let's go ahead and remove langevin.ipynb for now.)

diffrax/_solver/srk.py Outdated Show resolved Hide resolved
diffrax/_solver/srk.py Outdated Show resolved Hide resolved
diffrax/_solver/srk.py Outdated Show resolved Hide resolved
diffrax/_solver/srk.py Outdated Show resolved Hide resolved
diffrax/_solver/srk.py Show resolved Hide resolved
diffrax/_solver/srk.py Outdated Show resolved Hide resolved
diffrax/_solver/srk.py Outdated Show resolved Hide resolved
diffrax/_solver/srk.py Outdated Show resolved Hide resolved
diffrax/_solver/srk.py Outdated Show resolved Hide resolved
docs/api/solvers/abstract_solvers.md Show resolved Hide resolved
@patrick-kidger
Copy link
Owner

(let me know when you next want my input on this PR btw)

@andyElking
Copy link
Contributor Author

(let me know when you next want my input on this PR btw)

So I made the changes you suggested (except eqxi.scan) on the pr_correction branch. I also made some changes to test_sde_strong_order there, because they were quite sorely needed. I explained why in the code comments. But I'll probably merge all of that into this branch once I figure out the eqxi.scan, which might be on Monday or Tuesday. I'll let you know.

@andyElking
Copy link
Contributor Author

andyElking commented Feb 12, 2024

Hi Patrick, I made the edits you suggested.

In addition I also reformed convergence testing. This includes comparing the values of solutions at several times (not only at t1) as well as addressing the concern about Euler and Heun being way too imprecise to be used in establishing the order of high-order solvers. I explain more in this comment:
https://github.com/andyElking/diffrax_STLA/blob/c70e33bf0a18bd19ac7a3b0ed7074c6bafccd79b/test/test_integrate.py#L229-L233

@lockwo
Copy link
Contributor

lockwo commented Apr 2, 2024

Or is this related to Patricks PR which does "I realised that it's not sufficient to have term_compatible_contr_kwargs have a single thing on the solver -- we need something with essentially the same tree structure as term_structure, so that we can map the kwargs to each term. So I reworked this."

@lockwo
Copy link
Contributor

lockwo commented Apr 2, 2024

Can confirm, it runs off Patricks branch

@andyElking
Copy link
Contributor Author

Can confirm, it runs off Patricks branch

Hi @lockwo!

Yes indeed any descendant of AbstractSRK should work with any _AbstractControlTerm(AbstractTerm[_VF, _Control]) as long as term.control is a descendant of AbstractBrownianPath. And yes, the issue is just that we are trying to integrate SRKs with the new term structure checks from #364 and it seems I haven't done that correctly. Sorry about this.

I'll soon merge in Patrick's branch, so it won't stay broken for long.

@lockwo
Copy link
Contributor

lockwo commented Apr 2, 2024

Ok, sounds good. Maybe a weakly controlled srk could also be a unit test? (Whenever something doesn’t work my first impulse is to make it a test so I know it’s good from then on)

@andyElking
Copy link
Contributor Author

Ok, sounds good. Maybe a weakly controlled srk could also be a unit test? (Whenever something doesn’t work my first impulse is to make it a test so I know it’s good from then on)

The intention was to add that with the next PR, when Langevin terms are introduced (which use a different prod than ControlTerm). But I suppose I could add a quick test for this now.

@andyElking
Copy link
Contributor Author

@patrick-kidger I merged in your changes and added some fixes of my own on top of that. All of these additional fixes are explained in my comments on your PR.

@lockwo
Copy link
Contributor

lockwo commented Apr 2, 2024

Ok, sounds good. Maybe a weakly controlled srk could also be a unit test? (Whenever something doesn’t work my first impulse is to make it a test so I know it’s good from then on)

The intention was to add that with the next PR, when Langevin terms are introduced (which use a different prod than ControlTerm). But I suppose I could add a quick test for this now.

I defer to Patrick and you on this, I just encounter things and want to make sure they don't slip by on other PRs, I don't have a strong preference on which PR adds the tests.

@andyElking
Copy link
Contributor Author

andyElking commented Apr 2, 2024

@patrick-kidger I am very confused about the tests failing eariler. I reran on my laptop the tests that failed here, and everything passed. And the stack-trace is very confusing. Something about DenseInterpolation and times having shape 4097 instead of 4096. Do you know what could be the issue?

I looked into it further and I suspect it has to do with lines 960/961 of _integrate.py where clearly the length of dense_ts is 1 greater than the length of dense_infos, despite them having the same leading dimension called times (line 51 of _custom_types.py). Surprising this error hasn't been caught before though...

@andyElking
Copy link
Contributor Author

@lockwo I added a test for the SRKs which involves a weakly diagonal control term. Thanks again for bringing this up!

@andyElking
Copy link
Contributor Author

@patrick-kidger Another update: I found that when using an SRK wrapped in a HalfSolver, the previous error message still appears:

ValueError: `terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with structure diffrax._term.MultiTerm[tuple[diffrax._term.ODETerm, diffrax._term.AbstractTerm[typing.Any, diffrax._custom_types.AbstractSpaceTimeLevyArea]]]

I will try to investigate and will let you know if I find a fix.

@andyElking
Copy link
Contributor Author

andyElking commented Apr 10, 2024

I fixed the HalfSolver issue in commit ac6f3ce. I also added a quick test for this particular issue to test_sde.py. I made some other changes there to make all of the tests about 4x faster.

I don't know what to do to stop the very weird error one of the other tests keeps throwing. It seems entirely unrelated to my work. This first appeared when I merged your srk-tweaks branch (before I added edits of my own) and it kept appearing since. I think I know what is the cause, but it is in a part of Diffrax that I'd rather not meddle in.

@andyElking
Copy link
Contributor Author

I made another commit (7ff1029), which I think fixes the issue that kept appearing. I strongly think it wasn't caused by any of my changes, probably just some typechecking module got updated and caught this??

Anyway, this is just a quick temporary fix, let me know if you'd like something more principled.

@andyElking
Copy link
Contributor Author

More things are broken apparently. It seems to me that EqxRuntimeError no longer inherits from RuntimeError, causing issues with some exception catching. I did a very temporary patch, which I expect you will want to handle differently, but I just want the tests relevant to my PR to pass. Sorry, I really tried to figure out how these errors inherit from each other, but it all gets muddled at jaxlib.xla_extension.XlaRuntimeError and I have no clue how to fix this properly.

@patrick-kidger
Copy link
Owner

Awesome stuff! I'm glad things are passing. I'll try to look at what you've done this weekend.

FWIW I don't think EqxRuntimeError is ever surfaced to a user. I think what you're probably seeing here is an inconsistency I've observed in JAX before, in which runtime errors seem to be variably be surfaced as either an XlaRuntimeError or as a ValueError. Is that what you're seeing?

@patrick-kidger patrick-kidger mentioned this pull request Apr 10, 2024
@andyElking
Copy link
Contributor Author

For some reason it got floated to the user in this case. You can look at the log from the latest failed test run and search for:
equinox._errors.EqxRuntimeError: Must have `t0==ts[0]` and `t1==ts[-1]`.

In general I think something broke with the latest version of Equinox or JAX and now some tests in Diffrax are spuriously failing. Try to run tests on the main branch of Diffrax, I think they might fail, because these failures don't seem to be related to the changes we made.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Apr 18, 2024

Try updating to the latest Equinox :) Something on the JAX side changed without how they output errors. Equinox has since been updated to work in both regimes.

@patrick-kidger patrick-kidger force-pushed the dev branch 2 times, most recently from 76a9441 to 34cbe5c Compare April 20, 2024 09:27
@patrick-kidger
Copy link
Owner

patrick-kidger commented Apr 21, 2024

Btw it looks like there are now many merge conflicts here. I think these are probably spurious, e.g. I can see that your git diff now seems to include the changes to _progress_bar.py which have happened on unrelated branches. (FWIW I think we are probably now at the point where I'd like to just hit the "merge" button! Best to get this in and if we need any small tweaks then we can do them in a subsequent PR, as I don't want to leave this one outstanding for any longer :) )

If you're able to squash all your changes into a single commit on top of dev that would be best. If you're not comfortable with git to that level then LMK and I can give it a try / try to guide you through it.

@andyElking
Copy link
Contributor Author

andyElking commented Apr 23, 2024

I squashed everything and removed the last commit (which was just something to do with handling of RuntimeError and should now no longer be needed). From what I've seen the diff now only includes bits relevant to this work. Thanks for bearing with me on this one Patrick!

@patrick-kidger patrick-kidger merged commit 55d3c0f into patrick-kidger:dev Apr 27, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

And merged! 🎉🎉

This was a pretty gigantic PR, so thank you very much for all your effort to implement this. I am incredibly excited that we'll be able to introduce this in the next release of Diffrax.

Great work on getting all of this done!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants