-
Notifications
You must be signed in to change notification settings - Fork 93
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Accelerate the code, starting with the inner loops #340
Comments
I'd strongly suggest running It will produce a visualization similar to this one, which I found extremely informative to find the bottlenecks. |
Ah, I had been using cProfile, but was poring over the text: I didn't know about snakeviz. Here's three sample profiles of a set of harmonic oscillators, using the adaptive method - all of them converged in 3 self-consistent and 3 Newton-Raphson iterations. 100 states, 25000 samples per state: profile.100x25000.txt I'm not posting the graphs here, since you need to do a little interactive work to see what is going on, but you can visualize them easily with this. All of these took about 3-4 minutes. Note that anything bigger (size scales roughly (nsamples/state)*(n_state)^2 starts to cause memory issues, and start running much more slowly because of page faults. All of these ran at about 10GB, and took about the same amount of time (1000x300 was slower, but bigger). Interestingly, I noticed it cycling back and forth between 100% and 400% CPU usage, so some aspects are well-parallelized, and some are not. So what I identified as the slow parts were indeed, but the balance is not as clear.
For example, let's look at
Looking at the 1000x300 case,
Also with 2
Which is 10.557 cumtime/call, with mbar_W_kn 7.705 cumtime/call, with the dot product being a lot of the remainder. Anyway, it appears that MOST of the cost is in just 4-5 short routines, but it's not ALL in I'll let other people digest these for a bit, and come back later after I've thought about it a bit. |
We are getting closer to getting all the cleanup done, and starting to look at the right way to accelerate. @jchodera, you seem to be getting rather busy - do you have some ideas for doing the acceleration that I could try? |
OK, I started with Jax/Jit, with a couple of drafts. These are WIP PRs #398 (vanilla pymbar4) #399 (partial jax/jit), and #400 (full adaptive routine in jax/jit) A few comments.
I benchmarked on 5 states with 5M points each, and 100 states with 20K points each. The first took about 8GB of memory, the 2nd about 10 GB. I didn't go higher since I didn't want timings limited by writing to disk.
Files here: profiles.gz (Update! Bonus: I was able to get the embedded C code working for pymbar 2 again on the same system). Python 2 profiles, so harder to get read, so I'll just included the total time here.
It appears that the newer versions of numpy are using multiple cores, whereas the embedded C code uses just one core. So optimized code does much less than it used to. The winner here seems to be just doing jit for the most time-intensive functions. My theory was it wasn't faster because of needing to come in and out of the JIT code, but it slows down when I try to wrap more. Not sure why it slows down when wrapping the entire adaptive loop - I don't know things this well (hence attaching the profiling). We could let YTZ take a look - he's pretty invested in getting things like this working faster, and seems to enjoy it. |
@maxentile also has a lot of experience with JAX, and can hopefully give us some pointers!
cc @proteneer as well. |
I will comment in the PR directly. |
Could someone (like @jaimergp) potentially test this? I don't have that many GPU varieties around (like 1, and would still need to get it going with the driver). |
Rewrote the jit so we just accelerate the inner loop of the adaptive method, rather than the entire method (this is PR #402, cancelling PR #400 which was too slow). This brought the timing down to about the same as the partial jit (PR #399) pass that just optimized the most expensive functions. For the 5 state, 5M samples each problem compared to the previous fastest (#399), the time dropped from 59 s to 44s. For 100 states, 20K samples each, the time stayed about the same dropping from 72 s to 71s. There appears to be enough variation in the timing that I'll need to design the profiling a bit better, though. |
Have you already tried the conda versions? |
Well, yes, that's what I tried first. Unfortunately, I didn't document various things that went wrong. Currently, I'm getting (this was not the first error) Solving environment: failed with initial frozen solve. Retrying with flexible solve. ResolvePackageNotFound:
Which seems to indicate that there is some issue with resolving different packages that could be worked out eventually in a clean environment, but I haven't had the chance to do that yet (though that wasn't the error I had before - I should have noted it more carefully). |
Investigating whether we can do floats - mbar_gradient does not appear to handle floats. Even after preconditioning, the first line:
Ends up having 94% of the entries being pure zeros, and so the sum ends up sufficiently wrong that the weights don't exactly sum to 0 at the end. Will keep looking at alternate approaches. Maybe the jax automatic gradient will result in something more numerically stable? |
I did some benchmarking on the iteration time, ie. time it takes to cycle through one trip in the adaptive code. I think we need to find a way to make the logsumexp more stable, as I'm also seeing the numerical precision issues reported by @mrshirts above. Generally though, it looks like the GPU is about 10x (64bit) to 14x (32bit) time faster than the CPU. The large initial iteration time is due to JIT. Though I suspect a handwritten kernel will probably overall 30-50x faster.
|
Thanks so much for testing this! I was still getting my CUDA set up going on another machine.
Well, the ideal would be to move to a framework we DON'T have to handwrite the kernel. I wrote a C++ kernel for pymbar 2.0 that was much faster than numpy, but it took a lot of work, and became obsolete fairly soon (it's too slow now because it doesn't multithread, plus I'd have to rewrite for python3.
I think that is indeed the quickest route to improved performance, though I'm not sure the best way to do this. Fundamentally, exp(-energy) gets small fast, so even reordering the sum wouldn't help. For the updated code PR #402, then for 64-bit CPU, I'm getting 6 seconds for the first iteration, and then 3.7-3.8 for the rest of them. So that is probably an OK way to go in most cases. No extra time in compilation on subsequent runs of MBAR in the same loop. Can you try the comparison above with PR #402? A reduction in time with the GPU from 2.3+11(3.7) = 34 to 2.3 * 11(0.7) = 10 is not so bad to get even with 64 bit. Do most GPU's handle doubles now, or just the higher end ones? I'd be satisfied with something that was 2-6 times faster than pymbar3 (pymbar4 speed should be about equal). |
I was able to get jax and jaxlib installed on a fresh conda install on my computer, so that can be overcome. Note that the right CUDA developer kit needs to be installed from NVIDIA to run jax with GPUs - I don't think that will be automatable, since you need a (free) NVIDIA developer account to get them. |
You can install |
It apparently needs https://developer.nvidia.com/cudnn for some reason, which I don't think is in the toolkit. But I could be wrong. . . . |
|
OK, that's good to know that they may be a route. |
Is this going to be optional or required? |
I think it would be better to make it optional, since many computers would not have GPUs anyway. If we can automatically download jit and jax, that would be good, since that gives a 2x speedup in many applications. But it should run without GPU support. I don't know if cudatoolkit is required, but I'll let you know as I keep playing around on this. |
Cool. We can play around with the conda packaging to provide easy options for both approaches. |
Looks like JAX may add ROCm support sometime soon. jax-ml/jax#2012 |
I'd say baby steps. The #1 goal is ease of use (including installation), computational efficiency is good but second priority. I think we are generally getting to a point where memory size may be limiting for large systems rather than flops per second limitations. |
|
OK, #447 is a new attempt, will be working to get this going for real over the next week or so. It is about 2x as fast on CPU. I'm still trying to figure out how to make it play well with all the options (i.e. fails some tests when numpy and JAX mix). |
We need to accelerate the inner loops.
I think the most important things to do would be to accelerate in this order
This is the largest cost.
After that is accelerated, it's probably worth it to look at:
The text was updated successfully, but these errors were encountered: