diff --git a/NAMESPACE b/NAMESPACE index 1abcc8f0..75e83033 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -145,6 +145,7 @@ S3method(tidy,mdl_ts) S3method(tidy,null_mdl) S3method(transmute,fbl_ts) S3method(transmute,grouped_fbl) +S3method(transmute,mdl_df) S3method(ungroup,fbl_ts) S3method(ungroup,grouped_fbl) S3method(unique,fcdist) diff --git a/R/dplyr-mable.R b/R/dplyr-mable.R index 8cb9d803..9cef05a1 100644 --- a/R/dplyr-mable.R +++ b/R/dplyr-mable.R @@ -7,7 +7,6 @@ dplyr_row_slice.mdl_df <- function(data, i, ..., preserve = FALSE) { #' @export dplyr_col_modify.mdl_df <- function(data, cols) { res <- dplyr_col_modify(as_tibble(data), cols) - val_key <- any(key_vars(data) %in% cols) if (val_key) { key_vars <- setdiff(names(res), measured_vars(data)) @@ -19,5 +18,7 @@ dplyr_col_modify.mdl_df <- function(data, cols) { #' @export dplyr_reconstruct.mdl_df <- function(data, template) { res <- NextMethod() - build_mable(data, key_data = key_data(template), model = !!mable_vars(template)) + build_mable(data, + key = !!key_vars(template), + model = !!intersect(mable_vars(template), colnames(res))) } diff --git a/R/mable.R b/R/mable.R index 60f22ee9..aa6a80c8 100644 --- a/R/mable.R +++ b/R/mable.R @@ -66,8 +66,12 @@ build_mable <- function (x, key = NULL, key_data = NULL, model) { abort("The result is not a valid mable. The key variables must uniquely identify each row.") } + build_mable_meta(x, key_data, model) +} + +build_mable_meta <- function(x, key_data, model){ tibble::new_tibble(x, key = key_data, model = model, - nrow = NROW(x), class = "mdl_df", subclass = "mdl_df") + nrow = NROW(x), class = "mdl_df", subclass = "mdl_df") } #' @export @@ -89,6 +93,24 @@ tbl_sum.mdl_df <- function(x){ out } +restore_mable <- function(data, template){ + data <- as_tibble(data) + data_cols <- names(data) + + # key_vars <- setdiff(key_vars(template), data_cols) + # key_data <- select(key_data(template), key_vars) + # if (vec_size(key_data) == 1) { + # template <- remove_key(template, setdiff(key_vars(template), key_vars)) + # } + + model_vars <- intersect(mable_vars(template), data_cols) + # Variables to keep + mbl_vars <- setdiff(key_vars(template), data_cols) + res <- dplyr::bind_cols(template[mbl_vars], data) + + build_mable(res, key = !!key_vars(template), model = !!model_vars) +} + #' @export gather.mdl_df <- function(data, key = "key", value = "value", ..., na.rm = FALSE, convert = FALSE, factor_key = FALSE){ @@ -97,28 +119,18 @@ gather.mdl_df <- function(data, key = "key", value = "value", ..., na.rm = FALSE ..., na.rm = na.rm, convert = convert, factor_key = factor_key) mdls <- names(which(map_lgl(tbl, inherits, "lst_mdl"))) kv <- c(key_vars(data), key) - as_mable(tbl, key = kv, model = mdls) + build_mable(tbl, key = !!kv, model = !!mdls) } -# Adapted from tsibble:::select_tsibble #' @export select.mdl_df <- function (.data, ...){ - sel_data <- NextMethod() - sel_vars <- names(sel_data) - - kv <- key_vars(.data) - key_vars <- intersect(sel_vars, kv) - key_nochange <- all(is.element(kv, key_vars)) - - mdls <- names(which(map_lgl(sel_data, inherits, "lst_mdl"))) - if(is_empty(mdls)){ - abort("A mable must contain at least one model. To remove all models, first convert to a tibble with `as_tibble()`.") - } - if(key_nochange) { - build_mable(sel_data, key_data = key_data(.data), model = mdls) - } else { - build_mable(sel_data, key = !!key_vars, model = mdls) - } + res <- select(as_tibble(.data), ...) + restore_mable(res, .data) +} +#' @export +transmute.mdl_df <- function (.data, ...){ + res <- transmute(as_tibble(.data), ...) + restore_mable(res, .data) } #' @export diff --git a/R/vctrs-mable.R b/R/vctrs-mable.R index 70502ea4..324df34d 100644 --- a/R/vctrs-mable.R +++ b/R/vctrs-mable.R @@ -34,15 +34,17 @@ vec_ptype2.mdl_df.tbl_df <- vec_ptype2.mdl_df.mdl_df mable_ptype2 <- function(x, y, ...) { key_x <- key_vars(x) - resp_x <- response_vars(x) + mdl_x <- mable_vars(x) if (is_mable(y)) { + resp_x <- response_vars(x) if (!identical(resp_x, response_vars(y))) { abort("Objects with different response variables cannot be combined.") } key_x <- union(key_x, key_vars(y)) + mdl_x <- union(mdl_x, mable_vars(y)) } out <- df_ptype2(x, y, ...) - build_mable(out, !!key_x) + build_mable_meta(out, key_data = key_x, model = mdl_x) } #' @rdname mable-vctrs