Skip to content

Commit

Permalink
add hmc test
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf authored and brandonwillard committed Jun 3, 2021
1 parent 0cc92f1 commit 3cd82d1
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 38 deletions.
7 changes: 4 additions & 3 deletions aehmc/hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def step(
potential_energy: TensorVariable,
potential_energy_grad: TensorVariable,
):
"""Perform a single step of the HMC algorithm."""
p = momentum_generator(srng)
(
q_new,
Expand Down Expand Up @@ -75,11 +76,11 @@ def propose(
)

# flip the momentum to keep detailed balance
new_p = -1.0 * new_p
flipped_p = -1.0 * new_p

# compute transition-related quantities
energy = potential_energy + kinetic_energy(p)
new_energy = new_potential_energy + kinetic_energy(new_p)
new_energy = new_potential_energy + kinetic_energy(flipped_p)
delta_energy = energy - new_energy
delta_energy = aet.where(aet.isnan(delta_energy), -np.inf, delta_energy)
# is_transition_divergence = aet.abs(delta_energy) > divergence_threshold
Expand All @@ -93,7 +94,7 @@ def propose(
final_potential_energy_grad,
) = ifelse(
do_accept,
(new_q, new_p, new_potential_energy, new_potential_energy_grad),
(new_q, flipped_p, new_potential_energy, new_potential_energy_grad),
(q, p, potential_energy, potential_energy_grad),
)

Expand Down
76 changes: 41 additions & 35 deletions tests/test_hmc.py
Original file line number Diff line number Diff line change
@@ -1,61 +1,67 @@
from typing import Callable

import aesara
import aesara.tensor as aet
import aesara_hmc.hmc as hmc
import numpy as np
import pytest
from aesara.tensor.random.utils import RandomStream
from aesara.tensor.var import TensorVariable


def test_hmc():
def potential_fn(q: TensorVariable) -> TensorVariable:
return -1.0 / aet.power(aet.square(q[0]) + aet.square(q[1]), 0.5)
def normal_logp(q: TensorVariable):
return aet.sum(aet.square(q - 3.0))

srng = RandomStream(seed=59)

step_size = aet.scalar("step_size")
inverse_mass_matrix = aet.vector("inverse_mass_matrix")
num_integration_steps = aet.scalar("num_integration_steps", dtype="int32")

def build_trajectory_generator(
srng: RandomStream,
kernel_generator: Callable,
potential_fn: Callable,
num_states: int,
) -> Callable:
q = aet.vector("q")
potential_energy = potential_fn(q)
potential_energy_grad = aesara.grad(potential_energy, wrt=q)

kernel = hmc.kernel(
srng, potential_fn, step_size, inverse_mass_matrix, num_integration_steps
)
next_step = kernel(q, potential_energy, potential_energy_grad)
step_size = aet.scalar("step_size")
inverse_mass_matrix = aet.vector("inverse_mass_matrix")
num_integration_steps = aet.scalar("num_integration_steps", dtype="int32")

# Compile a function that returns the next state
step_fn = aesara.function(
(q, step_size, inverse_mass_matrix, num_integration_steps), next_step
kernel = kernel_generator(
srng, potential_fn, step_size, inverse_mass_matrix, num_integration_steps
)

# Compile a function that integrates the trajectory integrating several times
trajectory, _ = aesara.scan(
trajectory, updates = aesara.scan(
fn=kernel,
outputs_info=[
{"initial": q},
{"initial": potential_energy},
{"initial": potential_energy_grad},
],
n_steps=1000,
n_steps=num_states,
)
trajectory_generator = aesara.function(
(q, step_size, inverse_mass_matrix, num_integration_steps),
trajectory,
updates=updates,
)

return trajectory_generator


def test_hmc():
"""Test the HMC kernel on a simple potential."""
srng = RandomStream(seed=59)
step_size = 0.003
num_integration_steps = 10
initial_position = np.array([1.0])
inverse_mass_matrix = np.array([1.0])

trajectory_generator = build_trajectory_generator(
srng, hmc.kernel, normal_logp, 10_000
)
traj = aesara.function(
(q, step_size, inverse_mass_matrix, num_integration_steps), trajectory
positions, *_ = trajectory_generator(
initial_position, step_size, inverse_mass_matrix, num_integration_steps
)

# Run
step_size = 0.01
num_integration_steps = 3
q = np.array([1.0, 1.0])
inverse_mass_matrix = np.array([1.0, 1.0])

# This works
samples = []
for _ in range(1_000):
q, *_ = step_fn(q, step_size, inverse_mass_matrix, num_integration_steps)
samples.append(q)
print(np.mean(np.array(samples)))

# This doesn't
print(traj(q, step_size, inverse_mass_matrix, num_integration_steps))
assert np.mean(positions[9000:], axis=0) == pytest.approx(3, 1e-1)

0 comments on commit 3cd82d1

Please sign in to comment.