Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Implement the grapevine method for faster HMC with steady states #25

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

teddygroves
Copy link
Contributor

@teddygroves teddygroves commented Oct 4, 2024

This change introduces a new sampling method called "grapevine" (as in "I heard it on the grapevine"), as required to address #26.

The idea is to speed up the process of simulating a Hamiltonian trajectory, which an HMC sampler needs to do once per MCMC iteration.

To simulate a trajectory, the sampler needs to repeatedly take little steps through log parameter space and evaluate the model's log probability function and its gradients. When the model involves numerically finding the solution to a system of equations this can be quite slow. The grapevine method exploits the fact that the steps are small, so that the solution at the previous step is likely to be a pretty good guess for the next solution. By saving each solution, we can hopefully make systematically better guesses, which is well-known to be a good way to speed up equation solvers.

Since an HMC trajectory looks kind of like a grapevine, and the idea is kind of like a whisper game, I thought that "grapevine" was a nice name for the method.

Blackjax provides all the tools we need to implement the grapevine method with the NUTS sampler. The main challenge is to make a new integrator that can handle a log density function that both takes in and returns an equation solution guess.

Current status: the example in scripts/mcmc_demo.py works (and goes pretty fast: on my laptop one chain with 200 warmup and samples took 5:29 vs 38:04 on the main branch). Models with a number of state variables different from 5 won't work because that is currently hard coded.

Checklist:

  • tests pass
  • README.md up to date
  • docs up to date
  • link to any relevant issues

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants