Skip to content

Commit

Permalink
Merge pull request #342 from stan-dev/indices
Browse files Browse the repository at this point in the history
extract_variable_array and the with_indices parameter for `variables()`
  • Loading branch information
mjskay authored Feb 1, 2024
2 parents a63b894 + fe6f23b commit c312846
Show file tree
Hide file tree
Showing 41 changed files with 1,127 additions and 503 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,5 @@ LazyData: false
URL: https://mc-stan.org/posterior/, https://discourse.mc-stan.org/
BugReports: https://github.com/stan-dev/posterior/issues
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.0
VignetteBuilder: knitr
7 changes: 7 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -141,9 +141,15 @@ S3method(ess_tail,default)
S3method(ess_tail,rvar)
S3method(extract_variable,default)
S3method(extract_variable,draws)
S3method(extract_variable,draws_df)
S3method(extract_variable,draws_list)
S3method(extract_variable,draws_rvars)
S3method(extract_variable_array,default)
S3method(extract_variable_array,draws)
S3method(extract_variable_matrix,default)
S3method(extract_variable_matrix,draws)
S3method(extract_variable_matrix,draws_df)
S3method(extract_variable_matrix,draws_list)
S3method(extract_variable_matrix,draws_rvars)
S3method(format,rvar)
S3method(format_glimpse,rvar)
Expand Down Expand Up @@ -435,6 +441,7 @@ export(ess_sd)
export(ess_tail)
export(example_draws)
export(extract_variable)
export(extract_variable_array)
export(extract_variable_matrix)
export(for_each_draw)
export(is_draws)
Expand Down
9 changes: 9 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@
weights before adding to a draws object.
* Matrix multiplication of `rvar`s can now be done with the base matrix
multiplication operator (`%*%`) instead of `%**%` in R >= 4.3.
* `variables()`, `variables<-()`, `set_variables()`, and `nvariables()` now
support a `with_indices` argument, which determines whether variable names
are retrieved/set with (`"x[1]"`, `"x[2]"` ...) or without (`"x"`) indices
(#208).
* Add `extract_variable_array()` function to extract variables with indices
into arrays of iterations x chains x any remaining dimensions (#340).
* For types that support `factor` variables (`draws_df`, `draws_list`, and
`draws_rvars`), `extract_variable()` and `extract_variable_matrix()` can
now return `factor`s.

# posterior 1.5.0

Expand Down
58 changes: 21 additions & 37 deletions R/as_draws_rvars.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,36 +54,28 @@ as_draws_rvars.draws_matrix <- function(x, ...) {
}

#' Helper for as_draws_rvars.draws_matrix and as_draws_rvars.draws_df()
#' @param x_at A function taking a logical vector along variables(x) and returning a matrix of draws
#' @param x_at A function taking a numeric vector of indices along variables(x) and returning a matrix of draws
#' @noRd
.as_draws_rvars.draws_matrix <- function(x, ..., x_at = function(var_i) unclass(x[, var_i, drop = FALSE])) {
.variables <- variables(x, reserved = TRUE)
.nchains <- nchains(x)
if (ndraws(x) == 0) {
return(empty_draws_rvars(.variables))
}

# split x[y,z] names into base name and indices
#
# ----- base name -> vars_indices[[i]][[2]]
# ||||| lazy-matched (.*? not .*) so that indices match as much as they can
# |||||
# ||||| ---- optional indices -> vars_indices[[i]][[3]]
# ||||| ||||
matches <- regexec("^(.*?)(?:\\[(.*)\\])?$", .variables)
vars_indices <- regmatches(.variables, matches)
vars <- vapply(vars_indices, `[[`, i = 2, character(1))
vars <- split_variable_names(.variables)
vars$i <- seq_along(.variables)

# pull out each var into its own rvar
var_names <- unique(vars)
rvars_list <- lapply(var_names, function(var) {
var_i <- vars == var
var_matrix <- x_at(var_i)
vars_by_base_name <- vctrs::vec_split(vars, vars$base_name)
rvars_list <- lapply(vars_by_base_name$val, function(var) {
var_matrix <- x_at(var$i)
attr(var_matrix, "nchains") <- NULL
var_indices <- vars_indices[var_i]

if (ncol(var_matrix) == 1 && nchar(var_indices[[1]][[3]]) == 0) {
if (ncol(var_matrix) == 1 && nchar(var$indices[[1]]) == 0) {
# single variable, no indices
out <- rvar(var_matrix)
out <- rvar(var_matrix, nchains = .nchains)
dimnames(out) <- NULL
} else {
# variable with indices => we need to reshape the array
Expand All @@ -92,8 +84,7 @@ as_draws_rvars.draws_matrix <- function(x, ...) {

# first, pull out the list of indices into a data frame
# where each column is an index variable
indices <- vapply(var_indices, `[[`, i = 3, character(1))
indices <- as.data.frame(do.call(rbind, strsplit(indices, ",")),
indices <- as.data.frame(do.call(rbind, split_indices(var$indices)),
stringsAsFactors = FALSE)
unique_indices <- vector("list", length(indices))
.dimnames <- vector("list", length(indices))
Expand Down Expand Up @@ -131,35 +122,31 @@ as_draws_rvars.draws_matrix <- function(x, ...) {
# (2) if some combination of indices is missing (say x[2,1] isn't
# in the input) that cell in the array gets an NA

# Use expand.grid to get all cells in output array. We reverse indices
# here because it helps us do the sort after the merge, where
# we need to sort in reverse order of the indices (because
# the value of the last index should move slowest)
all_indices <- expand.grid(rev(unique_indices))
# Use expand.grid to get all cells in output array in the appropriate
# order (value of the last index should move slowest), and save that order
# in $order so we can restore it after the merge
all_indices <- expand.grid(unique_indices, KEEP.OUT.ATTRS = FALSE, stringsAsFactors = FALSE)
all_indices$order <- seq_len(nrow(all_indices))
# merge with all.x = TRUE (left join) to fill in missing cells with NA
indices <- merge(all_indices, cbind(indices, index = seq_len(nrow(indices))),
all.x = TRUE, sort = FALSE)
# need to do the sort manually after merge because when sort = TRUE, merge
# sorts factors as if they were strings, and we need factors to be sorted as factors
indices <- indices[do.call(order, as.list(indices[, -ncol(indices), drop = FALSE])),]
# (and merge does not guarantee it keeps the original order in `x`)
indices <- indices[order(indices$order), ]

# re-sort the array and fill in missing cells with NA
var_matrix <- var_matrix[, indices$index, drop = FALSE]

# convert to rvar and adjust dimensions
out <- rvar(var_matrix)
out <- rvar(var_matrix, nchains = .nchains)
dim(out) <- unname(lengths(unique_indices))
dimnames(out) <- .dimnames
}
out
})
names(rvars_list) <- var_names
out <- .as_draws_rvars(rvars_list, ...)
.nchains <- nchains(x)
for (i in seq_along(out)) {
nchains_rvar(out[[i]]) <- .nchains
}
out
names(rvars_list) <- vars_by_base_name$key
.as_draws_rvars(rvars_list, ...)
}

#' @rdname draws_rvars
Expand All @@ -175,10 +162,7 @@ as_draws_rvars.draws_df <- function(x, ...) {
data_frame_to_matrix <- function(df) {
if (any(vapply(df, is.factor, logical(1)))) {
# as.matrix() does not convert factor columns correctly, must do this ourselves
while_preserving_dims(
function(df) do.call(function(...) vctrs::vec_c(..., .name_spec = rlang::zap()), df),
df
)
copy_dims(df, vctrs::vec_c(!!!df, .name_spec = rlang::zap()))
} else {
as.matrix(df)
}
Expand Down
Loading

0 comments on commit c312846

Please sign in to comment.