Skip to content

Commit

Permalink
Build HMC with logprob instead of potential
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Sep 1, 2021
1 parent b253dc0 commit 789d1d7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
10 changes: 7 additions & 3 deletions aehmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

def kernel(
srng: RandomStream,
potential_fn: Callable[[TensorVariable], TensorVariable],
logprob_fn: TensorVariable,
step_size: TensorVariable,
inverse_mass_matrix: TensorVariable,
num_integration_steps: TensorVariable,
Expand All @@ -25,8 +25,9 @@ def kernel(
----------
srng
RandomStream object.
potential_fn
A function that returns the potential energy of a chain at a given position.
logprob_fn
A function that returns the value of the log-probability density
function of a chain at a given position.
step_size
The step size used in the symplectic integrator
inverse_mass_matrix
Expand All @@ -46,6 +47,9 @@ def kernel(
"""

def potential_fn(x):
return -logprob_fn(x)

momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_metric(
inverse_mass_matrix
)
Expand Down
13 changes: 8 additions & 5 deletions tests/test_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import aehmc.nuts as nuts


def normal_logp(q: TensorVariable):
def normal_potential(q: TensorVariable):
return aet.sum(aet.square(q - 3.0))


Expand All @@ -29,8 +29,11 @@ def build_hmc_trajectory_generator(
inverse_mass_matrix = aet.vector("inverse_mass_matrix")
num_integration_steps = aet.scalar("num_integration_steps", dtype="int32")

def logprob_fn(x):
return -potential_fn(x)

kernel = kernel_generator(
srng, potential_fn, step_size, inverse_mass_matrix, num_integration_steps
srng, logprob_fn, step_size, inverse_mass_matrix, num_integration_steps
)

trajectory, updates = aesara.scan(
Expand Down Expand Up @@ -60,7 +63,7 @@ def test_hmc():
inverse_mass_matrix = np.array([1.0])

trajectory_generator = build_hmc_trajectory_generator(
srng, hmc.kernel, normal_logp, 50_000
srng, hmc.kernel, normal_potential, 50_000
)
positions, *_ = trajectory_generator(
initial_position, step_size, inverse_mass_matrix, num_integration_steps
Expand All @@ -74,13 +77,13 @@ def test_nuts():
srng = RandomStream(seed=59)

q = aet.vector("q")
potential_energy = normal_logp(q)
potential_energy = normal_potential(q)
potential_energy_grad = aesara.grad(potential_energy, wrt=q)

step_size = aet.scalar("step_size", dtype="float64")
inverse_mass_matrix = aet.vector("inverse_mass_matrix", dtype="float64")

kernel = nuts.kernel(srng, normal_logp, step_size, inverse_mass_matrix)
kernel = nuts.kernel(srng, normal_potential, step_size, inverse_mass_matrix)
result, updates = kernel(q, potential_energy, potential_energy_grad)

trajectory_generator = aesara.function(
Expand Down

0 comments on commit 789d1d7

Please sign in to comment.