Skip to content

Commit

Permalink
fix geweke test
Browse files Browse the repository at this point in the history
  • Loading branch information
goldingn committed Nov 1, 2024
1 parent 900ba51 commit 99fdaf3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
28 changes: 22 additions & 6 deletions tests/testthat/helpers.R
Original file line number Diff line number Diff line change
Expand Up @@ -721,30 +721,46 @@ p_theta_greta <- function(
verbose = FALSE
)

# set up a progress bar and do a first increment
cli::cli_progress_bar("Geweke test iterations", total = niter)
cli::cli_progress_update()

# now loop through, sampling and updating x and returning theta
for (i in 2:niter) {

# update the progress bar
cli::cli_progress_update()

# sample x given theta
x <- p_x_bar_theta(theta[i - 1])

# put x in the data list
# replace x in the node
dag <- model$dag
target_name <- dag$tf_name(get_node(data))
x_array <- array(x, dim = c(1, dim(data)))
dag$tf_environment$data_list[[target_name]] <- x_array
x_node <- get_node(data)
x_node$value(as.matrix(x))

# put theta in the free state
# rewrite the log prob tf function, and the tf function for the posterior
# samples, now using this value of x (slow, but necessary in eager mode)
dag$tf_log_prob_function <- NULL
dag$define_tf_log_prob_function()
sampler <- attr(draws, "model_info")$samplers[[1]]
sampler$free_state <- as.matrix(theta[i - 1])
sampler$define_tf_evaluate_sample_batch()

# take anoteher sample
draws <- extra_samples(
draws,
n_samples = 1,
verbose = FALSE
)

# trace the sample
theta[i] <- tail(as.numeric(draws[[1]]), 1)

}

# kill the progress_bar
cli::cli_progress_done()

theta
}

Expand Down
4 changes: 2 additions & 2 deletions tests/testthat/test_posteriors_geweke.R
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,14 @@ test_that("samplers pass geweke tests", {
model = model,
data = x,
p_theta = p_theta,
p_x_bar_theta = p_x_bar_theta,
p_x_bar_theta = p_x_bar_theta
)

geweke_qq(geweke_hmc, title = "HMC Geweke test")

geweke_stat_hmc <- geweke_ks(geweke_hmc)

testthat::expect_gte(geweke_hmc_stat$p.value, 0.005)
testthat::expect_gte(geweke_stat_hmc$p.value, 0.005)

geweke_hmc_rwmh <- check_geweke(
sampler = rwmh(),
Expand Down

0 comments on commit 99fdaf3

Please sign in to comment.