diff --git a/R/checkers.R b/R/checkers.R index a43cc488..0d6841a8 100644 --- a/R/checkers.R +++ b/R/checkers.R @@ -1143,8 +1143,8 @@ check_positive_scalar <- function(x, } check_scalar <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ + arg = rlang::caller_arg(x), + call = rlang::caller_env()){ scalar <- is_scalar(x) if (!scalar){ cli::cli_abort( @@ -1159,8 +1159,8 @@ check_scalar <- function(x, } check_finite <- function(x, - arg = rlang::caller_arg(x), - call = rlang::caller_env()){ + arg = rlang::caller_arg(x), + call = rlang::caller_env()){ not_finite <- !is.finite(x) if (not_finite){ cli::cli_abort( @@ -1217,6 +1217,513 @@ check_numeric_length_1 <- function(x, } +check_both_2d <- function(x, + y, + x_arg = rlang::caller_arg(x), + y_arg = rlang::caller_arg(y), + call = rlang::caller_env()){ + if (!is_2d(x) | !is_2d(y)) { + cli::cli_abort( + message = c( + "Only two-dimensional {.cls greta_array}s can be matrix-multiplied", + "Dimensions for each are:", + "{.arg {x_arg}}: {.val {pretty_dim(x)}}", + "{.arg {y_arg}}: {.val {pretty_dim(y)}}" + ), + call = call + ) + } +} + +check_compatible_dimensions <- function(x, + y, + call = rlang::caller_env()){ + + incompatible_dimensions <- dim(x)[2] != dim(y)[1] + if (incompatible_dimensions) { + cli::cli_abort( + message = c( + "Incompatible dimensions: \\ + {.val {paste0(dim(x), collapse = 'x')}} vs \\ + {.val {paste0(dim(y), collapse = 'x')}}" + ), + call = call + ) + } +} + +check_distribution_support <- function(x, + arg = rlang::caller_arg(x), + call = rlang::caller_env()){ + n_supports <- length(unique(x)) + if (n_supports != 1) { + supports_text <- vapply( + X = unique(x), + FUN = paste, + collapse = " to ", + FUN.VALUE = character(1) + ) + + cli::cli_abort( + message = c( + "Component distributions must have the same support", + "However the component distributions have different support:", + "{.val {paste(supports_text, collapse = ' vs. ')}}" + ), + call = call + ) + } + +} + +check_not_multivariate_univariate <- function(x, + arg = rlang::caller_arg(x), + call = rlang::caller_env()){ + is_multivariate_and_univariate <- !all(x) & !all(!x) + if (is_multivariate_and_univariate) { + cli::cli_abort( + message = c( + "Cannot construct a mixture from a combination of multivariate and \\ + univariate distributions" + ), + call = call + ) + } +} + +check_not_discrete_continuous <- function(x, + name, + arg = rlang::caller_arg(x), + call = rlang::caller_env()){ + + is_discrete_and_continuous <- !all(x) & !all(!x) + if (is_discrete_and_continuous) { + cli::cli_abort( + message = c( + "Cannot construct a {name} distribution from a combination of discrete and \\ + continuous distributions" + ), + call = call + ) + } +} + +check_num_distributions <- function(n_distributions, + at_least, + name, + call = rlang::caller_env()){ + if (n_distributions < at_least) { + cli::cli_abort( + message = c( + "{.fun {name}} must be passed at least {.val {at_least}} \\ + distributions", + "The number of distributions passed was: {.val {n_distributions}}" + ), + call = call + ) + } + +} + +check_weights_dim <- function(weights_dim, + dim, + n_distributions, + arg = rlang::caller_arg(weights_dim), + call = rlang::caller_env()){ + + # weights should have n_distributions as the first dimension + if (weights_dim[1] != n_distributions) { + cli::cli_abort( + message = c( + "The first dimension of weights must be the number of \\ + distributions in the mixture ({.val {n_distributions}})", + "However it was {.val {weights_dim[1]}}" + ), + call = call + ) + } + + weights_extra_dim <- dim + n_extra_dim <- length(weights_extra_dim) + weights_last_dim_is_1 <- weights_extra_dim[n_extra_dim] == 1 + if (weights_last_dim_is_1) { + weights_extra_dim <- weights_extra_dim[-n_extra_dim] + } + + # remainder should be 1 or match weights_extra_dim + w_dim <- weights_dim[-1] + dim_1 <- length(w_dim) == 1 && w_dim == 1 + dim_same <- all(w_dim == weights_extra_dim) + incompatible_dims <- !(dim_1 | dim_same) + if (incompatible_dims) { + cli::cli_abort( + message = c( + "The dimension of weights must be either \\ + {.val {n_distributions}x1} or \\ + {.val {n_distributions}x{pretty_dim(dim)}}", + " but was {.val {pretty_dim(weights_dim)}}" + ), + call = call + ) + } + +} + +check_initials_are_named <- function(values, + call = rlang::caller_env()){ + names <- names(values) + initials_not_all_named <- length(names) != length(values) + if (initials_not_all_named) { + cli::cli_abort( + message = "All initial values must be named", + call = call + ) + } +} + +check_initials_are_numeric <- function(values, + call = rlang::caller_env()){ + are_numeric <- vapply(values, is.numeric, FUN.VALUE = FALSE) + if (!all(are_numeric)) { + cli::cli_abort( + message = "initial values must be numeric", + call = call + ) + } +} + +check_initial_values_match_chains <- function(initial_values, + n_chains, + call = rlang::caller_env()){ + n_sets <- length(initial_values) + + initial_values_do_not_match_chains <- n_sets != n_chains + if (initial_values_do_not_match_chains) { + cli::cli_abort( + message = c( + "The number of provided initial values does not match chains", + "{n_sets} set{?s} of initial values were provided, but there \\ + {cli::qty(n_chains)} {?is only/are} {n_chains} \\ + {cli::qty(n_chains)} chain{?s}" + ), + call = call + ) + } +} + +check_initial_values_correct_dim <- function(target_dims, + replacement_dims, + call = rlang::caller_env()){ + + same_dims <- mapply(identical, target_dims, replacement_dims) + + if (!all(same_dims)) { + cli::cli_abort( + message = "The initial values provided have different dimensions than \\ + the named {.cls greta_array}s", + call = call + ) + } + +} + +check_nodes_all_variable <- function(nodes, + call = rlang::caller_env()){ + types <- lapply(nodes, node_type) + are_variables <- are_identical(types, "variable") + + if (!all(are_variables)) { + cli::cli_abort( + "Initial values can only be set for variable {.cls greta_array}s" + ) + } + +} + +check_greta_arrays_associated_with_model <- function(tf_names, + call = rlang::caller_env()){ + missing_names <- is.na(tf_names) + if (any(missing_names)) { + bad <- names(tf_names)[missing_names] + cli::cli_abort( + c( + "Some {.cls greta_array}s passed to {.fun initials} are not \\ + associated with the model:", + "{.var {bad}}" + ) + ) + } +} + +check_not_data_greta_arrays <- function(model, + call = rlang::caller_env()){ + + # find variable names to label samples + target_greta_arrays <- model$target_greta_arrays + names <- names(target_greta_arrays) + + # check they're not data nodes, provide a useful error message if they are + are_data <- vapply( + target_greta_arrays, + function(x) is.data_node(get_node(x)), + FUN.VALUE = FALSE + ) + + if (any(are_data)) { + cli::cli_abort( + message = c( + "Data {.cls greta_array}s cannot be sampled", + "{.var {names[are_data]}} \\ + {?is a data/are data} {.cls greta_array}(s)" + ), + call = call + ) + } +} + +check_diagrammer_installed <- function(call = rlang::caller_env()){ + if (!is_DiagrammeR_installed()) { + cli::cli_abort( + message = c( + "The {.pkg DiagrammeR} package must be installed to plot \\ + {.pkg greta} models", + "Install {.pkg DiagrammeR} with:", + "{.code install.packages('DiagrammeR')}" + ), + call = call + ) + } +} + +check_unfixed_discrete_distributions <- function(dag, + call = rlang::caller_env()){ + + # check for unfixed discrete distributions + distributions <- dag$node_list[dag$node_types == "distribution"] + bad_nodes <- vapply( + X = distributions, + FUN = function(x) { + valid_target <- is.null(x$target) || is.data_node(x$target) + x$discrete && !valid_target + }, + FUN.VALUE = FALSE + ) + + if (any(bad_nodes)) { + cli::cli_abort( + "Model contains a discrete random variable that doesn't have a fixed \\ + value, so inference cannot be carried out." + ) + } +} + +check_greta_array_type <- function(x, + optional, + call = rlang::caller_env()){ + + if (!is.numeric(x) && !is.logical(x) && !optional){ + cli::cli_abort( + message = c( + "{.cls greta_array} must contain the same type", + "Cannot coerce {.cls matrix} to a {.cls greta_array} unless it is \\ + {.cls numeric}, {.cls integer} or {.cls logical}.", + "This {.cls matrix} had type: {.cls {class(as.vector(x))}}" + ), + call = call + ) + } +} + +check_greta_data_frame <- function(x, + optional, + arg = rlang::caller_arg(x), + call = rlang::caller_env()){ + classes <- vapply(x, class, "") + valid <- classes %in% c("numeric", "integer", "logical") + + array_has_different_types <- !optional & !all(valid) + if (array_has_different_types) { + invalid_types <- unique(classes[!valid]) + cli::cli_abort( + message = c( + "{.cls greta_array} must contain the same type", + "Cannot coerce a {.cls data.frame} to a {.cls greta_array} unless \\ + all columns are {.cls numeric, integer} or {.cls logical}.", + "This dataframe had columns of type: {.cls {invalid_types}}" + ), + call = call + ) + } +} + +check_ncols_match <- function(x1, + x2, + call = rlang::caller_env()){ + if (ncol(x1) != ncol(x2)) { + cli::cli_abort( + message = c( + "{.var x1} and {.var x2} must have the same number of columns", + "However {.code ncol(x1)} = {ncol(x1)} and \\ + {.code ncol(x2)} = {ncol(x2)}" + ), + call = call + ) + } +} + +check_fields_installed <- function(){ + fields_installed <- requireNamespace("fields", quietly = TRUE) + if (!fields_installed) { + cli::cli_abort( + c( + "{.pkg fields} package must be installed to use {.fun rdist} on greta \\ + arrays", + "Install {.pkg fields} with:", + "{.code install.packages('fields')}" + ) + ) + } +} + +check_2_by_1 <- function(x, + call = rlang::caller_env()){ + dim_x <- dim(x) + is_2_by_1 <- is_2d(x) && dim_x[2] == 1L + if (!is_2_by_1) { + cli::cli_abort( + message = c( + "{.var x} must be 2D {.cls greta_array} with one column", + "However {.var x} has dimensions {paste(dim_x, collapse = 'x')}" + ), + call = call + ) + } +} + + +check_transpose <- function(x, + call = rlang::caller_env()){ + if (x) { + cli::cli_abort( + message = "{.arg transpose} must be FALSE for {.cls greta_array}s", + call = call + ) + } +} + +check_x_matches_ncol <- function(x, + ncol_of, + x_arg = rlang::caller_arg(x), + ncol_of_arg = rlang::caller_arg(ncol_of), + call = rlang::caller_env()){ + + if (x != ncol(ncol_of)) { + cli::cli_abort( + message = "{.arg {x}} must equal {.code ncol({ncol_of_arg})} for \\ + {.cls greta_array}s", + call = call + ) + } +} + +check_stats_dim_matches_x_dim <- function(x, + margin, + stats, + call = rlang::caller_env()){ + stats_dim_matches_x_dim <- dim(x)[margin] == dim(stats)[1] + if (!stats_dim_matches_x_dim) { + cli::cli_abort( + message = c( + "The number of elements of {.var stats} does not match \\ + {.code dim(x)[MARGIN]}" + ), + call = call + ) + } +} + +# STATS must be a column array +check_is_column_array <- function(x, + arg = rlang::caller_arg(x), + call = rlang::caller_env()){ + + is_column_array <- is_2d(x) && dim(x)[2] == 1 + if (!is_column_array) { + cli::cli_abort( + message = c( + "{.arg {arg}} not a column vector array", + "{.arg {arg}} must be a column vector array", + "x" = "{.arg {arg}} has dimensions:", + "{.val {pretty_dim(x)}}" + ), + call = call + ) + } +} + +check_rows_equal <- function(a, + b, + a_arg = rlang::caller_arg(a), + b_arg = rlang::caller_arg(b), + call = rlang::caller_env()){ + + # b must have the right number of rows + rows_not_equal <- dim(b)[1] != dim(a)[1] + if (rows_not_equal) { + cli::cli_abort( + message = c( + "Number of rows not equal", + "x" = "{.arg {b_arg}} must have the same number of rows as \\ + {.arg {a_arg}} ({.val {dim(a)[1]}}), but has \\ + {.val {dim(b)[1]}} rows instead" + ), + call = call + ) + } +} + +check_final_dim <- function(dim, + thing, + call = rlang::caller_env()){ + # dimension of the free state version + n_dim <- length(dim) + last_dim <- dim[n_dim] + n_last_dim <- length(last_dim) + last_dim_gt_1 <- !last_dim > 1 + if (last_dim_gt_1) { + cli::cli_abort( + message = c( + "The final dimension of a {thing} must have more than \\ + one element", + "The final dimension has: {.val {n_last_dim} element{?s}}" + ), + call = call + ) + } + +} + +check_param_greta_array <- function(x, + arg = rlang::caller_arg(x), + call = rlang::caller_env()){ + if (is.greta_array(x)) { + cli::cli_abort( + message = "{.arg {arg}} must be fixed, they cannot be another \\ + {.cls greta_array}", + call = call + ) + } +} + +check_not_greta_array <- function(x, + arg = rlang::caller_arg(x), + call = rlang::caller_env()){ + if (is.greta_array(x)) { + cli::cli_abort( + "{.arg {arg}} cannot be a {.cls greta_array}" + ) + } +} checks_module <- module( check_tf_version, diff --git a/R/extract_replace_combine.R b/R/extract_replace_combine.R index c06412b4..75090170 100644 --- a/R/extract_replace_combine.R +++ b/R/extract_replace_combine.R @@ -649,7 +649,7 @@ diag.greta_array <- function(x = 1, nrow, ncol) { # check the rank isn't too high if (!is_2d(x)) { cli::cli_abort( - "cannot only extract the diagonal from a node with exactly two \\ + "Cannot only extract the diagonal from a node with exactly two \\ dimensions" ) } @@ -657,7 +657,7 @@ diag.greta_array <- function(x = 1, nrow, ncol) { is_square <- dim[1] != dim[2] if (is_square) { cli::cli_abort( - "diagonal elements can only be extracted from square matrices" + "Diagonal elements can only be extracted from square matrices" ) } diff --git a/R/functions.R b/R/functions.R index 2b761d60..02d5d513 100644 --- a/R/functions.R +++ b/R/functions.R @@ -399,19 +399,7 @@ solve.greta_array <- function(a, b, ...) { check_2d(a) # check the matrix is square - a_not_square <- dim(a)[1] != dim(a)[2] - if (a_not_square) { - - cli::cli_abort( - c( - "{.var a} is not square", - "x" = "{.var a} must be square, but has \\ - {dim(a)[1]} rows and \\ - {dim(a)[2]} columns" - ) - ) - - } + check_square(a) # if they just want the matrix inverse, do that if (missing(b)) { @@ -426,18 +414,8 @@ solve.greta_array <- function(a, b, ...) { } else { check_2d(b) - # b must have the right number of rows - rows_not_equal <- dim(b)[1] != dim(a)[1] - if (rows_not_equal) { - cli::cli_abort( - c( - "Number of rows not equal", - "x" = "{.var b} must have the same number of rows as {.var a} \\ - ({dim(a)[1]}), but has {dim(b)[1]} rows instead" - ) - ) - } + check_rows_equal(a, b) # ... and solve the linear equations result <- op("solve", a, b, @@ -755,30 +733,9 @@ sweep.greta_array <- function(x, } check_2d(x) - - # STATS must be a column array - is_column_array <- is_2d(stats) && dim(stats)[2] == 1 - if (!is_column_array) { - cli::cli_abort( - c( - "{.var stats} not a column vector array", - "{.var stats} must be a column vector array", - "x" = "{.var stats} has dimensions \\ - {paste(dim(stats), collapse = 'x')}" - ) - ) - } - + check_is_column_array(stats) # STATS must have the same dimension as the correct dim of x - stats_dim_matches_x_dim <- dim(x)[margin] == dim(stats)[1] - if (!stats_dim_matches_x_dim) { - cli::cli_abort( - c( - "The number of elements of {.var stats} does not match \\ - {.code dim(x)[MARGIN]}" - ) - ) - } + check_stats_dim_matches_x_dim(x, margin, stats) op("sweep", x, stats, operation_args = list(margin = margin, fun = fun), @@ -869,19 +826,8 @@ backsolve.greta_array <- function(r, x, upper.tri = TRUE, transpose = FALSE) { # nolint end - if (k != ncol(r)) { - cli::cli_abort( - c( - "{.var k} must equal {.code ncol(r)} for {.cls greta_array}s" - ) - ) - } - - if (transpose) { - cli::cli_abort( - "transpose must be FALSE for {.cls greta_array}s" - ) - } + check_x_matches_ncol(x = k, ncol_of = r) + check_transpose(transpose) op("backsolve", r, x, operation_args = list(lower = !upper.tri), @@ -921,17 +867,8 @@ forwardsolve.greta_array <- function(l, x, upper.tri = FALSE, transpose = FALSE) { # nolint end - if (k != ncol(l)) { - cli::cli_abort( - "{.var k} must equal {.code ncol(l)} for {.cls greta_array}s" - ) - } - - if (transpose) { - cli::cli_abort( - "transpose must be FALSE for {.cls greta_array}s" - ) - } + check_x_matches_ncol(x = k, ncol_of = l) + check_transpose(transpose) op("forwardsolve", l, x, operation_args = list(lower = !upper.tri), @@ -968,11 +905,7 @@ apply.greta_array <- function(X, MARGIN, # nolint end fun <- match.arg(FUN) - if (is.greta_array(MARGIN)) { - cli::cli_abort( - "MARGIN cannot be a greta array" - ) - } + check_not_greta_array(MARGIN) margin <- as.integer(MARGIN) @@ -1061,27 +994,14 @@ tapply.greta_array <- function(X, INDEX, index <- INDEX fun <- match.arg(FUN) - if (is.greta_array(index)) { - cli::cli_abort( - "INDEX cannot be a greta array" - ) - } + check_not_greta_array(INDEX) # convert index to successive integers starting at 0 groups <- sort(unique(index)) id <- match(index, groups) - 1L len <- length(groups) - dim_x <- dim(x) - is_2_by_1 <- is_2d(x) && dim_x[2] == 1L - if (!is_2_by_1) { - cli::cli_abort( - c( - "{.var x} must be 2D greta array with one column", - "However {.var x} has dimensions {paste(dim_x, collapse = 'x')}" - ) - ) - } + check_2_by_1(x) op("tapply", x, operation_args = list( @@ -1128,6 +1048,7 @@ eigen.greta_array <- function(x, symmetric, is_square <- dims[1] == dims[2] is_not_2d_square_symmetric <- !is_2d(x) | !is_square | !symmetric + if (is_not_2d_square_symmetric) { cli::cli_abort( "only two-dimensional, square, symmetric {.cls greta_array}s can be \\ @@ -1181,17 +1102,8 @@ rdist <- function(x1, x2 = NULL, compact = FALSE) { #' @export rdist.default <- function(x1, x2 = NULL, compact = FALSE) { # error nicely if they don't have fields installed - fields_installed <- requireNamespace("fields", quietly = TRUE) - if (!fields_installed) { - cli::cli_abort( - c( - "{.pkg fields} package must be installed to use {.fun rdist} on greta \\ - arrays", - "Install {.pkg fields} with:", - "{.code install.packages('fields')}" - ) - ) - } + check_fields_installed() + fields::rdist( x1 = x1, x2 = x2, @@ -1235,15 +1147,7 @@ rdist.greta_array <- function(x1, x2 = NULL, compact = FALSE) { # error if they have different number of columns. fields::rdist allows # different numbers of columns, takes the number of columns from x1,and # sometimes gives nonsense results - if (ncol(x1) != ncol(x2)) { - cli::cli_abort( - c( - "{.var x1} and {.var x2} must have the same number of columns", - "However {.code ncol(x1)} = {ncol(x1)} and \\ - {.code ncol(x2)} = {ncol(x2)}" - ) - ) - } + check_ncols_match(x1, x2) n2 <- nrow(x2) diff --git a/R/greta_array_class.R b/R/greta_array_class.R index 33e9ed19..ac71bede 100644 --- a/R/greta_array_class.R +++ b/R/greta_array_class.R @@ -30,21 +30,7 @@ as.greta_array.logical <- function(x, optional = FALSE, original_x = x, ...) { #' @export as.greta_array.data.frame <- function(x, optional = FALSE, original_x = x, ...) { - classes <- vapply(x, class, "") - valid <- classes %in% c("numeric", "integer", "logical") - - array_has_different_types <- !optional & !all(valid) - if (array_has_different_types) { - invalid_types <- unique(classes[!valid]) - cli::cli_abort( - c( - "{.cls greta_array} must contain the same type", - "Cannot coerce a {.cls data.frame} to a {.cls greta_array} unless \\ - all columns are {.cls numeric, integer} or {.cls logical}. This \\ - dataframe had columns of type: {.cls {invalid_types}}" - ) - ) - } + check_greta_data_frame(x, optional) as.greta_array.numeric(as.matrix(x), optional = optional, @@ -57,22 +43,12 @@ as.greta_array.data.frame <- function(x, optional = FALSE, # or numeric #' @export as.greta_array.matrix <- function(x, optional = FALSE, original_x = x, ...) { - ## TODO better abstract these if else clauses - if (!is.numeric(x)) { - if (is.logical(x)) { + + check_greta_array_type(x, optional) + + if (!is.numeric(x) && is.logical(x)) { x[] <- as.numeric(x[]) - } else if (!optional) { - cli::cli_abort( - c( - "{.cls greta_array} must contain the same type", - "Cannot coerce {.cls matrix} to a {.cls greta_array} unless it is \\ - {.cls numeric}, {.cls integer} or {.cls logical}. This \\ - {.cls matrix} had type:", - "{.cls {class(as.vector(x))}}" - ) - ) } - } as.greta_array.numeric(x, optional = optional, @@ -85,21 +61,11 @@ as.greta_array.matrix <- function(x, optional = FALSE, original_x = x, ...) { # or numeric #' @export as.greta_array.array <- function(x, optional = FALSE, original_x = x, ...) { - ## TODO Better abstract out these if statements - if (!optional & !is.numeric(x)) { - if (is.logical(x)) { + + check_greta_array_type(x, optional) + + if (!optional && !is.numeric(x) && is.logical(x)) { x[] <- as.numeric(x[]) - } else { - cli::cli_abort( - c( - "{.cls greta_array} must contain the same type", - "Cannot coerce {.cls array} to a {.cls greta_array} unless it is \\ - {.cls numeric}, {.cls integer} or {.cls logical}. This {.cls array} \\ - had type:", - "{.cls {class(as.vector(x))}}" - ) - ) - } } as.greta_array.numeric(x, diff --git a/R/greta_model_class.R b/R/greta_model_class.R index 1bbcd0fd..f1953122 100644 --- a/R/greta_model_class.R +++ b/R/greta_model_class.R @@ -155,24 +155,7 @@ model <- function(..., } } - # check for unfixed discrete distributions - distributions <- dag$node_list[dag$node_types == "distribution"] - bad_nodes <- vapply( - distributions, - function(x) { - valid_target <- is.null(x$target) || - is.data_node(x$target) - x$discrete && !valid_target - }, - FALSE - ) - - if (any(bad_nodes)) { - cli::cli_abort( - "model contains a discrete random variable that doesn't have a fixed \\ - value, so inference cannot be carried out" - ) - } + check_unfixed_discrete_distributions(dag) # define the TF graph # dag$define_tf() @@ -230,15 +213,7 @@ plot.greta_model <- function(x, y, colour = "#996bc7", ...) { - if (!is_DiagrammeR_installed()) { - cli::cli_abort( - c( - "the {.pkg DiagrammeR} package must be installed to plot greta models", - "install {.pkg DiagrammeR} with:", - "{.code install.packages('DiagrammeR')}" - ) - ) - } + check_diagrammer_installed() # set up graph dag_mat <- x$dag$adjacency_matrix diff --git a/R/inference.R b/R/inference.R index 3b6a1b62..c679c202 100644 --- a/R/inference.R +++ b/R/inference.R @@ -211,26 +211,7 @@ mcmc <- function( # check the trace batch size trace_batch_size <- check_trace_batch_size(trace_batch_size) - # find variable names to label samples - target_greta_arrays <- model$target_greta_arrays - names <- names(target_greta_arrays) - - # check they're not data nodes, provide a useful error message if they are - are_data <- vapply( - target_greta_arrays, - function(x) is.data_node(get_node(x)), - FUN.VALUE = FALSE - ) - - if (any(are_data)) { - cli::cli_abort( - c( - "data {.cls greta_array}s cannot be sampled", - "{.var {names[are_data]}} \\ - {?is a data/are data} {.cls greta_array}(s)" - ) - ) - } + check_not_data_greta_arrays(model) # get the dag containing the target nodes dag <- model$dag @@ -582,7 +563,7 @@ to_free <- function(node, data) { unsupported_error <- function() { cli::cli_abort( - "some provided initial values are outside the range of values their \\ + "Some provided initial values are outside the range of values their \\ variables can take" ) } @@ -643,17 +624,7 @@ parse_initial_values <- function(initials, dag) { FUN.VALUE = "" ) - missing_names <- is.na(tf_names) - if (any(missing_names)) { - bad <- names(tf_names)[missing_names] - cli::cli_abort( - c( - "some {.cls greta_array}s passed to {.fun initials} are not associated with \\ - the model:", - "{.var {bad}}" - ) - ) - } + check_greta_arrays_associated_with_model(tf_names) params <- dag$example_parameters(free = FALSE) idx <- match(tf_names, names(params)) @@ -667,25 +638,12 @@ parse_initial_values <- function(initials, dag) { # find the corresponding nodes and check they are variable nodes forward_names <- glue::glue("all_forward_{dag$node_tf_names}") nodes <- dag$node_list[match(tf_names, forward_names)] - types <- lapply(nodes, node_type) - are_variables <- are_identical(types, "variable") - if (!all(are_variables)) { - cli::cli_abort( - "initial values can only be set for variable {.cls greta_array}s" - ) - } + check_nodes_all_variable(nodes) target_dims <- lapply(params[idx], dim) replacement_dims <- lapply(initials, dim) - same_dims <- mapply(identical, target_dims, replacement_dims) - - if (!all(same_dims)) { - cli::cli_abort( - "the initial values provided have different dimensions than the named \\ - {.cls greta_array}s" - ) - } + check_initial_values_correct_dim(target_dims, replacement_dims) # convert the initial values to their free states inits_free <- mapply(to_free, nodes, initials, SIMPLIFY = FALSE) @@ -727,19 +685,7 @@ prep_initials <- function(initial_values, n_chains, dag) { are_initials <- vapply(initial_values, is.initials, FUN.VALUE = FALSE) if (all(are_initials)) { - n_sets <- length(initial_values) - - initial_values_do_not_match_chains <- n_sets != n_chains - if (initial_values_do_not_match_chains) { - cli::cli_abort( - c( - "the number of provided initial values does not match chains", - "{n_sets} set{?s} of initial values were provided, but there \\ - {cli::qty(n_chains)} {?is only/are} {n_chains} \\ - {cli::qty(n_chains)} chain{?s}" - ) - ) - } + check_initial_values_match_chains(initial_values, n_chains) } else { initial_values <- NULL } @@ -777,22 +723,12 @@ initials <- function(...) { values <- list(...) names <- names(values) - initials_not_all_named <- length(names) != length(values) - if (initials_not_all_named) { - cli::cli_abort( - "all initial values must be named" - ) - } + check_initials_are_named(values) # coerce to greta-array-like shape values <- lapply(values, as_2d_array) - are_numeric <- vapply(values, is.numeric, FUN.VALUE = FALSE) - if (!all(are_numeric)) { - cli::cli_abort( - "initial values must be numeric" - ) - } + check_initials_are_numeric(values) class(values) <- c("initials", class(values)) values diff --git a/R/inference_class.R b/R/inference_class.R index 37f40e5b..c25df260 100644 --- a/R/inference_class.R +++ b/R/inference_class.R @@ -53,14 +53,14 @@ inference <- R6Class( if (file.exists(self$trace_log_file)) { # Append write.table(last_burst_values, self$trace_log_file, - append = TRUE, - row.names = FALSE, col.names = FALSE + append = TRUE, + row.names = FALSE, col.names = FALSE ) } else { # Create file write.table(last_burst_values, self$trace_log_file, - append = FALSE, - row.names = FALSE, col.names = TRUE + append = FALSE, + row.names = FALSE, col.names = TRUE ) } }, @@ -103,34 +103,43 @@ inference <- R6Class( attempts <- attempts + 1 } - if (!valid) { - cli::cli_abort( - message = c( - "Could not find reasonable starting values after \\ - {attempts} attempts.", - "Please specify initial values manually via the \\ - {.arg initial_values} argument" - ), - call = call - ) - } + self$check_reasonable_starting_values(valid, attempts) + } else { # if they were all provided, check they can be be used valid <- self$valid_parameters(inits) - if (!valid) { - cli::cli_abort( - c( - "The log density could not be evaluated at these initial values", - "Try using these initials as the values argument in \\ + self$check_valid_parameters(valid) + + } + + inits + }, + + check_reasonable_starting_values = function(valid, attempts){ + if (!valid) { + cli::cli_abort( + message = c( + "Could not find reasonable starting values after \\ + {attempts} attempts.", + "Please specify initial values manually via the \\ + {.arg initial_values} argument" + ) + ) + } + }, + + check_valid_parameters = function(valid){ + if (!valid) { + cli::cli_abort( + c( + "The log density could not be evaluated at these initial values", + "Try using these initials as the {.arg values} argument in \\ {.fun calculate} to see what values of subsequent \\ {.cls greta_array}s these initial values lead to." - ) ) - } + ) } - - inits }, # check and set a list of initial values @@ -152,11 +161,11 @@ inference <- R6Class( tf_parameters <- fl(array( data = parameters, dim = c(1, length(parameters)) - )) + )) ld <- lapply( dag$tf_log_prob_function(tf_parameters), as.numeric - ) + ) is.finite(ld$adjusted) && is.finite(ld$unadjusted) }, @@ -176,9 +185,9 @@ inference <- R6Class( # append the free state trace for each chain self$traced_free_state <- mapply(rbind, - self$traced_free_state, - self$last_burst_free_states, - SIMPLIFY = FALSE + self$traced_free_state, + self$last_burst_free_states, + SIMPLIFY = FALSE ) } @@ -189,9 +198,9 @@ inference <- R6Class( self$write_trace_to_log_file(last_burst_values) } self$traced_values <- mapply(rbind, - self$traced_values, - last_burst_values, - SIMPLIFY = FALSE + self$traced_values, + last_burst_values, + SIMPLIFY = FALSE ) } }, @@ -349,13 +358,13 @@ sampler <- R6Class( # create these objects if needed if (from_scratch) { self$traced_free_state <- replicate(self$n_chains, - matrix(NA, 0, self$n_free), - simplify = FALSE + matrix(NA, 0, self$n_free), + simplify = FALSE ) self$traced_values <- replicate(self$n_chains, - matrix(NA, 0, self$n_traced), - simplify = FALSE + matrix(NA, 0, self$n_traced), + simplify = FALSE ) } @@ -379,8 +388,8 @@ sampler <- R6Class( # split up warmup iterations into bursts of sampling burst_lengths <- self$burst_lengths(warmup, - ideal_burst_size, - warmup = TRUE + ideal_burst_size, + warmup = TRUE ) completed_iterations <- cumsum(burst_lengths) @@ -405,23 +414,23 @@ sampler <- R6Class( # update the progress bar/percentage log iterate_progress_bar(pb_warmup, - it = completed_iterations[burst], - rejects = self$numerical_rejections, - chains = self$n_chains, - file = self$pb_file + it = completed_iterations[burst], + rejects = self$numerical_rejections, + chains = self$n_chains, + file = self$pb_file ) self$write_percentage_log(warmup, - completed_iterations[burst], - stage = "warmup" + completed_iterations[burst], + stage = "warmup" ) } } # scrub the free state trace and numerical rejections self$traced_free_state <- replicate(self$n_chains, - matrix(NA, 0, self$n_free), - simplify = FALSE + matrix(NA, 0, self$n_free), + simplify = FALSE ) self$numerical_rejections <- 0 } @@ -464,15 +473,15 @@ sampler <- R6Class( # update the progress bar/percentage log iterate_progress_bar(pb_sampling, - it = completed_iterations[burst], - rejects = self$numerical_rejections, - chains = self$n_chains, - file = self$pb_file + it = completed_iterations[burst], + rejects = self$numerical_rejections, + chains = self$n_chains, + file = self$pb_file ) self$write_percentage_log(n_samples, - completed_iterations[burst], - stage = "sampling" + completed_iterations[burst], + stage = "sampling" ) } } @@ -519,8 +528,8 @@ sampler <- R6Class( # chain dimension trace_values = function(trace_batch_size) { self$traced_values <- lapply(self$traced_free_state, - self$model$dag$trace_values, - trace_batch_size = trace_batch_size + self$model$dag$trace_values, + trace_batch_size = trace_batch_size ) }, @@ -656,7 +665,7 @@ sampler <- R6Class( sampler_thin, sampler_param_vec # pass values through - ) { + ) { dag <- self$model$dag tfe <- dag$tf_environment @@ -683,20 +692,20 @@ sampler <- R6Class( # Need to work out how to get sampler_batch() to run as a TF function. # To do that we need to work out how to get the free state - sampler_batch <- tfp$mcmc$sample_chain( - num_results = tf$math$floordiv(sampler_burst_length, sampler_thin), - current_state = free_state, - kernel = sampler_kernel, - trace_fn = function(current_state, kernel_results) { - kernel_results - }, - num_burnin_steps = tf$constant(0L, dtype = tf$int32), - num_steps_between_results = sampler_thin, - parallel_iterations = 1L - ) - return( - sampler_batch - ) + sampler_batch <- tfp$mcmc$sample_chain( + num_results = tf$math$floordiv(sampler_burst_length, sampler_thin), + current_state = free_state, + kernel = sampler_kernel, + trace_fn = function(current_state, kernel_results) { + kernel_results + }, + num_burnin_steps = tf$constant(0L, dtype = tf$int32), + num_steps_between_results = sampler_thin, + parallel_iterations = 1L + ) + return( + sampler_batch + ) }, # run a burst of the sampler @@ -715,7 +724,7 @@ sampler <- R6Class( # sampler_values <- list( # # TF1/2 check - # do we need free state here anymore? + # do we need free state here anymore? # free_state = self$free_state, # sampler_burst_length = as.integer(n_samples), # sampler_thin = as.integer(thin) @@ -875,21 +884,21 @@ hmc_sampler <- R6Class( dtype = tf$float64 ) # TF1/2 check - # where is "free_state" pulled from, given that it is the - # argument to this function, "generate_log_prob_function" ? + # where is "free_state" pulled from, given that it is the + # argument to this function, "generate_log_prob_function" ? # log probability function # build the kernel # nolint start - sampler_kernel <- tfp$mcmc$HamiltonianMonteCarlo( - target_log_prob_fn = dag$tf_log_prob_function_adjusted, - step_size = hmc_step_sizes, - num_leapfrog_steps = hmc_l - ) + sampler_kernel <- tfp$mcmc$HamiltonianMonteCarlo( + target_log_prob_fn = dag$tf_log_prob_function_adjusted, + step_size = hmc_step_sizes, + num_leapfrog_steps = hmc_l + ) return( sampler_kernel - ) + ) # nolint end }, sampler_parameter_values = function() { @@ -935,8 +944,8 @@ rwmh_sampler <- R6Class( tfe <- dag$tf_environment tfe$rwmh_proposal <- switch(self$parameters$proposal, - normal = tfp$mcmc$random_walk_normal_fn, - uniform = tfp$mcmc$random_walk_uniform_fn + normal = tfp$mcmc$random_walk_normal_fn, + uniform = tfp$mcmc$random_walk_uniform_fn ) # TF1/2 check @@ -946,32 +955,32 @@ rwmh_sampler <- R6Class( # tfe$log_prob_fun <- dag$generate_log_prob_function() # tensors for sampler parameters - # rwmh_epsilon <- tf$compat$v1$placeholder(dtype = tf_float()) + # rwmh_epsilon <- tf$compat$v1$placeholder(dtype = tf_float()) # need to pass in the value for this placeholder as a matrix (shape(n, 1)) - # rwmh_diag_sd <- tf$compat$v1$placeholder( - # dtype = tf_float(), - # # TF1/2 check - # again what do we with with `free_state`? - # shape = shape(dim(free_state)[[2]], 1) - # ) + # rwmh_diag_sd <- tf$compat$v1$placeholder( + # dtype = tf_float(), + # # TF1/2 check + # again what do we with with `free_state`? + # shape = shape(dim(free_state)[[2]], 1) + # ) # but it step_sizes must be a vector (shape(n, )), so reshape it - rwmh_step_sizes <- tf$reshape( - rwmh_epsilon * (rwmh_diag_sd / tf$reduce_sum(rwmh_diag_sd)), - # TF1/2 check - # what are we to do about `free_state` here? - shape = shape(free_state_size) - ) + rwmh_step_sizes <- tf$reshape( + rwmh_epsilon * (rwmh_diag_sd / tf$reduce_sum(rwmh_diag_sd)), + # TF1/2 check + # what are we to do about `free_state` here? + shape = shape(free_state_size) + ) - new_state_fn <- tfe$rwmh_proposal(scale = rwmh_step_sizes) + new_state_fn <- tfe$rwmh_proposal(scale = rwmh_step_sizes) # build the kernel # nolint start - sampler_kernel <- tfp$mcmc$RandomWalkMetropolis( - target_log_prob_fn = dag$tf_log_prob_function_adjusted, - new_state_fn = new_state_fn - ) + sampler_kernel <- tfp$mcmc$RandomWalkMetropolis( + target_log_prob_fn = dag$tf_log_prob_function_adjusted, + new_state_fn = new_state_fn + ) return( sampler_kernel ) @@ -1011,11 +1020,11 @@ slice_sampler <- R6Class( # build the kernel # nolint start - sampler_kernel <- tfp$mcmc$SliceSampler( - target_log_prob_fn = dag$tf_log_prob_function_adjusted, - step_size = fl(1), - max_doublings = slice_max_doublings - ) + sampler_kernel <- tfp$mcmc$SliceSampler( + target_log_prob_fn = dag$tf_log_prob_function_adjusted, + step_size = fl(1), + max_doublings = slice_max_doublings + ) return( sampler_kernel diff --git a/R/joint.R b/R/joint.R index 3c86a8e9..39663e98 100644 --- a/R/joint.R +++ b/R/joint.R @@ -53,14 +53,7 @@ joint_distribution <- R6Class( initialize = function(dots, dim) { n_distributions <- length(dots) - if (n_distributions < 2) { - cli::cli_abort( - c( - "{.fun joint} must be passed at least two distributions", - "The number of distributions passed was {n_distributions}" - ) - ) - } + check_num_distributions(n_distributions, at_least = 2, name = "joint") # check the dimensions of the variables in dots single_dim <- do.call(check_dims, c(dots, target_dim = dim)) @@ -95,13 +88,8 @@ joint_distribution <- R6Class( # check the distributions are all either discrete or continuous discrete <- vapply(distribs, member, "discrete", FUN.VALUE = FALSE) - is_discrete_and_continuous <- !all(discrete) & !all(!discrete) - if (is_discrete_and_continuous) { - cli::cli_abort( - "cannot construct a joint distribution from a combination of \\ - discrete and continuous distributions" - ) - } + check_not_discrete_continuous(discrete, "joint") + n_components <- length(dot_nodes) # work out the support of the resulting distribution, and add as the diff --git a/R/mixture.R b/R/mixture.R index ef69ab23..55dd4e73 100644 --- a/R/mixture.R +++ b/R/mixture.R @@ -83,14 +83,7 @@ mixture_distribution <- R6Class( initialize = function(dots, weights, dim) { n_distributions <- length(dots) - if (n_distributions < 2) { - cli::cli_abort( - c( - "{.fun mixture} must be passed at least two distributions", - "The number of distributions passed was: {.val {n_distributions}}" - ) - ) - } + check_num_distributions(n_distributions, at_least = 2, name = "mixture") # check the dimensions of the variables in dots dim <- do.call(check_dims, c(dots, target_dim = dim)) @@ -104,42 +97,10 @@ mixture_distribution <- R6Class( self$weights_is_log <- TRUE } - # weights should have n_distributions as the first dimension - if (weights_dim[1] != n_distributions) { - cli::cli_abort( - c( - "the first dimension of weights must be the number of \\ - distributions in the mixture ({.val {n_distributions}})", - "However it was {.val {weights_dim[1]}}" - ) - ) - } - # drop a trailing 1 from dim, so user doesn't need to deal with it # Ugh, need to get rid of column vector thing soon. # TODO get rid of column vector thing? - weights_extra_dim <- dim - n_extra_dim <- length(weights_extra_dim) - weights_last_dim_is_1 <- weights_extra_dim[n_extra_dim] == 1 - if (weights_last_dim_is_1) { - weights_extra_dim <- weights_extra_dim[-n_extra_dim] - } - - # remainder should be 1 or match weights_extra_dim - w_dim <- weights_dim[-1] - dim_1 <- length(w_dim) == 1 && w_dim == 1 - dim_same <- all(w_dim == weights_extra_dim) - incompatible_dims <- !(dim_1 | dim_same) - if (incompatible_dims) { - cli::cli_abort( - c( - "the dimension of weights must be either \\ - {.val {n_distributions}x1} or \\ - {.val {n_distributions}x{paste(dim, collapse = 'x')}}", - " but was {.val {paste(weights_dim, collapse = 'x')}}" - ) - ) - } + check_weights_dim(weights_dim, dim, n_distributions) dot_nodes <- lapply(dots, get_node) @@ -153,13 +114,7 @@ mixture_distribution <- R6Class( # check the distributions are all either discrete or continuous discrete <- vapply(distribs, member, "discrete", FUN.VALUE = logical(1)) - is_discrete_and_continuous <- !all(discrete) & !all(!discrete) - if (is_discrete_and_continuous) { - cli::cli_abort( - "cannot construct a mixture from a combination of discrete and \\ - continuous distributions" - ) - } + check_not_discrete_continuous(discrete, name = "mixture") # check the distributions are all either multivariate or univariate multivariate <- vapply(distribs, @@ -168,13 +123,7 @@ mixture_distribution <- R6Class( FUN.VALUE = logical(1) ) - is_multivariate_and_univariate <- !all(multivariate) & !all(!multivariate) - if (is_multivariate_and_univariate) { - cli::cli_abort( - "cannot construct a mixture from a combination of multivariate and \\ - univariate distributions" - ) - } + check_not_multivariate_univariate(multivariate) # ensure the support and bounds of each of the distributions is the same truncations <- lapply(distribs, member, "truncation") @@ -184,23 +133,7 @@ mixture_distribution <- R6Class( supports <- bounds supports[truncated] <- truncations[truncated] - n_supports <- length(unique(supports)) - if (n_supports != 1) { - supports_text <- vapply( - X = unique(supports), - FUN = paste, - collapse = " to ", - FUN.VALUE = character(1) - ) - - cli::cli_abort( - c( - "component distributions must have the same support", - "However the component distributions have different support:", - "{.val {paste(supports_text, collapse = ' vs. ')}}" - ) - ) - } + check_distribution_support(supports) # get the maximal bounds for all component distributions bounds <- c( diff --git a/R/node_class.R b/R/node_class.R index 2241d50e..6b54683f 100644 --- a/R/node_class.R +++ b/R/node_class.R @@ -220,7 +220,7 @@ node <- R6Class( # check it if (!is.distribution_node(distribution)) { cli::cli_abort( - "invalid distribution" + "Invalid distribution" ) } diff --git a/R/operators.R b/R/operators.R index 7f1bd7c8..a0dca594 100644 --- a/R/operators.R +++ b/R/operators.R @@ -157,28 +157,8 @@ NULL #' @export `%*%.greta_array` <- function(x, y) { # nolint - # check they're matrices - if (!is_2d(x) | !is_2d(y)) { - cli::cli_abort( - c( - "only two-dimensional {.cls greta_array}s can be matrix-multiplied", - "dimensions recorded were {dim(x)}" - ) - ) - } - - # check the dimensions match - # check_incompatible_dimensions(x, y) - incompatible_dimensions <- dim(x)[2] != dim(y)[1] - if (incompatible_dimensions) { - cli::cli_abort( - c( - "incompatible dimensions: \\ - {.val {paste0(dim(x), collapse = 'x')}} vs \\ - {.val {paste0(dim(y), collapse = 'x')}}" - ) - ) - } + check_both_2d(x,y) + check_compatible_dimensions(x, y) op("matrix multiply", x, y, dim = c(nrow(x), ncol(y)), diff --git a/R/optimiser_class.R b/R/optimiser_class.R index 695dc090..38aa1c42 100644 --- a/R/optimiser_class.R +++ b/R/optimiser_class.R @@ -161,24 +161,33 @@ tf_optimiser <- R6Class( # The objective value can reach numerical overflow, so we error and # suggest changing initial values or changing sampler, e.g., `adam` - if (!is.finite(obj_numeric)){ - cli::cli_abort( - c( - "Detected numerical overflow during optimisation", - "Please try one of the following:", - "i" = "Using different initial values", - "i" = "Using another optimiser. (E.g., instead of \\ - {.fun {self$name}}, try {.fun adam})" - ) - ) - } + self$check_numerical_overflow(obj_numeric) self$diff <- abs(self$old_obj - obj_numeric) self$old_obj <- obj_numeric } tfe$free_state <- free_state } + }, + + check_numerical_overflow = function(x, + arg = rlang::caller_arg(x), + call = rlang::caller_env()){ + if (!is.finite(x)){ + cli::cli_abort( + message = c( + "Detected numerical overflow during optimisation", + "Please try one of the following:", + "i" = "Using different initial values", + "i" = "Using another optimiser. (E.g., instead of \\ + {.fun {self$name}}, try {.fun adam})" + ), + call = call + ) + } } + + ) ) diff --git a/R/probability_distributions.R b/R/probability_distributions.R index f9f8cb9d..cba0eeb3 100644 --- a/R/probability_distributions.R +++ b/R/probability_distributions.R @@ -5,13 +5,8 @@ uniform_distribution <- R6Class( min = NA, max = NA, initialize = function(min, max, dim) { - if (is.greta_array(min) | is.greta_array(max)) { - cli::cli_abort( - "{.arg min} and {.arg max} must be fixed, they cannot be another \\ - greta array" - ) - } - + check_param_greta_array(min) + check_param_greta_array(max) check_numeric_length_1(min) check_numeric_length_1(max) check_finite(min) diff --git a/R/utils.R b/R/utils.R index 6dae1802..2c0468e8 100644 --- a/R/utils.R +++ b/R/utils.R @@ -1184,3 +1184,4 @@ outside_version_range <- function(provided, range) { outside_range } +pretty_dim <- function(x) paste0(dim(x), collapse = "x") diff --git a/R/variable.R b/R/variable.R index 7d0c1fef..9c9f5e93 100644 --- a/R/variable.R +++ b/R/variable.R @@ -45,13 +45,8 @@ #' } variable <- function(lower = -Inf, upper = Inf, dim = NULL) { check_tf_version("error") - - if (is.greta_array(lower) | is.greta_array(upper)) { - cli::cli_abort( - "{.arg lower} and {.arg upper} must be fixed, they cannot be another \\ - {.cls greta_array}" - ) - } + check_param_greta_array(lower) + check_param_greta_array(upper) node <- variable_node$new(lower, upper, dim) as.greta_array(node) @@ -144,17 +139,10 @@ simplex_variable <- function(dim) { } dim <- check_dims(target_dim = dim) - - # dimension of the free state version n_dim <- length(dim) last_dim <- dim[n_dim] - if (!last_dim > 1) { - cli::cli_abort( - "the final dimension of a simplex variable must have more than one \\ - element", - "The final dimension has: {.val {length(last_dim)} elements}" - ) - } + # dimension of the free state version + check_final_dim(dim, thing = "simplex variable") raw_dim <- dim raw_dim[n_dim] <- last_dim - 1 @@ -192,17 +180,7 @@ ordered_variable <- function(dim) { dim <- check_dims(target_dim = dim) - # dimension of the free state version - n_dim <- length(dim) - last_dim <- dim[n_dim] - - if (!last_dim > 1) { - cli::cli_abort( - "the final dimension of an ordered variable must have more than \\ - one element", - "the final dimension has: {.val {length(last_dim)} elements}" - ) - } + check_final_dim(dim, thing = "ordered variable") # create variable node node <- vble(truncation = c(-Inf, Inf), dim = dim) diff --git a/tests/testthat/_snaps/as_data.md b/tests/testthat/_snaps/as_data.md index 91978394..594e87ba 100644 --- a/tests/testthat/_snaps/as_data.md +++ b/tests/testthat/_snaps/as_data.md @@ -1,62 +1,116 @@ # as_data errors informatively - Object cannot be coerced to - Objects of class cannot be coerced to a + Code + as_data(NULL) + Condition + Error in `as.greta_array()`: + ! Object cannot be coerced to + Objects of class cannot be coerced to a --- - Object cannot be coerced to - Objects of class cannot be coerced to a + Code + as_data(list()) + Condition + Error in `as.greta_array()`: + ! Object cannot be coerced to + Objects of class cannot be coerced to a --- - Object cannot be coerced to - Objects of class cannot be coerced to a + Code + as_data(environment()) + Condition + Error in `as.greta_array()`: + ! Object cannot be coerced to + Objects of class cannot be coerced to a --- - Object cannot be coerced to - Objects of class cannot be coerced to a + Code + as_data(cha_vec) + Condition + Error in `as.greta_array()`: + ! Object cannot be coerced to + Objects of class cannot be coerced to a --- - must contain the same type - Cannot coerce to a unless it is , or . This had type: - + Code + as_data(cha_mat) + Condition + Error in `as.greta_array()`: + ! must contain the same type + Cannot coerce to a unless it is , or . + This had type: --- - must contain the same type - Cannot coerce to a unless it is , or . This had type: - + Code + as_data(cha_arr) + Condition + Error in `as.greta_array()`: + ! must contain the same type + Cannot coerce to a unless it is , or . + This had type: --- - must contain the same type - Cannot coerce a to a unless all columns are or . This dataframe had columns of type: + Code + as_data(cha_df) + Condition + Error in `as.greta_array()`: + ! must contain the same type + Cannot coerce a to a unless all columns are or . + This dataframe had columns of type: --- - must contain the same type - Cannot coerce a to a unless all columns are or . This dataframe had columns of type: + Code + as_data(cha_df2) + Condition + Error in `as.greta_array()`: + ! must contain the same type + Cannot coerce a to a unless all columns are or . + This dataframe had columns of type: --- - must not contain missing or infinite values + Code + as_data(arr_inf) + Condition + Error in `as.greta_array.numeric()`: + ! must not contain missing or infinite values --- - must not contain missing or infinite values + Code + as_data(arr_minf) + Condition + Error in `as.greta_array.numeric()`: + ! must not contain missing or infinite values --- - must not contain missing or infinite values + Code + as_data(arr_na) + Condition + Error in `as.greta_array.numeric()`: + ! must not contain missing or infinite values --- - cannot coerce a non-data to data + Code + as_data(stoch) + Condition + Error in `as_data()`: + ! cannot coerce a non-data to data --- - cannot coerce a non-data to data + Code + as_data(op) + Condition + Error in `as_data()`: + ! cannot coerce a non-data to data diff --git a/tests/testthat/_snaps/diagrammer-installed.md b/tests/testthat/_snaps/diagrammer-installed.md index f63e660a..48486402 100644 --- a/tests/testthat/_snaps/diagrammer-installed.md +++ b/tests/testthat/_snaps/diagrammer-installed.md @@ -1,6 +1,10 @@ # DiagrammeR installation is checked - the DiagrammeR package must be installed to plot greta models - install DiagrammeR with: - `install.packages('DiagrammeR')` + Code + plot(m) + Condition + Error in `plot()`: + ! The DiagrammeR package must be installed to plot greta models + Install DiagrammeR with: + `install.packages('DiagrammeR')` diff --git a/tests/testthat/_snaps/functions.md b/tests/testthat/_snaps/functions.md index a0767523..a40cdaa3 100644 --- a/tests/testthat/_snaps/functions.md +++ b/tests/testthat/_snaps/functions.md @@ -70,8 +70,8 @@ solve(a, a) Condition Error in `solve()`: - ! `a` is not square - x `a` must be square, but has 5 rows and 25 columns + ! Not 2D square greta array + x expected a 2D square greta array, but object `x` had dimension: 5x25 --- @@ -79,8 +79,8 @@ solve(a) Condition Error in `solve()`: - ! `a` is not square - x `a` must be square, but has 5 rows and 25 columns + ! Not 2D square greta array + x expected a 2D square greta array, but object `x` had dimension: 5x25 --- @@ -124,7 +124,8 @@ Error in `sweep()`: ! `stats` not a column vector array `stats` must be a column vector array - x `stats` has dimensions 1x5 + x `stats` has dimensions: + "1x5" --- @@ -170,28 +171,28 @@ # forwardsolve and backsolve error as expected - `k` must equal `ncol(l)` for s + `1` must equal `ncol(l)` for s --- - `k` must equal `ncol(r)` for s + `1` must equal `ncol(r)` for s --- - transpose must be FALSE for s + `transpose` must be FALSE for s --- - transpose must be FALSE for s + `transpose` must be FALSE for s # tapply errors as expected - `x` must be 2D greta array with one column + `x` must be 2D with one column However `x` has dimensions 10x2 --- - INDEX cannot be a greta array + `INDEX` cannot be a # ignored options are errored/warned about diff --git a/tests/testthat/_snaps/inference.md b/tests/testthat/_snaps/inference.md index d2937fe0..2a0edb04 100644 --- a/tests/testthat/_snaps/inference.md +++ b/tests/testthat/_snaps/inference.md @@ -1,7 +1,7 @@ # bad mcmc proposals are rejected The log density could not be evaluated at these initial values - Try using these initials as the values argument in `calculate()` to see what values of subsequent s these initial values lead to. + Try using these initials as the `values` argument in `calculate()` to see what values of subsequent s these initial values lead to. --- @@ -10,12 +10,12 @@ # mcmc handles initial values nicely - the number of provided initial values does not match chains + The number of provided initial values does not match chains 3 sets of initial values were provided, but there are 2 chains --- - the initial values provided have different dimensions than the named s + The initial values provided have different dimensions than the named s --- @@ -77,7 +77,7 @@ --- - all initial values must be named + All initial values must be named --- @@ -101,28 +101,28 @@ --- - some s passed to `initials()` are not associated with the model: + Some s passed to `initials()` are not associated with the model: `g` --- - initial values can only be set for variable s + Initial values can only be set for variable s --- - initial values can only be set for variable s + Initial values can only be set for variable s --- - some provided initial values are outside the range of values their variables can take + Some provided initial values are outside the range of values their variables can take --- - some provided initial values are outside the range of values their variables can take + Some provided initial values are outside the range of values their variables can take --- - some provided initial values are outside the range of values their variables can take + Some provided initial values are outside the range of values their variables can take # samplers print informatively diff --git a/tests/testthat/_snaps/joint.md b/tests/testthat/_snaps/joint.md index 12e149bc..cee49ad3 100644 --- a/tests/testthat/_snaps/joint.md +++ b/tests/testthat/_snaps/joint.md @@ -1,16 +1,16 @@ # joint of fixed and continuous distributions errors - cannot construct a joint distribution from a combination of discrete and continuous distributions + Cannot construct a joint distribution from a combination of discrete and continuous distributions # joint with insufficient distributions errors - `joint()` must be passed at least two distributions - The number of distributions passed was 1 + `joint()` must be passed at least 2 distributions + The number of distributions passed was: 1 --- - `joint()` must be passed at least two distributions - The number of distributions passed was 0 + `joint()` must be passed at least 2 distributions + The number of distributions passed was: 0 # joint with non-scalar distributions errors diff --git a/tests/testthat/_snaps/misc.md b/tests/testthat/_snaps/misc.md index 10384610..ed413f69 100644 --- a/tests/testthat/_snaps/misc.md +++ b/tests/testthat/_snaps/misc.md @@ -45,7 +45,7 @@ --- - model contains a discrete random variable that doesn't have a fixed value, so inference cannot be carried out + Model contains a discrete random variable that doesn't have a fixed value, so inference cannot be carried out. --- @@ -53,7 +53,7 @@ --- - data s cannot be sampled + Data s cannot be sampled `x` is a data (s) # check_dims errors informatively diff --git a/tests/testthat/_snaps/mixture.md b/tests/testthat/_snaps/mixture.md index 0b5fac14..5a44f127 100644 --- a/tests/testthat/_snaps/mixture.md +++ b/tests/testthat/_snaps/mixture.md @@ -1,35 +1,35 @@ # mixtures of fixed and continuous distributions errors - cannot construct a mixture from a combination of discrete and continuous distributions + Cannot construct a mixture distribution from a combination of discrete and continuous distributions # mixtures of multivariate and univariate errors - cannot construct a mixture from a combination of multivariate and univariate distributions + Cannot construct a mixture from a combination of multivariate and univariate distributions # mixtures of supports errors - component distributions must have the same support + Component distributions must have the same support However the component distributions have different support: "0 to Inf vs. -Inf to Inf" --- - component distributions must have the same support + Component distributions must have the same support However the component distributions have different support: "0 to Inf vs. -Inf to Inf" # incorrectly-shaped weights errors - the first dimension of weights must be the number of distributions in the mixture (2) + The first dimension of weights must be the number of distributions in the mixture (2) However it was 1 # mixtures with insufficient distributions errors - `mixture()` must be passed at least two distributions + `mixture()` must be passed at least 2 distributions The number of distributions passed was: 1 --- - `mixture()` must be passed at least two distributions + `mixture()` must be passed at least 2 distributions The number of distributions passed was: 0 diff --git a/tests/testthat/_snaps/operators.md b/tests/testthat/_snaps/operators.md index 21e5cf7f..866f04bb 100644 --- a/tests/testthat/_snaps/operators.md +++ b/tests/testthat/_snaps/operators.md @@ -1,11 +1,13 @@ # %*% errors informatively - incompatible dimensions: "3x4" vs "1x4" + Incompatible dimensions: "3x4" vs "1x4" --- - only two-dimensional s can be matrix-multiplied - dimensions recorded were 3 and 4 + Only two-dimensional s can be matrix-multiplied + Dimensions for each are: + `x`: "3x4" + `y`: "2x2x2" # %*% works when one is a non-greta array diff --git a/tests/testthat/_snaps/variables.md b/tests/testthat/_snaps/variables.md index 5aad289a..a8d6ba3c 100644 --- a/tests/testthat/_snaps/variables.md +++ b/tests/testthat/_snaps/variables.md @@ -52,17 +52,21 @@ --- - the final dimension of a simplex variable must have more than one element + The final dimension of a simplex variable must have more than one element + The final dimension has: "1 element" --- - the final dimension of a simplex variable must have more than one element + The final dimension of a simplex variable must have more than one element + The final dimension has: "1 element" --- - the final dimension of an ordered variable must have more than one element + The final dimension of a ordered variable must have more than one element + The final dimension has: "1 element" --- - the final dimension of an ordered variable must have more than one element + The final dimension of a ordered variable must have more than one element + The final dimension has: "1 element" diff --git a/tests/testthat/test-diagrammer-installed.R b/tests/testthat/test-diagrammer-installed.R index e4a5dd71..7577b333 100644 --- a/tests/testthat/test-diagrammer-installed.R +++ b/tests/testthat/test-diagrammer-installed.R @@ -1,13 +1,12 @@ test_that("DiagrammeR installation is checked", { skip_if_not(check_tf_version()) skip_on_cran() - mockery::stub( - where = plot.greta_model, - what = 'is_DiagrammeR_installed', - how = FALSE - ) + local_mocked_bindings( + is_DiagrammeR_installed = function() FALSE + ) m <- model(normal(0,1)) - expect_snapshot_error( + expect_snapshot( + error = TRUE, x = plot(m) ) }) diff --git a/tests/testthat/test_as_data.R b/tests/testthat/test_as_data.R index e05f1b26..792bf0eb 100644 --- a/tests/testthat/test_as_data.R +++ b/tests/testthat/test_as_data.R @@ -108,15 +108,18 @@ test_that("as_data errors informatively", { # wrong class of object - expect_snapshot_error( + expect_snapshot( + error = TRUE, as_data(NULL) ) - expect_snapshot_error( + expect_snapshot( + error = TRUE, as_data(list()) ) - expect_snapshot_error( + expect_snapshot( + error = TRUE, as_data(environment()) ) @@ -127,23 +130,28 @@ test_that("as_data errors informatively", { cha_df <- as.data.frame(cha_mat, stringsAsFactors = FALSE) cha_df2 <- as.data.frame(cha_mat, stringsAsFactors = TRUE) - expect_snapshot_error( + expect_snapshot( + error = TRUE, as_data(cha_vec) ) - expect_snapshot_error( + expect_snapshot( + error = TRUE, as_data(cha_mat) ) - expect_snapshot_error( + expect_snapshot( + error = TRUE, as_data(cha_arr) ) - expect_snapshot_error( + expect_snapshot( + error = TRUE, as_data(cha_df) ) - expect_snapshot_error( + expect_snapshot( + error = TRUE, as_data(cha_df2) ) @@ -155,15 +163,18 @@ test_that("as_data errors informatively", { arr_na <- randn(3, 3) arr_na[1, 3] <- NA - expect_snapshot_error( + expect_snapshot( + error = TRUE, as_data(arr_inf) ) - expect_snapshot_error( + expect_snapshot( + error = TRUE, as_data(arr_minf) ) - expect_snapshot_error( + expect_snapshot( + error = TRUE, as_data(arr_na) ) @@ -171,10 +182,12 @@ test_that("as_data errors informatively", { stoch <- normal(0, 1, dim = c(2, 3)) op <- stoch^2 - expect_snapshot_error( + expect_snapshot( + error = TRUE, as_data(stoch) ) - expect_snapshot_error( + expect_snapshot( + error = TRUE, as_data(op) ) })