Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make using JAX (or any accelerator) an Instanced Python class (and toggle) #509

Open
wants to merge 14 commits into
base: master
Choose a base branch
from

Conversation

Lnaden
Copy link
Contributor

@Lnaden Lnaden commented Jun 22, 2023

Supersedes #508

This PR overhauls how the accelerator logic is chosen, and gives that power to the MBAR instantiation process as well as if someone is just using the functional solver library itself.

This is a re-thinking on how to handle different libraries with identical functioning methods where we only want to select one or the other in Python. This also adds some future-proofing if other accelerators are wanted in the future.

  • The mbar_solvers file has now been moved to its own module.
  • An abstract API class is created for both the MBAR-specific methods, and for the internal accelerator/solver approach methods which have to be done. Together these make an MBARSolver general class as was suggested in #508 by @jchodera
  • One of the methods is instanced as a "default" for the mbar_solvers import, so any attempt to import mbar_solvers or from mbar_solvers import X,Y,Z,* behave identically to how the main branch currently is. Thus the API is preserved and can keep the 4.y.z version scheme.
  • Default is JAX
  • I'm calling numpy an "accelerator" even though its the fall back.
  • I can confirm this implementation is just (post import) as fast as the one in Make using JAX (or any accelerator) a toggle #508 which should be comparable to main.

The complicated part of this is casting the actual functions (e.g. gradient, log W nk, precondition u_kn, etc.) have to be generated as static methods, not tied to the actual MBARSolver class. The way JIT for JAX works is it will serialize anything in the function, so any object (i.e. class and self in function) has every one of its methods/parameters serialized as well at compile time, resulting in a massive slowdown if I try to just leave each of the methods as class methods. I get around this with a number of generate_static... methods and replace the actual definitions in the __init__ while still preserving the doc strings and API. In case you're wondering details, see the MBARSolver.__init__ doc string
Edit: After testing, correct function definitions, and PyTree assignments, this is not a problem and implementation 99% as fast as static methods on average.

Given the drastically different way the mbar_solvers is loaded, and the massive implementation change, I would like to again formally request input from @mikemhenry and @mrshirts about this approach. @jchodera as you suggested the API idea in #508, I'd like to as for your feedback as well given the implementation was not as clean due to the JIT shenanigans.

@invemichele this is the full implementation of the outlined API and features in #496, so any input you have would also be appreciated.

Lnaden and others added 8 commits June 16, 2023 14:31
…s of work still, including how to delay JIT

Rebased from master
This PR overhuals how the accelerator logic is chosen, and gives that power to the MBAR instantiation process as well as if someone is just using the functional solver library itself.

This is a re-thinking on how to handle different libraries with identical functioning methods where we only want to select one or the other in Python. This also adds some future-proofing if other accelerators are wanted in the future.

* Importing is all handled through an `init_accelerator` method with matching name (i.e. `init_numpy` or `init_jax`).
* All items which need to be set in the `mbar_solvers` namespace are set through the `global` word of Python in the `init_X` method and therefore are cast up to the full `mbar_solvers` namespace.
* The `mbar_solvers` module now has state of the whole module and exists as ONE OR THE OTHER at any given time depending on when the last time the accelerator was set. I.e. You cannot have one MBAR object set as numpy and another set as JAX in the same code and expect them to operate with different libraries.
* Default is JAX
* I'm calling numpy an "accelerator" even though its the fall back.
Relies on creating classes with an exposed API. The problem is that JAX doesn't like acting on class methods so I am having to build around it.
…ecompile. Seems to run much slower in tests right now.
Casts methods to static methods every time to ensure that JAX is not serializing the class itself as constants, dramatically slowing down the code execution. Makes for slightly more complicated method call jumping, but otherwise uses the same code paths.
@codecov
Copy link

codecov bot commented Jun 22, 2023

Codecov Report

Merging #509 (00a9cd7) into master (cfe49fc) will increase coverage by 0.88%.
The diff coverage is 89.47%.

Lnaden and others added 3 commits June 22, 2023 16:03
…o go into source code to find this flag)

Fix lint complaining about jax import on no-jax systems by wrapping JAX and raising appropriately
…ls are weird, so added a "real_jit" property that can be set on implementation.
@Lnaden
Copy link
Contributor Author

Lnaden commented Jun 23, 2023

Note: After some more testing, I think I was instancing the Solver classes everytime a new MBAR was created and that was what was slowing down my testing, that and I had a native Python sum function exposed instead of numpy/jax .sum which also didnt help.

I think I can remove all of the static method generator functions and go back to just clean class methods, and do the pytree registration for good measure just in case. It will clean up the code and maintain speed. Something for me to check and benchmark when I can.

…nerated methods is negligible and my earlier testing was the fact I was re-instancing the solver, and thus re-JIT'ing everything each time a new MBAR was called (fixed in earlier commit).

After testing, here are the results:

Testing the timing of test_protocols test using static-generated methods as a relative baseline:
The test is 99% as fast on average with PyTree registration.
The test is 95% as fast on average without the PyTree registration.

So I've opted to use the JAX PyTree registration method and simplify the code substantially by moving all methods back into self methods.

Also updated the readme to reflect the new option.
@Lnaden
Copy link
Contributor Author

Lnaden commented Jun 26, 2023

After some testing, I found the speed gain from having pure static generated methods is negligible and my earlier testing was the fact I was re-instancing the solver, and thus re-JIT'ing everything each time a new MBAR was called (fixed in earlier commit).

After testing, here are the results:

Testing the timing of test_protocols test using static-generated methods as a relative baseline:
The test is 99% as fast on average with PyTree registration.
The test is 95% as fast on average without the PyTree registration.

So I've opted to use the JAX PyTree registration method and simplify the code substantially by moving all methods back into self methods.

So this version is more pythonic, easier to read, almost as fast as pure static methods, and overall implements the API listed in #496.

This is ready for review

The only outstanding question I have is a naming convention: Do we want to keep the name "accelerator" as I have in most places, or "solver" which I have in a few others, they are just accelerated by the different libraries. The only API concern here is the keyword accelerator=... I added to the MBAR object. Whatever we set, we wont want to change until 5.0, so I want to set it now.

They say the pre 1.9 behavior was a bug, hence no depreciation warning.
@@ -96,6 +96,7 @@ def __init__(
n_bootstraps=0,
bootstrap_solver_protocol=None,
rseed=None,
accelerator=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this name is good

"Either install with pip or conda:\n"
" pip install pybar[jax] \n"
" OR \n"
" conda install pymbar \n"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should add a -c conda-forge in here, I didn't want to use the web UI suggestion features since I wouldn't be able to maintain the formatting of this warning

@mikemhenry
Copy link
Contributor

Also you can either fix this in your PR, or merge this one in #510 to fix the RTD builds

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants