diff --git a/aehmc/hmc.py b/aehmc/hmc.py index c42e701..326bed5 100644 --- a/aehmc/hmc.py +++ b/aehmc/hmc.py @@ -13,7 +13,7 @@ def kernel( srng: RandomStream, - potential_fn: Callable[[TensorVariable], TensorVariable], + logprob_fn: TensorVariable, step_size: TensorVariable, inverse_mass_matrix: TensorVariable, num_integration_steps: TensorVariable, @@ -25,8 +25,9 @@ def kernel( ---------- srng RandomStream object. - potential_fn - A function that returns the potential energy of a chain at a given position. + logprob_fn + A function that returns the value of the log-probability density + function of a chain at a given position. step_size The step size used in the symplectic integrator inverse_mass_matrix @@ -46,6 +47,9 @@ def kernel( """ + def potential_fn(x): + return -logprob_fn(x) + momentum_generator, kinetic_energy_fn, _ = metrics.gaussian_metric( inverse_mass_matrix ) diff --git a/tests/test_hmc.py b/tests/test_hmc.py index e792fd2..82640b9 100644 --- a/tests/test_hmc.py +++ b/tests/test_hmc.py @@ -11,7 +11,7 @@ import aehmc.nuts as nuts -def normal_logp(q: TensorVariable): +def normal_potential(q: TensorVariable): return aet.sum(aet.square(q - 3.0)) @@ -29,8 +29,11 @@ def build_hmc_trajectory_generator( inverse_mass_matrix = aet.vector("inverse_mass_matrix") num_integration_steps = aet.scalar("num_integration_steps", dtype="int32") + def logprob_fn(x): + return -potential_fn(x) + kernel = kernel_generator( - srng, potential_fn, step_size, inverse_mass_matrix, num_integration_steps + srng, logprob_fn, step_size, inverse_mass_matrix, num_integration_steps ) trajectory, updates = aesara.scan( @@ -60,7 +63,7 @@ def test_hmc(): inverse_mass_matrix = np.array([1.0]) trajectory_generator = build_hmc_trajectory_generator( - srng, hmc.kernel, normal_logp, 50_000 + srng, hmc.kernel, normal_potential, 50_000 ) positions, *_ = trajectory_generator( initial_position, step_size, inverse_mass_matrix, num_integration_steps @@ -74,13 +77,13 @@ def test_nuts(): srng = RandomStream(seed=59) q = aet.vector("q") - potential_energy = normal_logp(q) + potential_energy = normal_potential(q) potential_energy_grad = aesara.grad(potential_energy, wrt=q) step_size = aet.scalar("step_size", dtype="float64") inverse_mass_matrix = aet.vector("inverse_mass_matrix", dtype="float64") - kernel = nuts.kernel(srng, normal_logp, step_size, inverse_mass_matrix) + kernel = nuts.kernel(srng, normal_potential, step_size, inverse_mass_matrix) result, updates = kernel(q, potential_energy, potential_energy_grad) trajectory_generator = aesara.function(