From b4df84aef248497699b9eb38c5c2bcf1ee4d7f5e Mon Sep 17 00:00:00 2001 From: njtierney Date: Wed, 18 Dec 2024 09:44:23 +1000 Subject: [PATCH 1/3] add warmup information to sampler class and pass to print method --- R/greta_mcmc_list.R | 2 ++ R/inference.R | 8 ++++++-- R/inference_class.R | 2 ++ R/utils.R | 5 +++++ 4 files changed, 15 insertions(+), 2 deletions(-) diff --git a/R/greta_mcmc_list.R b/R/greta_mcmc_list.R index 5b049147..88980c87 100644 --- a/R/greta_mcmc_list.R +++ b/R/greta_mcmc_list.R @@ -62,6 +62,7 @@ window.greta_mcmc_list <- function(x, start, end, thin, ...) { #' @export print.greta_mcmc_list <- function(x, ..., n = 5){ + n_warmup <- n_warmup(x) n_chain <- coda::nchain(x) n_iter <- coda::niter(x) n_thin <- coda::thin(x) @@ -69,6 +70,7 @@ print.greta_mcmc_list <- function(x, ..., n = 5){ cli::cli_bullets( c( "*" = "Iterations = {n_iter}", + "*" = "Warmup = {n_warmup}", "*" = "Chains = {n_chain}", "*" = "Thinning = {n_thin}" ) diff --git a/R/inference.R b/R/inference.R index cb5e9f41..472c3de5 100644 --- a/R/inference.R +++ b/R/inference.R @@ -394,7 +394,10 @@ run_samplers <- function(samplers, greta_stash$trace_log_files <- trace_log_files greta_stash$percentage_log_files <- percentage_log_files greta_stash$progress_bar_log_files <- progress_bar_log_files - greta_stash$mcmc_info <- list(n_samples = n_samples) + greta_stash$mcmc_info <- list( + n_samples = n_samples, + warmup = warmup + ) } if (plan_is$parallel) { @@ -514,7 +517,8 @@ stashed_samples <- function() { model_info <- list( raw_draws = free_state_draws, samplers = samplers, - model = samplers[[1]]$model + model = samplers[[1]]$model, + warmup = samplers[[1]]$warmup ) values_draws <- as_greta_mcmc_list(values_draws, model_info) diff --git a/R/inference_class.R b/R/inference_class.R index 38591ff0..bb560a24 100644 --- a/R/inference_class.R +++ b/R/inference_class.R @@ -244,6 +244,7 @@ sampler <- R6Class( n_chains = 1, numerical_rejections = 0, thin = 1, + warmup = 1, # tuning information mean_accept_stat = 0.5, @@ -336,6 +337,7 @@ sampler <- R6Class( one_by_one, plan_is, n_cores, float_type, trace_batch_size, from_scratch = TRUE) { + self$warmup <- warmup self$thin <- thin dag <- self$model$dag diff --git a/R/utils.R b/R/utils.R index 0c10987a..34b92ab8 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1174,3 +1174,8 @@ are_initials <- function(x){ FUN.VALUE = logical(1) ) } + +n_warmup <- function(x){ + x_info <- attr(x, "model_info") + x_info$warmup +} From 4ca534b85dd9e91a57a122c468843739790920ea Mon Sep 17 00:00:00 2001 From: njtierney Date: Wed, 18 Dec 2024 13:18:29 +1000 Subject: [PATCH 2/3] add warmup info to relevant print tests --- tests/testthat/_snaps/greta_mcmc_list_class.md | 8 ++++++++ tests/testthat/_snaps/print_calculate.md | 1 + 2 files changed, 9 insertions(+) diff --git a/tests/testthat/_snaps/greta_mcmc_list_class.md b/tests/testthat/_snaps/greta_mcmc_list_class.md index d59586c3..baa06ac5 100644 --- a/tests/testthat/_snaps/greta_mcmc_list_class.md +++ b/tests/testthat/_snaps/greta_mcmc_list_class.md @@ -6,6 +6,7 @@ -- MCMC draws from greta ------------------------------------------------------- * Iterations = 10 + * Warmup = 10 * Chains = 4 * Thinning = 1 @@ -36,6 +37,7 @@ -- MCMC draws from greta ------------------------------------------------------- * Iterations = 20 + * Warmup = 20 * Chains = 4 * Thinning = 1 @@ -66,6 +68,7 @@ -- MCMC draws from greta ------------------------------------------------------- * Iterations = 20 + * Warmup = 20 * Chains = 4 * Thinning = 1 @@ -109,6 +112,7 @@ -- MCMC draws from greta ------------------------------------------------------- * Iterations = 20 + * Warmup = 20 * Chains = 4 * Thinning = 1 @@ -153,6 +157,7 @@ -- MCMC draws from greta ------------------------------------------------------- * Iterations = 20 + * Warmup = 20 * Chains = 4 * Thinning = 1 @@ -196,6 +201,7 @@ -- MCMC draws from greta ------------------------------------------------------- * Iterations = 2 + * Warmup = 2 * Chains = 4 * Thinning = 1 @@ -221,6 +227,7 @@ -- MCMC draws from greta ------------------------------------------------------- * Iterations = 2 + * Warmup = 2 * Chains = 4 * Thinning = 1 @@ -247,6 +254,7 @@ -- MCMC draws from greta ------------------------------------------------------- * Iterations = 2 + * Warmup = 2 * Chains = 4 * Thinning = 1 diff --git a/tests/testthat/_snaps/print_calculate.md b/tests/testthat/_snaps/print_calculate.md index 99b21ce6..201d879b 100644 --- a/tests/testthat/_snaps/print_calculate.md +++ b/tests/testthat/_snaps/print_calculate.md @@ -28,6 +28,7 @@ -- MCMC draws from greta ------------------------------------------------------- * Iterations = 10 + * Warmup = 10 * Chains = 4 * Thinning = 1 From 450dceeb64a1975acbaea16b82c4f1b0194ffc89 Mon Sep 17 00:00:00 2001 From: njtierney Date: Thu, 19 Dec 2024 10:14:46 +1000 Subject: [PATCH 3/3] bump NEWS --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index 4c8d954d..598d19fb 100644 --- a/NEWS.md +++ b/NEWS.md @@ -3,6 +3,7 @@ ## Changes - `log.greta_array()` function warns if user uses the `base` arg, as it was unused, (#597). +- Add warmup information to MCMC print method (#652, resolved by #755). # greta 0.5.0