Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Accelerate the code, starting with the inner loops #340

Open
mrshirts opened this issue Feb 16, 2020 · 27 comments
Open

Accelerate the code, starting with the inner loops #340

mrshirts opened this issue Feb 16, 2020 · 27 comments
Assignees
Labels

Comments

@mrshirts
Copy link
Collaborator

mrshirts commented Feb 16, 2020

We need to accelerate the inner loops.

I think the most important things to do would be to accelerate in this order

  • logsumexp.

This is the largest cost.

After that is accelerated, it's probably worth it to look at:

  • mbar_gradient - used the most in many algorithms, just a few lines wrapped around logsumexp
  • self_consistent_iteration - used a lot, also just a few lines.
  • mbar_hessian - not many lines either, though also calls some other routines such as mbar_W_nk that are themselves relatively short wrappers around logsumexp.
@jaimergp
Copy link
Member

I'd strongly suggest running python -m cprofile -o profile.dat whatever_pymbar_script.py to profile all the calls and then visualize the output with snakeviz.

It will produce a visualization similar to this one, which I found extremely informative to find the bottlenecks.

@mrshirts
Copy link
Collaborator Author

mrshirts commented Feb 17, 2020

Ah, I had been using cProfile, but was poring over the text: I didn't know about snakeviz.

Here's three sample profiles of a set of harmonic oscillators, using the adaptive method - all of them converged in 3 self-consistent and 3 Newton-Raphson iterations.

100 states, 25000 samples per state: profile.100x25000.txt
1000 states, 300 samples per state: profile.1000x300.txt
500 states, 1000 samples per state: profile.500x1000.txt

I'm not posting the graphs here, since you need to do a little interactive work to see what is going on, but you can visualize them easily with this.

All of these took about 3-4 minutes. Note that anything bigger (size scales roughly (nsamples/state)*(n_state)^2 starts to cause memory issues, and start running much more slowly because of page faults. All of these ran at about 10GB, and took about the same amount of time (1000x300 was slower, but bigger).

Interestingly, I noticed it cycling back and forth between 100% and 400% CPU usage, so some aspects are well-parallelized, and some are not.

So what I identified as the slow parts were indeed, but the balance is not as clear.

logsumexp is the most time-consuming part, but the non-logsumexp parts of mbar_gradient, self_consistent_update, and mbar_hessian were expensive as well if I'm interpreting it correctly.

For example, let's look at self_consistent_update, which consists entirely of

u_kn, N_k, f_k = validate_inputs(u_kn, N_k, f_k)
states_with_samples = (N_k > 0)

# Only the states with samples can contribute to the denominator term.                                                    
log_denominator_n = logsumexp(f_k[states_with_samples] - u_kn[states_with_samples].T, b=N_k[states_with_samples], axis=1)

# All states can contribute to the numerator term.                                                                        
return -1. * logsumexp(-log_denominator_n - u_kn, axis=1)

Looking at the 1000x300 case, self_consistent_update has a 8.1 s/call, but there are only two logsumexp calls, with 1.4 s/call each. the tottime per call of self_consistent_update is 5.351 s/call., which means that the array reshaping, etc, is taking up a lot of time.

mbar_gradient is:

u_kn, N_k, f_k = validate_inputs(u_kn, N_k, f_k)
log_denominator_n = logsumexp(f_k - u_kn.T, b=N_k, axis=1)
log_numerator_k = logsumexp(-log_denominator_n - u_kn, axis=1)
return -1 * N_k * (1.0 - np.exp(f_k + log_numerator_k))

Also with 2 logsumexp's. The tottime per call is 2.8 s/call, which means that reformatting rather than logsumexp is about 50% of the time. (Also, validate_inputs is like 0.003 s/call, so it's not that).

mbar_hessian is:

u_kn, N_k, f_k = validate_inputs(u_kn, N_k, f_k)

W = mbar_W_nk(u_kn, N_k, f_k)
H = W.T.dot(W)
H *= N_k
H *= N_k[:, np.newaxis]
H -= np.diag(W.sum(0) * N_k)
return -1.0 * H

