Skip to content

Commit

Permalink
Decouple the adaptation from MCMC kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Jul 7, 2022
1 parent ba55e6a commit ea21742
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 123 deletions.
208 changes: 93 additions & 115 deletions aehmc/window_adaptation.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Callable, List, Tuple
from typing import List, Tuple

import aesara
import aesara.tensor as at
from aesara import config
from aesara.ifelse import ifelse
from aesara.tensor.shape import shape_tuple
from aesara.tensor.var import TensorVariable

Expand All @@ -11,7 +12,7 @@


def run(
kernel_factory,
kernel,
initial_state,
num_steps=1000,
*,
Expand All @@ -20,13 +21,12 @@ def run(
target_acceptance_rate=0.80
):

init, update, final = window_adaptation(
kernel_factory, is_mass_matrix_full, initial_step_size, target_acceptance_rate
init_adapt, update_adapt, final_adapt = window_adaptation(
num_steps, is_mass_matrix_full, initial_step_size, target_acceptance_rate
)

def one_step(
stage, # schedule
is_middle_window_end,
warmup_step,
q, # chain state
potential_energy,
potential_energy_grad,
Expand All @@ -35,61 +35,65 @@ def one_step(
log_step_size_avg,
gradient_avg,
mu,
inverse_mass_matrix, # inverse mass matrix
mean, # mass matrix adaptation state
m2,
sample_size,
step_size, # parameters
inverse_mass_matrix,
):
chain_state = (q, potential_energy, potential_energy_grad)

warmup_state = (
(step, log_step_size, log_step_size_avg, gradient_avg, mu),
inverse_mass_matrix,
(mean, m2, sample_size),
)
parameters = (step_size, inverse_mass_matrix)

# Advance the chain by one step
chain_state, inner_updates = kernel(*chain_state, *parameters)

(chain_state, warmup_state), inner_updates = update(
stage, is_middle_window_end, chain_state, warmup_state
# Update the warmup state and parameters
warmup_state, parameters = update_adapt(
warmup_step, warmup_state, parameters, chain_state
)

return (
*chain_state,
chain_state[0], # q
chain_state[1], # potential_energy
chain_state[2], # potential_energy_grad
*warmup_state[0],
warmup_state[1],
*warmup_state[2],
*warmup_state[1],
*parameters,
), inner_updates

schedule = build_schedule(num_steps)
stage = at.as_tensor([s[0] for s in schedule])
is_middle_window_end = at.as_tensor([s[1] for s in schedule])
(da_state, mm_state), parameters = init_adapt(initial_state)

da_state, inverse_mass_matrix, wc_state = init(initial_state)
warmup_steps = at.arange(0, num_steps)
state, updates = aesara.scan(
fn=one_step,
outputs_info=(*initial_state, *da_state, inverse_mass_matrix, *wc_state),
sequences=(stage, is_middle_window_end),
outputs_info=(*initial_state, *da_state, *mm_state, *parameters),
sequences=(warmup_steps,),
name="window_adaptation",
)

last_chain_state = (state[0][-1], state[1][-1], state[2][-1])
last_warmup_state = (
(state[3][-1], state[4][-1], state[5][-1], state[6][-1], state[7][-1]),
state[8][-1],
(state[9][-1], state[10][-1], state[11][-1]),
)

step_size, inverse_mass_matrix = final(last_warmup_state)
step_size = state[-2][-1]
inverse_mass_matrix = state[-1][-1]

return last_chain_state, (step_size, inverse_mass_matrix), updates


def window_adaptation(
kernel_factory: Callable[[TensorVariable], Callable],
num_steps: int,
is_mass_matrix_full: bool = False,
initial_step_size: TensorVariable = at.as_tensor(1.0, dtype=config.floatX),
target_acceptance_rate: TensorVariable = 0.80,
):
mm_init, mm_update, mm_final = covariance_adaptation(is_mass_matrix_full)
da_init, da_update = dual_averaging_adaptation(target_acceptance_rate)
schedule = build_schedule(num_steps)

schedule_stage = at.as_tensor([s[0] for s in schedule])
schedule_middle_window = at.as_tensor([s[1] for s in schedule])

def init(initial_chain_state: Tuple):
if initial_chain_state[0].ndim == 0:
Expand All @@ -98,114 +102,88 @@ def init(initial_chain_state: Tuple):
num_dims = shape_tuple(initial_chain_state[0])[0]
inverse_mass_matrix, mm_state = mm_init(num_dims)

step, logstepsize, logstepsize_avg, gradient_avg, mu = da_init(
initial_step_size
)
da_state = da_init(initial_step_size)
step_size = at.exp(da_state[1])

return (
(step, logstepsize, logstepsize_avg, gradient_avg, mu),
inverse_mass_matrix,
mm_state,
)
warmup_state = (da_state, mm_state)
parameters = (step_size, inverse_mass_matrix)
return warmup_state, parameters

def fast_update(p_accept, warmup_state, parameters):
da_state, mm_state = warmup_state
_, inverse_mass_matrix = parameters

new_da_state = da_update(p_accept, *da_state)
step_size = at.exp(new_da_state[1])

def fast_update(p_accept, da_state, inverse_mass_matrix, mm_state):
da_state = da_update(p_accept, *da_state)
return (da_state, inverse_mass_matrix, mm_state)
return (new_da_state, mm_state), (step_size, inverse_mass_matrix)

def slow_update(position, p_accept, da_state, inverse_mass_matrix, mm_state):
da_state = da_update(p_accept, *da_state)
mm_state = mm_update(position, mm_state)
return (da_state, inverse_mass_matrix, mm_state)
def slow_update(position, p_accept, warmup_state, parameters):
da_state, mm_state = warmup_state
_, inverse_mass_matrix = parameters

new_da_state = da_update(p_accept, *da_state)
new_mm_state = mm_update(position, mm_state)
step_size = at.exp(new_da_state[1])

return (new_da_state, new_mm_state), (step_size, inverse_mass_matrix)

def slow_final(warmup_state):
"""We recompute the inverse mass matrix and re-initialize the dual averaging scheme at the end of each 'slow window'."""
da_state, inverse_mass_matrix, mm_state = warmup_state
da_state, mm_state = warmup_state

new_inverse_mass_matrix = mm_final(mm_state)
_, new_mm_state = mm_init(inverse_mass_matrix.ndim)
inverse_mass_matrix = mm_final(mm_state)

if inverse_mass_matrix.ndim == 0:
num_dims = 0
else:
num_dims = shape_tuple(inverse_mass_matrix)[0]
_, new_mm_state = mm_init(num_dims)

step_size = at.exp(da_state[1])
step, logstepsize, logstepsize_avg, gradient_avg, mu = da_init(step_size)
return (
(step, logstepsize, logstepsize_avg, gradient_avg, mu),
new_inverse_mass_matrix,
new_mm_state,
)
new_da_state = da_init(step_size)

def update(
stage: int, is_middle_window_end: bool, chain_state: Tuple, warmup_state: Tuple
):
da_state, inverse_mass_matrix, mm_state = warmup_state
warmup_state = (new_da_state, new_mm_state)
parameters = (step_size, inverse_mass_matrix)
return warmup_state, parameters

step_size = at.exp(da_state[1])
kernel = kernel_factory(inverse_mass_matrix)
(*chain_state, p_accept, _, _, _), updates = kernel(
*chain_state, step_size, inverse_mass_matrix
)
def final(
warmup_state: Tuple, parameters: Tuple
) -> Tuple[TensorVariable, TensorVariable]:
da_state, _ = warmup_state
_, inverse_mass_matrix = parameters
step_size = at.exp(da_state[2]) # return stepsize_avg at the end
return step_size, inverse_mass_matrix

def update(step: int, warmup_state: Tuple, parameters: Tuple, chain_state: Tuple):
position, _, _, p_accept, *_ = chain_state

warmup_state = where_warmup_state(
stage = schedule_stage[step]
warmup_state, parameters = where_warmup_state(
at.eq(stage, 0),
fast_update(p_accept, da_state, inverse_mass_matrix, mm_state),
slow_update(
chain_state[0], p_accept, da_state, inverse_mass_matrix, mm_state
),
fast_update(p_accept, warmup_state, parameters),
slow_update(position, p_accept, warmup_state, parameters),
)
warmup_state = where_warmup_state(
is_middle_window_end, slow_final(warmup_state), warmup_state

is_middle_window_end = schedule_middle_window[step]
warmup_state, parameters = where_warmup_state(
is_middle_window_end, slow_final(warmup_state), (warmup_state, parameters)
)

return (chain_state, warmup_state), updates
is_last_step = at.eq(step, num_steps - 1)
parameters = ifelse(is_last_step, final(warmup_state, parameters), parameters)

def final(warmup_state: Tuple) -> Tuple[TensorVariable, TensorVariable]:
da_state, inverse_mass_matrix, mm_state = warmup_state
step_size = at.exp(da_state[2]) # return stepsize_avg at the end
return step_size, inverse_mass_matrix
return warmup_state, parameters

def where_warmup_state(do_pick_left, left_warmup_state, right_warmup_state):
(
left_step,
left_logstepsize,
left_logstepsize_avg,
left_gradient_avg,
left_mu,
) = left_warmup_state[0]
(
right_step,
right_logstepsize,
right_logstepsize_avg,
right_gradient_avg,
right_mu,
) = right_warmup_state[0]

step = at.where(do_pick_left, left_step, right_step)
logstepsize = at.where(do_pick_left, left_logstepsize, right_logstepsize)
logstepsize_avg = at.where(
do_pick_left, left_logstepsize_avg, right_logstepsize_avg
)
gradient_avg = at.where(do_pick_left, left_gradient_avg, right_gradient_avg)
mu = at.where(do_pick_left, left_mu, right_mu)
(left_da_state, left_mm_state), left_params = left_warmup_state
(right_da_state, right_mm_state), right_params = right_warmup_state

left_inverse_mass_matrix = left_warmup_state[1]
right_inverse_mass_matrix = right_warmup_state[1]
inverse_mass_matrix = at.where(
do_pick_left, left_inverse_mass_matrix, right_inverse_mass_matrix
)

right_mean, right_m2, right_sample_size = right_warmup_state[2]
left_mean, left_m2, left_sample_size = left_warmup_state[2]
mean = at.where(do_pick_left, left_mean, right_mean)
m2 = at.where(do_pick_left, left_m2, right_m2)
sample_size = at.where(do_pick_left, left_sample_size, right_sample_size)
da_state = ifelse(do_pick_left, left_da_state, right_da_state)
mm_state = ifelse(do_pick_left, left_mm_state, right_mm_state)
params = ifelse(do_pick_left, left_params, right_params)

return (
(step, logstepsize, logstepsize_avg, gradient_avg, mu),
inverse_mass_matrix,
(
mean,
m2,
sample_size,
),
)
return (da_state, mm_state), params

return init, update, final

Expand Down
13 changes: 5 additions & 8 deletions tests/test_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,12 @@ def logprob_fn(y: TensorVariable):
logprob = joint_logprob({Y_rv: y})
return logprob

def kernel_factory(inverse_mass_matrix: TensorVariable):
return nuts.new_kernel(srng, logprob_fn)

y_vv = Y_rv.clone()
kernel = nuts.new_kernel(srng, logprob_fn)
initial_state = nuts.new_state(y_vv, logprob_fn)

state, (step_size, inverse_mass_matrix), updates = window_adaptation.run(
kernel_factory, initial_state, num_steps=1000
kernel, initial_state, num_steps=1000
)

# Compile the warmup and execute to get a value for the step size and the
Expand All @@ -42,6 +40,7 @@ def kernel_factory(inverse_mass_matrix: TensorVariable):

assert final_state[0] != 3.0 # the chain has moved
assert np.ndim(step_size) == 0 # scalar step size
assert step_size != 1.0 # step size changed
assert step_size > 0.1 and step_size < 2 # stable range for the step size
assert np.ndim(inverse_mass_matrix) == 0 # scalar mass matrix
assert inverse_mass_matrix == pytest.approx(4, rel=1.0)
Expand All @@ -61,14 +60,12 @@ def logprob_fn(y: TensorVariable):
logprob = joint_logprob({Y_rv: y})
return logprob

def kernel_factory(inverse_mass_matrix: TensorVariable):
return nuts.new_kernel(srng, logprob_fn)

y_vv = Y_rv.clone()
kernel = nuts.new_kernel(srng, logprob_fn)
initial_state = nuts.new_state(y_vv, logprob_fn)

state, (step_size, inverse_mass_matrix), updates = window_adaptation.run(
kernel_factory, initial_state, num_steps=1000
kernel, initial_state, num_steps=1000
)

# Compile the warmup and execute to get a value for the step size and the
Expand Down

0 comments on commit ea21742

Please sign in to comment.