Skip to content

Commit

Permalink
Merge pull request #327 from n-kall/doc-updates
Browse files Browse the repository at this point in the history
Pareto diagnostic defaults and doc updates
  • Loading branch information
paul-buerkner authored Jan 9, 2024
2 parents c79a40d + 9a59fb5 commit 627d2e8
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 58 deletions.
9 changes: 9 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -223,10 +223,16 @@ S3method(order_draws,draws_list)
S3method(order_draws,draws_matrix)
S3method(order_draws,draws_rvars)
S3method(order_draws,rvar)
S3method(pareto_convergence_rate,default)
S3method(pareto_convergence_rate,rvar)
S3method(pareto_diags,default)
S3method(pareto_diags,rvar)
S3method(pareto_khat,default)
S3method(pareto_khat,rvar)
S3method(pareto_khat_threshold,default)
S3method(pareto_khat_threshold,rvar)
S3method(pareto_min_ss,default)
S3method(pareto_min_ss,rvar)
S3method(pareto_smooth,default)
S3method(pareto_smooth,rvar)
S3method(pillar_shaft,rvar)
Expand Down Expand Up @@ -455,8 +461,11 @@ export(ndraws)
export(niterations)
export(nvariables)
export(order_draws)
export(pareto_convergence_rate)
export(pareto_diags)
export(pareto_khat)
export(pareto_khat_threshold)
export(pareto_min_ss)
export(pareto_smooth)
export(quantile2)
export(r_scale)
Expand Down
101 changes: 76 additions & 25 deletions R/pareto_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,24 +181,23 @@ pareto_diags.rvar <- function(x, ...) {
#'
#' @template args-pareto
#' @param return_k (logical) Should the Pareto khat be included in
#' output? If `TRUE`, output will be a list containing of smoothed
#' draws and diagnostics. Default is `TRUE`.
#' output? If `TRUE`, output will be a list containing smoothed
#' draws and diagnostics, otherwise it will be a numeric of the
#' smoothed draws. Default is `FALSE`.
#' @param extra_diags (logical) Should extra Pareto khat diagnostics
#' be included in output? If `TRUE`, `min_ss`, `khat_threshold` and
#' `convergence_rate` for the estimated k value will be
#' returned. Default is `FALSE`.
#' @template args-methods-dots
#' @template ref-vehtari-paretosmooth-2022
#' @return Either a vector `x` of smoothed values or a named list
#' containing the vector `x` and a named list `diagnostics` containing Pareto smoothing
#' diagnostics:
#' * `khat`: estimated Pareto k shape parameter, and
#' optionally
#' * `min_ss`: minimum sample size for reliable Pareto
#' smoothed estimate
#' * `khat_threshold`: khat-threshold for reliable
#' containing the vector `x` and a named list `diagnostics`
#' containing Pareto smoothing diagnostics: * `khat`: estimated
#' Pareto k shape parameter, and optionally * `min_ss`: minimum
#' sample size for reliable Pareto smoothed estimate *
#' `khat_threshold`: khat-threshold for reliable Pareto smoothed
#' estimates * `convergence_rate`: Relative convergence rate for
#' Pareto smoothed estimates
#' * `convergence_rate`: Relative convergence rate for Pareto smoothed estimates
#'
#' @seealso [`pareto_khat`] for only calculating khat, and
#' [`pareto_diags`] for additional diagnostics.
Expand All @@ -213,7 +212,7 @@ pareto_smooth <- function(x, ...) UseMethod("pareto_smooth")

#' @rdname pareto_smooth
#' @export
pareto_smooth.rvar <- function(x, return_k = TRUE, extra_diags = FALSE, ...) {
pareto_smooth.rvar <- function(x, return_k = FALSE, extra_diags = FALSE, ...) {

if (extra_diags) {
return_k <- TRUE
Expand Down Expand Up @@ -253,9 +252,9 @@ pareto_smooth.default <- function(x,
tail = c("both", "right", "left"),
r_eff = 1,
ndraws_tail = NULL,
return_k = TRUE,
return_k = FALSE,
extra_diags = FALSE,
verbose = FALSE,
verbose = TRUE,
are_log_weights = FALSE,
...) {

Expand Down Expand Up @@ -370,6 +369,65 @@ pareto_smooth.default <- function(x,
return(out)
}

#' @rdname pareto_diags
#' @export
pareto_khat_threshold <- function(x, ...) {
UseMethod("pareto_khat_threshold")
}

#' @rdname pareto_diags
#' @export
pareto_khat_threshold.default <- function(x, ...) {
c(khat_threshold = ps_khat_threshold(length(x)))
}

#' @rdname pareto_diags
#' @export
pareto_khat_threshold.rvar <- function(x, ...) {
c(khat_threshold = ps_khat_threshold(ndraws(x)))
}

#' @rdname pareto_diags
#' @export
pareto_min_ss <- function(x, ...) {
UseMethod("pareto_min_ss")
}

#' @rdname pareto_diags
#' @export
pareto_min_ss.default <- function(x, ...) {
k <- pareto_khat(x)$k
c(min_ss = ps_min_ss(k))
}

#' @rdname pareto_diags
#' @export
pareto_min_ss.rvar <- function(x, ...) {
k <- pareto_khat(x)$k
c(min_ss = ps_min_ss(k))
}

#' @rdname pareto_diags
#' @export
pareto_convergence_rate <- function(x, ...) {
UseMethod("pareto_convergence_rate")
}

#' @rdname pareto_diags
#' @export
pareto_convergence_rate.default <- function(x, ...) {
k <- pareto_khat(x)$khat
c(convergence_rate = ps_convergence_rate(k, length(x)))
}

#' @rdname pareto_diags
#' @export
pareto_convergence_rate.rvar <- function(x, ...) {
k <- pareto_khat(x)
c(convergence_rate = ps_convergence_rate(k, ndraws(x)))
}


#' Pareto smooth tail
#' internal function to pareto smooth the tail of a vector
#' @noRd
Expand Down Expand Up @@ -493,7 +551,6 @@ ps_min_ss <- function(k, ...) {
out
}


#' Pareto-smoothing k-hat threshold
#'
#' Given sample size S computes khat threshold for reliable Pareto
Expand Down Expand Up @@ -561,26 +618,20 @@ pareto_k_diagmsg <- function(diags, are_weights = FALSE, ...) {
if (!are_weights) {

if (khat > 1) {
msg <- paste0(msg, "All estimates are unreliable. If the distribution of draws is bounded,\n",
"further draws may improve the estimates, but it is not possible to predict\n",
"whether any feasible sample size is sufficient.")
msg <- paste0(msg, " Mean does not exist, making empirical mean estimate of the draws not applicable.")
} else {
if (khat > khat_threshold) {
msg <- paste0(msg, "S is too small, and sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n")
} else {
msg <- paste0(msg, "To halve the RMSE, approximately ", round(2^(2 / convergence_rate), 1), " times bigger S is needed.\n")
msg <- paste0(msg, " Sample size is too small, for given Pareto k-hat. Sample size larger than ", round(min_ss, 0), " is needed for reliable results.\n")
}
if (khat > 0.7) {
msg <- paste0(msg, "Bias dominates RMSE, and the variance based MCSE is underestimated.\n")
msg <- paste0(msg, " Bias dominates when k-hat > 0.7, making empirical mean estimate of the Pareto-smoothed draws unreliable.\n")
}
}

} else {

if (khat > khat_threshold || khat > 0.7) {
msg <- paste0(msg, "Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n")
msg <- paste0(msg, " Pareto khat for weights is high (", round(khat, 1) ,"). This indicates a single or few weights dominate.\n", "Inference based on weighted draws will be unreliable.\n")
}
}
message(msg)
message("Pareto k-hat = ", round(khat, 2), ".", msg)
invisible(diags)
}
27 changes: 27 additions & 0 deletions man/pareto_diags.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

27 changes: 12 additions & 15 deletions man/pareto_smooth.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 16 additions & 18 deletions tests/testthat/test-pareto_smooth.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,24 @@ test_that("pareto_khat diagnostics messages are as expected", {
)

expect_message(pareto_k_diagmsg(diags),
paste0('To halve the RMSE, approximately 4.1 times bigger S is needed.'))
paste0("Pareto k-hat = 0.5.\n"))

diags$khat <- 0.6

expect_message(pareto_k_diagmsg(diags),
paste0('S is too small, and sample size larger than 10 is needed for reliable results.\n'))
paste0("Pareto k-hat = 0.6. Sample size is too small, for given Pareto k-hat. Sample size larger than 10 is needed for reliable results.\n"))

diags$khat <- 0.71
diags$khat_threshold <- 0.8

expect_message(pareto_k_diagmsg(diags),
paste0('To halve the RMSE, approximately 4.1 times bigger S is needed.\n', 'Bias dominates RMSE, and the variance based MCSE is underestimated.\n'))
paste0("Pareto k-hat = 0.71. Bias dominates when k-hat > 0.7, making empirical mean estimate of the Pareto-smoothed draws unreliable.\n"))


diags$khat <- 1.1

expect_message(pareto_k_diagmsg(diags),
paste0('All estimates are unreliable. If the distribution of draws is bounded,\n',
'further draws may improve the estimates, but it is not possible to predict\n',
'whether any feasible sample size is sufficient.'))
paste0("Pareto k-hat = 1.1. Mean does not exist, making empirical mean estimate of the draws not applicable.\n"))

})

