diff --git a/aehmc/window_adaptation.py b/aehmc/window_adaptation.py index b626ab0..de5fcbd 100644 --- a/aehmc/window_adaptation.py +++ b/aehmc/window_adaptation.py @@ -1,8 +1,9 @@ -from typing import Callable, List, Tuple +from typing import List, Tuple import aesara import aesara.tensor as at from aesara import config +from aesara.ifelse import ifelse from aesara.tensor.shape import shape_tuple from aesara.tensor.var import TensorVariable @@ -11,7 +12,7 @@ def run( - kernel_factory, + kernel, initial_state, num_steps=1000, *, @@ -20,13 +21,12 @@ def run( target_acceptance_rate=0.80 ): - init, update, final = window_adaptation( - kernel_factory, is_mass_matrix_full, initial_step_size, target_acceptance_rate + init_adapt, update_adapt, final_adapt = window_adaptation( + num_steps, is_mass_matrix_full, initial_step_size, target_acceptance_rate ) def one_step( - stage, # schedule - is_middle_window_end, + warmup_step, q, # chain state potential_energy, potential_energy_grad, @@ -35,61 +35,65 @@ def one_step( log_step_size_avg, gradient_avg, mu, - inverse_mass_matrix, # inverse mass matrix mean, # mass matrix adaptation state m2, sample_size, + step_size, # parameters + inverse_mass_matrix, ): chain_state = (q, potential_energy, potential_energy_grad) - warmup_state = ( (step, log_step_size, log_step_size_avg, gradient_avg, mu), - inverse_mass_matrix, (mean, m2, sample_size), ) + parameters = (step_size, inverse_mass_matrix) + + # Advance the chain by one step + chain_state, inner_updates = kernel(*chain_state, *parameters) - (chain_state, warmup_state), inner_updates = update( - stage, is_middle_window_end, chain_state, warmup_state + # Update the warmup state and parameters + warmup_state, parameters = update_adapt( + warmup_step, warmup_state, parameters, chain_state ) return ( - *chain_state, + chain_state[0], # q + chain_state[1], # potential_energy + chain_state[2], # potential_energy_grad *warmup_state[0], - warmup_state[1], - *warmup_state[2], + *warmup_state[1], + *parameters, ), inner_updates - schedule = build_schedule(num_steps) - stage = at.as_tensor([s[0] for s in schedule]) - is_middle_window_end = at.as_tensor([s[1] for s in schedule]) + (da_state, mm_state), parameters = init_adapt(initial_state) - da_state, inverse_mass_matrix, wc_state = init(initial_state) + warmup_steps = at.arange(0, num_steps) state, updates = aesara.scan( fn=one_step, - outputs_info=(*initial_state, *da_state, inverse_mass_matrix, *wc_state), - sequences=(stage, is_middle_window_end), + outputs_info=(*initial_state, *da_state, *mm_state, *parameters), + sequences=(warmup_steps,), + name="window_adaptation", ) last_chain_state = (state[0][-1], state[1][-1], state[2][-1]) - last_warmup_state = ( - (state[3][-1], state[4][-1], state[5][-1], state[6][-1], state[7][-1]), - state[8][-1], - (state[9][-1], state[10][-1], state[11][-1]), - ) - - step_size, inverse_mass_matrix = final(last_warmup_state) + step_size = state[-2][-1] + inverse_mass_matrix = state[-1][-1] return last_chain_state, (step_size, inverse_mass_matrix), updates def window_adaptation( - kernel_factory: Callable[[TensorVariable], Callable], + num_steps: int, is_mass_matrix_full: bool = False, initial_step_size: TensorVariable = at.as_tensor(1.0, dtype=config.floatX), target_acceptance_rate: TensorVariable = 0.80, ): mm_init, mm_update, mm_final = covariance_adaptation(is_mass_matrix_full) da_init, da_update = dual_averaging_adaptation(target_acceptance_rate) + schedule = build_schedule(num_steps) + + schedule_stage = at.as_tensor([s[0] for s in schedule]) + schedule_middle_window = at.as_tensor([s[1] for s in schedule]) def init(initial_chain_state: Tuple): if initial_chain_state[0].ndim == 0: @@ -98,114 +102,88 @@ def init(initial_chain_state: Tuple): num_dims = shape_tuple(initial_chain_state[0])[0] inverse_mass_matrix, mm_state = mm_init(num_dims) - step, logstepsize, logstepsize_avg, gradient_avg, mu = da_init( - initial_step_size - ) + da_state = da_init(initial_step_size) + step_size = at.exp(da_state[1]) - return ( - (step, logstepsize, logstepsize_avg, gradient_avg, mu), - inverse_mass_matrix, - mm_state, - ) + warmup_state = (da_state, mm_state) + parameters = (step_size, inverse_mass_matrix) + return warmup_state, parameters + + def fast_update(p_accept, warmup_state, parameters): + da_state, mm_state = warmup_state + _, inverse_mass_matrix = parameters + + new_da_state = da_update(p_accept, *da_state) + step_size = at.exp(new_da_state[1]) - def fast_update(p_accept, da_state, inverse_mass_matrix, mm_state): - da_state = da_update(p_accept, *da_state) - return (da_state, inverse_mass_matrix, mm_state) + return (new_da_state, mm_state), (step_size, inverse_mass_matrix) - def slow_update(position, p_accept, da_state, inverse_mass_matrix, mm_state): - da_state = da_update(p_accept, *da_state) - mm_state = mm_update(position, mm_state) - return (da_state, inverse_mass_matrix, mm_state) + def slow_update(position, p_accept, warmup_state, parameters): + da_state, mm_state = warmup_state + _, inverse_mass_matrix = parameters + + new_da_state = da_update(p_accept, *da_state) + new_mm_state = mm_update(position, mm_state) + step_size = at.exp(new_da_state[1]) + + return (new_da_state, new_mm_state), (step_size, inverse_mass_matrix) def slow_final(warmup_state): """We recompute the inverse mass matrix and re-initialize the dual averaging scheme at the end of each 'slow window'.""" - da_state, inverse_mass_matrix, mm_state = warmup_state + da_state, mm_state = warmup_state - new_inverse_mass_matrix = mm_final(mm_state) - _, new_mm_state = mm_init(inverse_mass_matrix.ndim) + inverse_mass_matrix = mm_final(mm_state) + + if inverse_mass_matrix.ndim == 0: + num_dims = 0 + else: + num_dims = shape_tuple(inverse_mass_matrix)[0] + _, new_mm_state = mm_init(num_dims) step_size = at.exp(da_state[1]) - step, logstepsize, logstepsize_avg, gradient_avg, mu = da_init(step_size) - return ( - (step, logstepsize, logstepsize_avg, gradient_avg, mu), - new_inverse_mass_matrix, - new_mm_state, - ) + new_da_state = da_init(step_size) - def update( - stage: int, is_middle_window_end: bool, chain_state: Tuple, warmup_state: Tuple - ): - da_state, inverse_mass_matrix, mm_state = warmup_state + warmup_state = (new_da_state, new_mm_state) + parameters = (step_size, inverse_mass_matrix) + return warmup_state, parameters - step_size = at.exp(da_state[1]) - kernel = kernel_factory(inverse_mass_matrix) - (*chain_state, p_accept, _, _, _), updates = kernel( - *chain_state, step_size, inverse_mass_matrix - ) + def final( + warmup_state: Tuple, parameters: Tuple + ) -> Tuple[TensorVariable, TensorVariable]: + da_state, _ = warmup_state + _, inverse_mass_matrix = parameters + step_size = at.exp(da_state[2]) # return stepsize_avg at the end + return step_size, inverse_mass_matrix + + def update(step: int, warmup_state: Tuple, parameters: Tuple, chain_state: Tuple): + position, _, _, p_accept, *_ = chain_state - warmup_state = where_warmup_state( + stage = schedule_stage[step] + warmup_state, parameters = where_warmup_state( at.eq(stage, 0), - fast_update(p_accept, da_state, inverse_mass_matrix, mm_state), - slow_update( - chain_state[0], p_accept, da_state, inverse_mass_matrix, mm_state - ), + fast_update(p_accept, warmup_state, parameters), + slow_update(position, p_accept, warmup_state, parameters), ) - warmup_state = where_warmup_state( - is_middle_window_end, slow_final(warmup_state), warmup_state + + is_middle_window_end = schedule_middle_window[step] + warmup_state, parameters = where_warmup_state( + is_middle_window_end, slow_final(warmup_state), (warmup_state, parameters) ) - return (chain_state, warmup_state), updates + is_last_step = at.eq(step, num_steps - 1) + parameters = ifelse(is_last_step, final(warmup_state, parameters), parameters) - def final(warmup_state: Tuple) -> Tuple[TensorVariable, TensorVariable]: - da_state, inverse_mass_matrix, mm_state = warmup_state - step_size = at.exp(da_state[2]) # return stepsize_avg at the end - return step_size, inverse_mass_matrix + return warmup_state, parameters def where_warmup_state(do_pick_left, left_warmup_state, right_warmup_state): - ( - left_step, - left_logstepsize, - left_logstepsize_avg, - left_gradient_avg, - left_mu, - ) = left_warmup_state[0] - ( - right_step, - right_logstepsize, - right_logstepsize_avg, - right_gradient_avg, - right_mu, - ) = right_warmup_state[0] - - step = at.where(do_pick_left, left_step, right_step) - logstepsize = at.where(do_pick_left, left_logstepsize, right_logstepsize) - logstepsize_avg = at.where( - do_pick_left, left_logstepsize_avg, right_logstepsize_avg - ) - gradient_avg = at.where(do_pick_left, left_gradient_avg, right_gradient_avg) - mu = at.where(do_pick_left, left_mu, right_mu) + (left_da_state, left_mm_state), left_params = left_warmup_state + (right_da_state, right_mm_state), right_params = right_warmup_state - left_inverse_mass_matrix = left_warmup_state[1] - right_inverse_mass_matrix = right_warmup_state[1] - inverse_mass_matrix = at.where( - do_pick_left, left_inverse_mass_matrix, right_inverse_mass_matrix - ) - - right_mean, right_m2, right_sample_size = right_warmup_state[2] - left_mean, left_m2, left_sample_size = left_warmup_state[2] - mean = at.where(do_pick_left, left_mean, right_mean) - m2 = at.where(do_pick_left, left_m2, right_m2) - sample_size = at.where(do_pick_left, left_sample_size, right_sample_size) + da_state = ifelse(do_pick_left, left_da_state, right_da_state) + mm_state = ifelse(do_pick_left, left_mm_state, right_mm_state) + params = ifelse(do_pick_left, left_params, right_params) - return ( - (step, logstepsize, logstepsize_avg, gradient_avg, mu), - inverse_mass_matrix, - ( - mean, - m2, - sample_size, - ), - ) + return (da_state, mm_state), params return init, update, final diff --git a/tests/test_hmc.py b/tests/test_hmc.py index 4284954..2758497 100644 --- a/tests/test_hmc.py +++ b/tests/test_hmc.py @@ -20,14 +20,12 @@ def logprob_fn(y: TensorVariable): logprob = joint_logprob({Y_rv: y}) return logprob - def kernel_factory(inverse_mass_matrix: TensorVariable): - return nuts.new_kernel(srng, logprob_fn) - y_vv = Y_rv.clone() + kernel = nuts.new_kernel(srng, logprob_fn) initial_state = nuts.new_state(y_vv, logprob_fn) state, (step_size, inverse_mass_matrix), updates = window_adaptation.run( - kernel_factory, initial_state, num_steps=1000 + kernel, initial_state, num_steps=1000 ) # Compile the warmup and execute to get a value for the step size and the @@ -42,6 +40,7 @@ def kernel_factory(inverse_mass_matrix: TensorVariable): assert final_state[0] != 3.0 # the chain has moved assert np.ndim(step_size) == 0 # scalar step size + assert step_size != 1.0 # step size changed assert step_size > 0.1 and step_size < 2 # stable range for the step size assert np.ndim(inverse_mass_matrix) == 0 # scalar mass matrix assert inverse_mass_matrix == pytest.approx(4, rel=1.0) @@ -61,14 +60,12 @@ def logprob_fn(y: TensorVariable): logprob = joint_logprob({Y_rv: y}) return logprob - def kernel_factory(inverse_mass_matrix: TensorVariable): - return nuts.new_kernel(srng, logprob_fn) - y_vv = Y_rv.clone() + kernel = nuts.new_kernel(srng, logprob_fn) initial_state = nuts.new_state(y_vv, logprob_fn) state, (step_size, inverse_mass_matrix), updates = window_adaptation.run( - kernel_factory, initial_state, num_steps=1000 + kernel, initial_state, num_steps=1000 ) # Compile the warmup and execute to get a value for the step size and the