From 3cd82d1fa363401757f2517c24e5a420cca98969 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Tue, 1 Jun 2021 12:18:12 +0200 Subject: [PATCH] add hmc test --- aehmc/hmc.py | 7 +++-- tests/test_hmc.py | 76 +++++++++++++++++++++++++---------------------- 2 files changed, 45 insertions(+), 38 deletions(-) diff --git a/aehmc/hmc.py b/aehmc/hmc.py index b514fb2..1b33125 100644 --- a/aehmc/hmc.py +++ b/aehmc/hmc.py @@ -38,6 +38,7 @@ def step( potential_energy: TensorVariable, potential_energy_grad: TensorVariable, ): + """Perform a single step of the HMC algorithm.""" p = momentum_generator(srng) ( q_new, @@ -75,11 +76,11 @@ def propose( ) # flip the momentum to keep detailed balance - new_p = -1.0 * new_p + flipped_p = -1.0 * new_p # compute transition-related quantities energy = potential_energy + kinetic_energy(p) - new_energy = new_potential_energy + kinetic_energy(new_p) + new_energy = new_potential_energy + kinetic_energy(flipped_p) delta_energy = energy - new_energy delta_energy = aet.where(aet.isnan(delta_energy), -np.inf, delta_energy) # is_transition_divergence = aet.abs(delta_energy) > divergence_threshold @@ -93,7 +94,7 @@ def propose( final_potential_energy_grad, ) = ifelse( do_accept, - (new_q, new_p, new_potential_energy, new_potential_energy_grad), + (new_q, flipped_p, new_potential_energy, new_potential_energy_grad), (q, p, potential_energy, potential_energy_grad), ) diff --git a/tests/test_hmc.py b/tests/test_hmc.py index 22f1206..8f5c58c 100644 --- a/tests/test_hmc.py +++ b/tests/test_hmc.py @@ -1,61 +1,67 @@ +from typing import Callable + import aesara import aesara.tensor as aet import aesara_hmc.hmc as hmc import numpy as np +import pytest from aesara.tensor.random.utils import RandomStream from aesara.tensor.var import TensorVariable -def test_hmc(): - def potential_fn(q: TensorVariable) -> TensorVariable: - return -1.0 / aet.power(aet.square(q[0]) + aet.square(q[1]), 0.5) +def normal_logp(q: TensorVariable): + return aet.sum(aet.square(q - 3.0)) - srng = RandomStream(seed=59) - - step_size = aet.scalar("step_size") - inverse_mass_matrix = aet.vector("inverse_mass_matrix") - num_integration_steps = aet.scalar("num_integration_steps", dtype="int32") +def build_trajectory_generator( + srng: RandomStream, + kernel_generator: Callable, + potential_fn: Callable, + num_states: int, +) -> Callable: q = aet.vector("q") potential_energy = potential_fn(q) potential_energy_grad = aesara.grad(potential_energy, wrt=q) - kernel = hmc.kernel( - srng, potential_fn, step_size, inverse_mass_matrix, num_integration_steps - ) - next_step = kernel(q, potential_energy, potential_energy_grad) + step_size = aet.scalar("step_size") + inverse_mass_matrix = aet.vector("inverse_mass_matrix") + num_integration_steps = aet.scalar("num_integration_steps", dtype="int32") - # Compile a function that returns the next state - step_fn = aesara.function( - (q, step_size, inverse_mass_matrix, num_integration_steps), next_step + kernel = kernel_generator( + srng, potential_fn, step_size, inverse_mass_matrix, num_integration_steps ) - # Compile a function that integrates the trajectory integrating several times - trajectory, _ = aesara.scan( + trajectory, updates = aesara.scan( fn=kernel, outputs_info=[ {"initial": q}, {"initial": potential_energy}, {"initial": potential_energy_grad}, ], - n_steps=1000, + n_steps=num_states, + ) + trajectory_generator = aesara.function( + (q, step_size, inverse_mass_matrix, num_integration_steps), + trajectory, + updates=updates, + ) + + return trajectory_generator + + +def test_hmc(): + """Test the HMC kernel on a simple potential.""" + srng = RandomStream(seed=59) + step_size = 0.003 + num_integration_steps = 10 + initial_position = np.array([1.0]) + inverse_mass_matrix = np.array([1.0]) + + trajectory_generator = build_trajectory_generator( + srng, hmc.kernel, normal_logp, 10_000 ) - traj = aesara.function( - (q, step_size, inverse_mass_matrix, num_integration_steps), trajectory + positions, *_ = trajectory_generator( + initial_position, step_size, inverse_mass_matrix, num_integration_steps ) - # Run - step_size = 0.01 - num_integration_steps = 3 - q = np.array([1.0, 1.0]) - inverse_mass_matrix = np.array([1.0, 1.0]) - - # This works - samples = [] - for _ in range(1_000): - q, *_ = step_fn(q, step_size, inverse_mass_matrix, num_integration_steps) - samples.append(q) - print(np.mean(np.array(samples))) - - # This doesn't - print(traj(q, step_size, inverse_mass_matrix, num_integration_steps)) + assert np.mean(positions[9000:], axis=0) == pytest.approx(3, 1e-1)