Skip to content

Commit

Permalink
Add periodic orbital MCMC
Browse files Browse the repository at this point in the history
  • Loading branch information
albcab authored and rlouf committed Apr 29, 2022
1 parent 8bb3761 commit 6e31022
Show file tree
Hide file tree
Showing 8 changed files with 1,412 additions and 5 deletions.
2 changes: 2 additions & 0 deletions blackjax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
hmc,
mala,
nuts,
orbital_hmc,
rmh,
tempered_smc,
window_adaptation,
Expand All @@ -16,6 +17,7 @@
"hmc", # mcmc
"mala",
"nuts",
"orbital_hmc",
"rmh",
"window_adaptation", # mcmc adaptation
"adaptive_tempered_smc", # smc
Expand Down
79 changes: 79 additions & 0 deletions blackjax/kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"hmc",
"mala",
"nuts",
"orbital_hmc",
"rmh",
"tempered_smc",
"window_adaptation",
Expand Down Expand Up @@ -545,3 +546,81 @@ def step_fn(rng_key: PRNGKey, state):
)

return SamplingAlgorithm(init_fn, step_fn)


class orbital_hmc:
"""Implements the (basic) user interface for the Periodic orbital MCMC kernel
Each iteration of the periodic orbital MCMC outputs ``period`` weighted samples from
a single Hamiltonian orbit connecting the previous sample and momentum (latent) variable
with precision matrix ``inverse_mass_matrix``, evaluated using the ``bijection`` as an
integrator with discretization parameter ``step_size``.
Examples
--------
A new Periodic orbital MCMC kernel can be initialized and used with the following code:
.. code::
per_orbit = blackjax.orbital_hmc(logprob_fn, step_size, inverse_mass_matrix, period)
state = per_orbit.init(position)
new_state, info = per_orbit.step(rng_key, state)
We can JIT-compile the step function for better performance
.. code::
step = jax.jit(per_orbit.step)
new_state, info = step(rng_key, state)
Parameters
----------
logprob_fn
The logarithm of the probability density function we wish to draw samples from. This
is minus the potential energy function.
step_size
The value to use for the step size in for the symplectic integrator to buid the orbit.
inverse_mass_matrix
The value to use for the inverse mass matrix when drawing a value for
the momentum and computing the kinetic energy.
period
The number of steps used to build the orbit.
bijection
(algorithm parameter) The symplectic integrator to use to build the orbit.
Returns
-------
A ``SamplingAlgorithm``.
"""

init = staticmethod(mcmc.periodic_orbital.init)
kernel = staticmethod(mcmc.periodic_orbital.kernel)

def __new__( # type: ignore[misc]
cls,
logprob_fn: Callable,
step_size: float,
inverse_mass_matrix: Array, # assume momentum is always Gaussian
period: int,
*,
bijection: Callable = mcmc.integrators.velocity_verlet,
) -> SamplingAlgorithm:

step = cls.kernel(bijection)

def init_fn(position: PyTree):
return cls.init(position, logprob_fn, period)

def step_fn(rng_key: PRNGKey, state):
return step(
rng_key,
state,
logprob_fn,
step_size,
inverse_mass_matrix,
period,
)

return SamplingAlgorithm(init_fn, step_fn)
4 changes: 2 additions & 2 deletions blackjax/mcmc/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from . import hmc, mala, nuts, rmh
from . import hmc, mala, nuts, periodic_orbital, rmh

__all__ = ["hmc", "mala", "nuts", "rmh"]
__all__ = ["hmc", "mala", "nuts", "periodic_orbital", "rmh"]
Loading

0 comments on commit 6e31022

Please sign in to comment.