Which is 10.557 cumtime/call, with mbar_W_kn 7.705 cumtime/call, with the dot product being a lot of the remainder.

Anyway, it appears that MOST of the cost is in just 4-5 short routines, but it's not ALL in logsumexp.

I'll let other people digest these for a bit, and come back later after I've thought about it a bit.

@mrshirts
Copy link
Collaborator Author

mrshirts commented Apr 6, 2020

We are getting closer to getting all the cleanup done, and starting to look at the right way to accelerate. @jchodera, you seem to be getting rather busy - do you have some ideas for doing the acceleration that I could try?

@mrshirts
Copy link
Collaborator Author

mrshirts commented Jun 29, 2020

OK, I started with Jax/Jit, with a couple of drafts. These are WIP PRs #398 (vanilla pymbar4) #399 (partial jax/jit), and #400 (full adaptive routine in jax/jit) A few comments.

  • I do not entirely know what I am doing here.

  • I wasn't able to get it installing with conda-forge. I had to use pip, less ideal. So not sure if this is ready for prime-time until we sort that out. Problematic if it can't just conda-install. Also, it's failing checks. Should we try something like pyTorch as well which is currently easier to get rolled out automatically?

  • I tried two different things:

    • just turning the most expensive parts into jax, and jitifying them
    • turning the entire adaptive solver loop into a single jit function. This took some fiddling, since I needed to declare all of the variables to be static, which may have impacted performance. But it appears necessary, since the results of the variables are needed for control flow (deciding which branch to use)
    • I seem to need to force double precision, the loops fail with single precision.

I benchmarked on 5 states with 5M points each, and 100 states with 20K points each. The first took about 8GB of memory, the 2nd about 10 GB. I didn't go higher since I didn't want timings limited by writing to disk.

  • I have not tried to run on GPU yet. Can't do it on my laptop, haven't gotten to running it somewhere else yet.

  • I'm attaching a tarred gzip with 6 profiles (listed with the total times) for those who want to inspect more.

    • profile_pymbar4_100x2p4.dat (147 s)
    • profile_partial_jit_100x2p4.dat (72 s)
    • profile_adaptive_jit_100x2p4.dat (208 s)
    • profile_pymbar4_5x5p6.dat (101 s)
    • profile_partial_jit_5x5p6.dat (59 s)
    • profile_adaptive_jit_5x5p6.dat (140 s)

Files here: profiles.gz

(Update! Bonus: I was able to get the embedded C code working for pymbar 2 again on the same system). Python 2 profiles, so harder to get read, so I'll just included the total time here.

  • profile_oldC_5x5p6.dat (140 s)
  • profile_oldC_100x2p4.dat (274 s)

It appears that the newer versions of numpy are using multiple cores, whereas the embedded C code uses just one core. So optimized code does much less than it used to.

The winner here seems to be just doing jit for the most time-intensive functions. My theory was it wasn't faster because of needing to come in and out of the JIT code, but it slows down when I try to wrap more. Not sure why it slows down when wrapping the entire adaptive loop - I don't know things this well (hence attaching the profiling).

We could let YTZ take a look - he's pretty invested in getting things like this working faster, and seems to enjoy it.

@jchodera
Copy link
Member

jchodera commented Jun 29, 2020

@maxentile also has a lot of experience with JAX, and can hopefully give us some pointers!

numpy is hard to beat given that it's multithreaded, so it's impressive that JAX is doing better. But I think the GPU acceleration will be the key here (provided it supports things like embedded Intel iGPUs).

cc @proteneer as well.

@proteneer
Copy link

I will comment in the PR directly.

@mrshirts
Copy link
Collaborator Author

But I think the GPU acceleration will be the key here (provided it supports things like embedded Intel iGPUs).

Could someone (like @jaimergp) potentially test this? I don't have that many GPU varieties around (like 1, and would still need to get it going with the driver).

@mrshirts
Copy link
Collaborator Author

mrshirts commented Jun 30, 2020

Rewrote the jit so we just accelerate the inner loop of the adaptive method, rather than the entire method (this is PR #402, cancelling PR #400 which was too slow). This brought the timing down to about the same as the partial jit (PR #399) pass that just optimized the most expensive functions.

For the 5 state, 5M samples each problem compared to the previous fastest (#399), the time dropped from 59 s to 44s. For 100 states, 20K samples each, the time stayed about the same dropping from 72 s to 71s. There appears to be enough variation in the timing that I'll need to design the profiling a bit better, though.

@mrshirts
Copy link
Collaborator Author

mrshirts commented Jun 30, 2020

So @jchodera and @jaimergp will it be possible to make it easy for people to install if jax and jaxlib need to be pulled down via pip instead of conda installed?

@jchodera
Copy link
Member

Have you already tried the conda versions?
https://anaconda.org/conda-forge/jax/files

@mrshirts
Copy link
Collaborator Author

Well, yes, that's what I tried first. Unfortunately, I didn't document various things that went wrong. Currently, I'm getting (this was not the first error)

Solving environment: failed with initial frozen solve. Retrying with flexible solve.
Solving environment: failed

ResolvePackageNotFound:

  • conda=4.3.25

Which seems to indicate that there is some issue with resolving different packages that could be worked out eventually in a clean environment, but I haven't had the chance to do that yet (though that wasn't the error I had before - I should have noted it more carefully).

@mrshirts
Copy link
Collaborator Author

Investigating whether we can do floats - mbar_gradient does not appear to handle floats. Even after preconditioning, the first line:

log_denominator_n = logsumexp(f_k - u_kn.T, b=jNk, axis=1) ends

Ends up having 94% of the entries being pure zeros, and so the sum ends up sufficiently wrong that the weights don't exactly sum to 0 at the end.

Will keep looking at alternate approaches. Maybe the jax automatic gradient will result in something more numerically stable?

@proteneer
Copy link

proteneer commented Jul 1, 2020

I did some benchmarking on the iteration time, ie. time it takes to cycle through one trip in the adaptive code. I think we need to find a way to make the logsumexp more stable, as I'm also seeing the numerical precision issues reported by @mrshirts above. Generally though, it looks like the GPU is about 10x (64bit) to 14x (32bit) time faster than the CPU. The large initial iteration time is due to JIT. Though I suspect a handwritten kernel will probably overall 30-50x faster.

CPU 64 bit
iteration time: 103.07413482666016
iteration time: 3.7705514430999756
iteration time: 3.673891305923462
iteration time: 3.706911087036133
iteration time: 3.643885850906372
iteration time: 3.8017430305480957

CPU 32 bit
iteration time: 60.54610633850098
iteration time: 0.7477660179138184
iteration time: 0.8013427257537842
iteration time: 0.7045493125915527
iteration time: 0.7167036533355713
iteration time: 0.731112003326416
iteration time: 0.7113027572631836
GPU 64 bit
iteration time: 88.80276942253113
iteration time: 0.33150577545166016
iteration time: 0.3132588863372803
iteration time: 0.3524589538574219
iteration time: 0.3147101402282715
iteration time: 0.3501741886138916

GPU 32 bit
iteration time: 73.53490352630615
iteration time: 0.056365013122558594
iteration time: 0.05306673049926758
iteration time: 0.0488278865814209
iteration time: 0.04876542091369629
iteration time: 0.051480770111083984
iteration time: 0.045372724533081055
iteration time: 0.04701972007751465
iteration time: 0.055937767028808594
iteration time: 0.05077195167541504
iteration time: 0.051758527755737305
iteration time: 0.05005002021789551
iteration time: 0.04569387435913086
iteration time: 0.046376705169677734
iteration time: 0.05036187171936035
iteration time: 0.056464433670043945
iteration time: 0.04788923263549805
iteration time: 0.04864335060119629

@mrshirts
Copy link
Collaborator Author

mrshirts commented Jul 1, 2020

Thanks so much for testing this! I was still getting my CUDA set up going on another machine.

Though I suspect a handwritten kernel will probably overall 30-50x faster.

Well, the ideal would be to move to a framework we DON'T have to handwrite the kernel. I wrote a C++ kernel for pymbar 2.0 that was much faster than numpy, but it took a lot of work, and became obsolete fairly soon (it's too slow now because it doesn't multithread, plus I'd have to rewrite for python3.

I think we need to find a way to make the logsumexp more stable, as I'm also seeing the numerical precision issues reported by @mrshirts above.

I think that is indeed the quickest route to improved performance, though I'm not sure the best way to do this. Fundamentally, exp(-energy) gets small fast, so even reordering the sum wouldn't help.

For the updated code PR #402, then for 64-bit CPU, I'm getting 6 seconds for the first iteration, and then 3.7-3.8 for the rest of them. So that is probably an OK way to go in most cases. No extra time in compilation on subsequent runs of MBAR in the same loop.

Can you try the comparison above with PR #402? A reduction in time with the GPU from 2.3+11(3.7) = 34 to 2.3 * 11(0.7) = 10 is not so bad to get even with 64 bit. Do most GPU's handle doubles now, or just the higher end ones?

I'd be satisfied with something that was 2-6 times faster than pymbar3 (pymbar4 speed should be about equal).

@mrshirts
Copy link
Collaborator Author

mrshirts commented Jul 1, 2020

I was able to get jax and jaxlib installed on a fresh conda install on my computer, so that can be overcome.

Note that the right CUDA developer kit needs to be installed from NVIDIA to run jax with GPUs - I don't think that will be automatable, since you need a (free) NVIDIA developer account to get them.

@jchodera
Copy link
Member

jchodera commented Jul 1, 2020

You can install cudatoolkit via conda (at least for several OSs):
https://anaconda.org/anaconda/cudatoolkit/files

@mrshirts
Copy link
Collaborator Author

mrshirts commented Jul 1, 2020

It apparently needs https://developer.nvidia.com/cudnn for some reason, which I don't think is in the toolkit. But I could be wrong. . . .

@jaimergp
Copy link
Member

jaimergp commented Jul 1, 2020

cudatoolkit does not ship cudnn, but cudnn is available as a separate package in defaults.

More info: conda-forge/conda-forge.github.io#687 (comment)

@mrshirts
Copy link
Collaborator Author

mrshirts commented Jul 1, 2020

cudatoolkit does not ship cudnn, but cudnn is available as a separate package in defaults.

OK, that's good to know that they may be a route.

@jaimergp
Copy link
Member

jaimergp commented Jul 3, 2020

Is this going to be optional or required? cudatoolkit + cudnn amount to ~700-900MB depending on the OS and version.

@mrshirts
Copy link
Collaborator Author

mrshirts commented Jul 3, 2020

I think it would be better to make it optional, since many computers would not have GPUs anyway. If we can automatically download jit and jax, that would be good, since that gives a 2x speedup in many applications. But it should run without GPU support.

I don't know if cudatoolkit is required, but I'll let you know as I keep playing around on this.

@jaimergp
Copy link
Member

jaimergp commented Jul 3, 2020

Cool. We can play around with the conda packaging to provide easy options for both approaches.

@jchodera
Copy link
Member

jchodera commented Jul 6, 2020

Looks like JAX may add ROCm support sometime soon.

jax-ml/jax#2012
Another possibility is for us to add pyopencl acceleration:

https://documen.tician.de/pyopencl/

@mrshirts
Copy link
Collaborator Author

mrshirts commented Jul 6, 2020

I'd say baby steps. The #1 goal is ease of use (including installation), computational efficiency is good but second priority. I think we are generally getting to a point where memory size may be limiting for large systems rather than flops per second limitations.

@jchodera
Copy link
Member

jchodera commented Apr 2, 2022

jax is now available on conda-forge, so we may want to dump the C extensions and switch to jax.

@mrshirts
Copy link
Collaborator Author

mrshirts commented Apr 2, 2022

I started working on this in #399 and #400, though I probably won't be able to do full testing until end of the semester (~1 month).

@mrshirts
Copy link
Collaborator Author

mrshirts commented Jun 29, 2022

OK, #447 is a new attempt, will be working to get this going for real over the next week or so. It is about 2x as fast on CPU. I'm still trying to figure out how to make it play well with all the options (i.e. fails some tests when numpy and JAX mix).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

4 participants