Expand Down Expand Up @@ -131,8 +129,8 @@ test_that("pareto_khat functions work with matrix with chains", {
expect_equal(pareto_khat(tau_chains, ndraws_tail = 20),
pareto_khat(tau_nochains, ndraws_tail = 20))

ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20)
ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20)
ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20, return_k = TRUE)
ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20, return_k = TRUE)

expect_equal(as.numeric(ps_chains$x), as.numeric(ps_nochains$x))

Expand All @@ -159,22 +157,22 @@ test_that("pareto_khat functions work with rvars with and without chains", {
expect_equal(pareto_diags(tau_rvar_chains, ndraws_tail = 20),
pareto_diags(tau_rvar_nochains, ndraws_tail = 20))

ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20)
ps_rvar_chains <- pareto_smooth(tau_rvar_chains, ndraws_tail = 20)
ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20, return_k = TRUE)
ps_rvar_chains <- pareto_smooth(tau_rvar_chains, ndraws_tail = 20, return_k = TRUE)

ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20)
ps_rvar_nochains <- pareto_smooth(tau_rvar_nochains, ndraws_tail = 20)
ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20, return_k = TRUE)
ps_rvar_nochains <- pareto_smooth(tau_rvar_nochains, ndraws_tail = 20, return_k = TRUE)

expect_equal(ps_rvar_chains$x, rvar(ps_chains$x, with_chains = TRUE))

expect_equal(ps_rvar_nochains$x, rvar(ps_nochains$x))


ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20, extra_diags = TRUE)
ps_rvar_chains <- pareto_smooth(tau_rvar_chains, ndraws_tail = 20, extra_diags = TRUE)
ps_chains <- pareto_smooth(tau_chains, ndraws_tail = 20, extra_diags = TRUE, return_k = TRUE)
ps_rvar_chains <- pareto_smooth(tau_rvar_chains, ndraws_tail = 20, extra_diags = TRUE, return_k = TRUE)

ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20, extra_diags = TRUE)
ps_rvar_nochains <- pareto_smooth(tau_rvar_nochains, ndraws_tail = 20, extra_diags = TRUE)
ps_nochains <- pareto_smooth(tau_nochains, ndraws_tail = 20, extra_diags = TRUE, return_k = TRUE)
ps_rvar_nochains <- pareto_smooth(tau_rvar_nochains, ndraws_tail = 20, extra_diags = TRUE, return_k = TRUE)

expect_equal(ps_rvar_chains$x, rvar(ps_chains$x, with_chains = TRUE))

Expand All @@ -185,7 +183,7 @@ test_that("pareto_khat functions work with rvars with and without chains", {
test_that("pareto_smooth returns x with smoothed tail", {
tau <- extract_variable_matrix(example_draws(), "tau")

tau_smoothed <- pareto_smooth(tau, ndraws_tail = 10, tail = "right")$x
tau_smoothed <- pareto_smooth(tau, ndraws_tail = 10, tail = "right", return_k = TRUE)$x

expect_equal(sort(tau)[1:390], sort(tau_smoothed)[1:390])

Expand All @@ -197,7 +195,7 @@ test_that("pareto_smooth works for log_weights", {
w <- c(1:25, 1e3, 1e3, 1e3)
lw <- log(w)

ps <- pareto_smooth(lw, are_log_weights = TRUE, verbose = FALSE, ndraws_tail = 10)
ps <- pareto_smooth(lw, are_log_weights = TRUE, verbose = FALSE, ndraws_tail = 10, return_k = TRUE)

# only right tail is smoothed
expect_equal(ps$x[1:15], lw[1:15])
Expand Down

0 comments on commit 627d2e8

Please sign in to comment.