Skip to content

Commit

Permalink
Continue to add many checking functions:
Browse files Browse the repository at this point in the history
* check_not_greta_array()
* check_param_greta_array()
* check_final_dim()
* check_rows_equal()
* check_is_column_array()
* check_stats_dim_matches_x_dim()
* check_x_matches_ncol()
* check_transpose()
* check_2_by_1()
* check_fields_installed()
* check_diagrammer_installed()
* check_ncols_match()
* check_greta_data_frame()
* check_greta_array_type()
* check_unfixed_discrete_distributions()
* check_not_data_greta_arrays()
* check_greta_arrays_associated_with_model()
* check_nodes_all_variable()
* check_initial_values_correct_dim()
* check_initial_values_match_chains()
* check_initials_are_numeric()
* check_initials_are_named()
* check_weights_dim()
* check_num_distributions()
* check_not_discrete_continuous()
* check_not_multivariate_univariate()
* check_distribution_support()
* check_compatible_dimensions()
* check_both_2d()
  • Loading branch information
njtierney committed Aug 14, 2024
1 parent 374f52c commit 016464b
Show file tree
Hide file tree
Showing 26 changed files with 862 additions and 604 deletions.
515 changes: 511 additions & 4 deletions R/checkers.R

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions R/extract_replace_combine.R
Original file line number Diff line number Diff line change
Expand Up @@ -649,15 +649,15 @@ diag.greta_array <- function(x = 1, nrow, ncol) {
# check the rank isn't too high
if (!is_2d(x)) {
cli::cli_abort(
"cannot only extract the diagonal from a node with exactly two \\
"Cannot only extract the diagonal from a node with exactly two \\

Check warning on line 652 in R/extract_replace_combine.R

View check run for this annotation

Codecov / codecov/patch

R/extract_replace_combine.R#L652

Added line #L652 was not covered by tests
dimensions"
)
}

is_square <- dim[1] != dim[2]
if (is_square) {
cli::cli_abort(
"diagonal elements can only be extracted from square matrices"
"Diagonal elements can only be extracted from square matrices"

Check warning on line 660 in R/extract_replace_combine.R

View check run for this annotation

Codecov / codecov/patch

R/extract_replace_combine.R#L660

Added line #L660 was not covered by tests
)
}

Expand Down
126 changes: 15 additions & 111 deletions R/functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -399,19 +399,7 @@ solve.greta_array <- function(a, b, ...) {
check_2d(a)

# check the matrix is square
a_not_square <- dim(a)[1] != dim(a)[2]
if (a_not_square) {

cli::cli_abort(
c(
"{.var a} is not square",
"x" = "{.var a} must be square, but has \\
{dim(a)[1]} rows and \\
{dim(a)[2]} columns"
)
)

}
check_square(a)

# if they just want the matrix inverse, do that
if (missing(b)) {
Expand All @@ -426,18 +414,8 @@ solve.greta_array <- function(a, b, ...) {
} else {

check_2d(b)

# b must have the right number of rows
rows_not_equal <- dim(b)[1] != dim(a)[1]
if (rows_not_equal) {
cli::cli_abort(
c(
"Number of rows not equal",
"x" = "{.var b} must have the same number of rows as {.var a} \\
({dim(a)[1]}), but has {dim(b)[1]} rows instead"
)
)
}
check_rows_equal(a, b)

# ... and solve the linear equations
result <- op("solve", a, b,
Expand Down Expand Up @@ -755,30 +733,9 @@ sweep.greta_array <- function(x,
}

check_2d(x)

# STATS must be a column array
is_column_array <- is_2d(stats) && dim(stats)[2] == 1
if (!is_column_array) {
cli::cli_abort(
c(
"{.var stats} not a column vector array",
"{.var stats} must be a column vector array",
"x" = "{.var stats} has dimensions \\
{paste(dim(stats), collapse = 'x')}"
)
)
}

check_is_column_array(stats)
# STATS must have the same dimension as the correct dim of x
stats_dim_matches_x_dim <- dim(x)[margin] == dim(stats)[1]
if (!stats_dim_matches_x_dim) {
cli::cli_abort(
c(
"The number of elements of {.var stats} does not match \\
{.code dim(x)[MARGIN]}"
)
)
}
check_stats_dim_matches_x_dim(x, margin, stats)

op("sweep", x, stats,
operation_args = list(margin = margin, fun = fun),
Expand Down Expand Up @@ -869,19 +826,8 @@ backsolve.greta_array <- function(r, x,
upper.tri = TRUE,
transpose = FALSE) {
# nolint end
if (k != ncol(r)) {
cli::cli_abort(
c(
"{.var k} must equal {.code ncol(r)} for {.cls greta_array}s"
)
)
}

if (transpose) {
cli::cli_abort(
"transpose must be FALSE for {.cls greta_array}s"
)
}
check_x_matches_ncol(x = k, ncol_of = r)
check_transpose(transpose)

op("backsolve", r, x,
operation_args = list(lower = !upper.tri),
Expand Down Expand Up @@ -921,17 +867,8 @@ forwardsolve.greta_array <- function(l, x,
upper.tri = FALSE,
transpose = FALSE) {
# nolint end
if (k != ncol(l)) {
cli::cli_abort(
"{.var k} must equal {.code ncol(l)} for {.cls greta_array}s"
)
}

if (transpose) {
cli::cli_abort(
"transpose must be FALSE for {.cls greta_array}s"
)
}
check_x_matches_ncol(x = k, ncol_of = l)
check_transpose(transpose)

op("forwardsolve", l, x,
operation_args = list(lower = !upper.tri),
Expand Down Expand Up @@ -968,11 +905,7 @@ apply.greta_array <- function(X, MARGIN,
# nolint end
fun <- match.arg(FUN)

if (is.greta_array(MARGIN)) {
cli::cli_abort(
"MARGIN cannot be a greta array"
)
}
check_not_greta_array(MARGIN)

margin <- as.integer(MARGIN)

Expand Down Expand Up @@ -1061,27 +994,14 @@ tapply.greta_array <- function(X, INDEX,
index <- INDEX
fun <- match.arg(FUN)

if (is.greta_array(index)) {
cli::cli_abort(
"INDEX cannot be a greta array"
)
}
check_not_greta_array(INDEX)

# convert index to successive integers starting at 0
groups <- sort(unique(index))
id <- match(index, groups) - 1L
len <- length(groups)

dim_x <- dim(x)
is_2_by_1 <- is_2d(x) && dim_x[2] == 1L
if (!is_2_by_1) {
cli::cli_abort(
c(
"{.var x} must be 2D greta array with one column",
"However {.var x} has dimensions {paste(dim_x, collapse = 'x')}"
)
)
}
check_2_by_1(x)

op("tapply", x,
operation_args = list(
Expand Down Expand Up @@ -1128,6 +1048,7 @@ eigen.greta_array <- function(x, symmetric,

is_square <- dims[1] == dims[2]
is_not_2d_square_symmetric <- !is_2d(x) | !is_square | !symmetric

if (is_not_2d_square_symmetric) {
cli::cli_abort(
"only two-dimensional, square, symmetric {.cls greta_array}s can be \\
Expand Down Expand Up @@ -1181,17 +1102,8 @@ rdist <- function(x1, x2 = NULL, compact = FALSE) {
#' @export
rdist.default <- function(x1, x2 = NULL, compact = FALSE) {
# error nicely if they don't have fields installed
fields_installed <- requireNamespace("fields", quietly = TRUE)
if (!fields_installed) {
cli::cli_abort(
c(
"{.pkg fields} package must be installed to use {.fun rdist} on greta \\
arrays",
"Install {.pkg fields} with:",
"{.code install.packages('fields')}"
)
)
}
check_fields_installed()

fields::rdist(
x1 = x1,
x2 = x2,
Expand Down Expand Up @@ -1235,15 +1147,7 @@ rdist.greta_array <- function(x1, x2 = NULL, compact = FALSE) {
# error if they have different number of columns. fields::rdist allows
# different numbers of columns, takes the number of columns from x1,and
# sometimes gives nonsense results
if (ncol(x1) != ncol(x2)) {
cli::cli_abort(
c(
"{.var x1} and {.var x2} must have the same number of columns",
"However {.code ncol(x1)} = {ncol(x1)} and \\
{.code ncol(x2)} = {ncol(x2)}"
)
)
}
check_ncols_match(x1, x2)

n2 <- nrow(x2)

Expand Down
52 changes: 9 additions & 43 deletions R/greta_array_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -30,21 +30,7 @@ as.greta_array.logical <- function(x, optional = FALSE, original_x = x, ...) {
#' @export
as.greta_array.data.frame <- function(x, optional = FALSE,
original_x = x, ...) {
classes <- vapply(x, class, "")
valid <- classes %in% c("numeric", "integer", "logical")

array_has_different_types <- !optional & !all(valid)
if (array_has_different_types) {
invalid_types <- unique(classes[!valid])
cli::cli_abort(
c(
"{.cls greta_array} must contain the same type",
"Cannot coerce a {.cls data.frame} to a {.cls greta_array} unless \\
all columns are {.cls numeric, integer} or {.cls logical}. This \\
dataframe had columns of type: {.cls {invalid_types}}"
)
)
}
check_greta_data_frame(x, optional)

as.greta_array.numeric(as.matrix(x),
optional = optional,
Expand All @@ -57,22 +43,12 @@ as.greta_array.data.frame <- function(x, optional = FALSE,
# or numeric
#' @export
as.greta_array.matrix <- function(x, optional = FALSE, original_x = x, ...) {
## TODO better abstract these if else clauses
if (!is.numeric(x)) {
if (is.logical(x)) {

check_greta_array_type(x, optional)

if (!is.numeric(x) && is.logical(x)) {
x[] <- as.numeric(x[])
} else if (!optional) {
cli::cli_abort(
c(
"{.cls greta_array} must contain the same type",
"Cannot coerce {.cls matrix} to a {.cls greta_array} unless it is \\
{.cls numeric}, {.cls integer} or {.cls logical}. This \\
{.cls matrix} had type:",
"{.cls {class(as.vector(x))}}"
)
)
}
}

as.greta_array.numeric(x,
optional = optional,
Expand All @@ -85,21 +61,11 @@ as.greta_array.matrix <- function(x, optional = FALSE, original_x = x, ...) {
# or numeric
#' @export
as.greta_array.array <- function(x, optional = FALSE, original_x = x, ...) {
## TODO Better abstract out these if statements
if (!optional & !is.numeric(x)) {
if (is.logical(x)) {

check_greta_array_type(x, optional)

if (!optional && !is.numeric(x) && is.logical(x)) {
x[] <- as.numeric(x[])
} else {
cli::cli_abort(
c(
"{.cls greta_array} must contain the same type",
"Cannot coerce {.cls array} to a {.cls greta_array} unless it is \\
{.cls numeric}, {.cls integer} or {.cls logical}. This {.cls array} \\
had type:",
"{.cls {class(as.vector(x))}}"
)
)
}
}

as.greta_array.numeric(x,
Expand Down
29 changes: 2 additions & 27 deletions R/greta_model_class.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,24 +155,7 @@ model <- function(...,
}
}

# check for unfixed discrete distributions
distributions <- dag$node_list[dag$node_types == "distribution"]
bad_nodes <- vapply(
distributions,
function(x) {
valid_target <- is.null(x$target) ||
is.data_node(x$target)
x$discrete && !valid_target
},
FALSE
)

if (any(bad_nodes)) {
cli::cli_abort(
"model contains a discrete random variable that doesn't have a fixed \\
value, so inference cannot be carried out"
)
}
check_unfixed_discrete_distributions(dag)

# define the TF graph
# dag$define_tf()
Expand Down Expand Up @@ -230,15 +213,7 @@ plot.greta_model <- function(x,
y,
colour = "#996bc7",
...) {
if (!is_DiagrammeR_installed()) {
cli::cli_abort(
c(
"the {.pkg DiagrammeR} package must be installed to plot greta models",
"install {.pkg DiagrammeR} with:",
"{.code install.packages('DiagrammeR')}"
)
)
}
check_diagrammer_installed()

# set up graph
dag_mat <- x$dag$adjacency_matrix
Expand Down
Loading

0 comments on commit 016464b

Please sign in to comment.