[WIP] Implement the grapevine method for faster HMC with steady states #25
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This change introduces a new sampling method called "grapevine" (as in "I heard it on the grapevine"), as required to address #26.
The idea is to speed up the process of simulating a Hamiltonian trajectory, which an HMC sampler needs to do once per MCMC iteration.
To simulate a trajectory, the sampler needs to repeatedly take little steps through log parameter space and evaluate the model's log probability function and its gradients. When the model involves numerically finding the solution to a system of equations this can be quite slow. The grapevine method exploits the fact that the steps are small, so that the solution at the previous step is likely to be a pretty good guess for the next solution. By saving each solution, we can hopefully make systematically better guesses, which is well-known to be a good way to speed up equation solvers.
Since an HMC trajectory looks kind of like a grapevine, and the idea is kind of like a whisper game, I thought that "grapevine" was a nice name for the method.
Blackjax provides all the tools we need to implement the grapevine method with the NUTS sampler. The main challenge is to make a new integrator that can handle a log density function that both takes in and returns an equation solution guess.
Current status: the example in
scripts/mcmc_demo.py
works (and goes pretty fast: on my laptop one chain with 200 warmup and samples took 5:29 vs 38:04 on the main branch). Models with a number of state variables different from 5 won't work because that is currently hard coded.Checklist:
README.md
up to date