Skip to content

Commit

Permalink
Replace get_prior_names with get_param_names (related to #3)
Browse files Browse the repository at this point in the history
  • Loading branch information
bodkan committed Dec 8, 2023
1 parent 26216a5 commit 9d4c914
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions R/plot_prior.R
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
plot_prior <- function(x, param = NULL, facets = FALSE, file = NULL,
replicates = 10000, geom = ggplot2::geom_density, ...) {
priors <- if (inherits(x, "demografr_abc.abc")) attr(x, "priors") else x
all_params <- get_prior_names(priors); names(priors) <- all_params
all_params <- get_param_names(priors); names(priors) <- all_params
subset_params <- subset_parameters(subset = param, all = all_params)
priors <- priors[subset_params]; unname(priors)

Expand All @@ -42,7 +42,7 @@ plot_prior <- function(x, param = NULL, facets = FALSE, file = NULL,
simulate_priors <- function(priors, replicates = 1000) {
if (!is.list(priors)) priors <- list(priors)

vars <- get_prior_names(priors)
vars <- get_param_names(priors)

samples_list <- lapply(
seq_along(priors), \(i)
Expand Down
8 changes: 4 additions & 4 deletions R/validate_abc.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ validate_abc <- function(model, priors, functions, observed,

prior_samples <- list()

prior_names <- get_prior_names(priors)
prior_names <- get_param_names(priors)

cat("Testing sampling of each prior parameter:\n")

Expand Down Expand Up @@ -110,13 +110,13 @@ validate_abc <- function(model, priors, functions, observed,

# first expand any generic "..." prior sampling expressions (if needed)
priors <- tryCatch(
expand_priors(model, priors, model_args),
expand_formulas(priors, model, model_args),
error = function(e) {
cat(" \u274C\n\n")
stop(e$message, call. = FALSE)
})
# prior names generated above have to be re-generated after templating
prior_names <- get_prior_names(priors)
prior_names <- get_param_names(priors)

missing_priors <- setdiff(prior_names, methods::formalArgs(model))
if (length(missing_priors) > 0) {
Expand All @@ -142,7 +142,7 @@ validate_abc <- function(model, priors, functions, observed,

cat("The model is a custom user-defined", script_engine, "script\n")

prior_names <- get_prior_names(priors)
prior_names <- get_param_names(priors)
template_priors <- grepl("\\.\\.\\.", prior_names)
if (any(template_priors)) {
cat(" \u274C\n\n")
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test-prior-expansion.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ test_that("Templated prior sampling expressions validate correctly", {

test_that("expand_fromulas() produces parameters equal to formal arguments of a model", {
expanded_priors <- expand_formulas(templated_priors, model)
expect_true(all(get_prior_names(expanded_priors) == names(formals(model))))
expect_true(all(get_param_names(expanded_priors) == names(formals(model))))
})

test_that("With the same seed, both sets of priors give the same tree sequence", {
Expand Down

0 comments on commit 9d4c914

Please sign in to comment.