Skip to content

Commit

Permalink
select.mdl_df() and transmute.mdl_df() now keep the key by default
Browse files Browse the repository at this point in the history
Resolves #170
Ref #192
  • Loading branch information
mitchelloharawild committed May 19, 2020
1 parent b735f0a commit 91c75fe
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 23 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ S3method(tidy,mdl_ts)
S3method(tidy,null_mdl)
S3method(transmute,fbl_ts)
S3method(transmute,grouped_fbl)
S3method(transmute,mdl_df)
S3method(ungroup,fbl_ts)
S3method(ungroup,grouped_fbl)
S3method(unique,fcdist)
Expand Down
5 changes: 3 additions & 2 deletions R/dplyr-mable.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ dplyr_row_slice.mdl_df <- function(data, i, ..., preserve = FALSE) {
#' @export
dplyr_col_modify.mdl_df <- function(data, cols) {
res <- dplyr_col_modify(as_tibble(data), cols)

val_key <- any(key_vars(data) %in% cols)
if (val_key) {
key_vars <- setdiff(names(res), measured_vars(data))
Expand All @@ -19,5 +18,7 @@ dplyr_col_modify.mdl_df <- function(data, cols) {
#' @export
dplyr_reconstruct.mdl_df <- function(data, template) {
res <- NextMethod()
build_mable(data, key_data = key_data(template), model = !!mable_vars(template))
build_mable(data,
key = !!key_vars(template),
model = !!intersect(mable_vars(template), colnames(res)))
}
50 changes: 31 additions & 19 deletions R/mable.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,12 @@ build_mable <- function (x, key = NULL, key_data = NULL, model) {
abort("The result is not a valid mable. The key variables must uniquely identify each row.")
}

build_mable_meta(x, key_data, model)
}

build_mable_meta <- function(x, key_data, model){
tibble::new_tibble(x, key = key_data, model = model,
nrow = NROW(x), class = "mdl_df", subclass = "mdl_df")
nrow = NROW(x), class = "mdl_df", subclass = "mdl_df")
}

#' @export
Expand All @@ -89,6 +93,24 @@ tbl_sum.mdl_df <- function(x){
out
}

restore_mable <- function(data, template){
data <- as_tibble(data)
data_cols <- names(data)

# key_vars <- setdiff(key_vars(template), data_cols)
# key_data <- select(key_data(template), key_vars)
# if (vec_size(key_data) == 1) {
# template <- remove_key(template, setdiff(key_vars(template), key_vars))
# }

model_vars <- intersect(mable_vars(template), data_cols)
# Variables to keep
mbl_vars <- setdiff(key_vars(template), data_cols)
res <- dplyr::bind_cols(template[mbl_vars], data)

build_mable(res, key = !!key_vars(template), model = !!model_vars)
}

#' @export
gather.mdl_df <- function(data, key = "key", value = "value", ..., na.rm = FALSE,
convert = FALSE, factor_key = FALSE){
Expand All @@ -97,28 +119,18 @@ gather.mdl_df <- function(data, key = "key", value = "value", ..., na.rm = FALSE
..., na.rm = na.rm, convert = convert, factor_key = factor_key)
mdls <- names(which(map_lgl(tbl, inherits, "lst_mdl")))
kv <- c(key_vars(data), key)
as_mable(tbl, key = kv, model = mdls)
build_mable(tbl, key = !!kv, model = !!mdls)
}

# Adapted from tsibble:::select_tsibble
#' @export
select.mdl_df <- function (.data, ...){
sel_data <- NextMethod()
sel_vars <- names(sel_data)

kv <- key_vars(.data)
key_vars <- intersect(sel_vars, kv)
key_nochange <- all(is.element(kv, key_vars))

mdls <- names(which(map_lgl(sel_data, inherits, "lst_mdl")))
if(is_empty(mdls)){
abort("A mable must contain at least one model. To remove all models, first convert to a tibble with `as_tibble()`.")
}
if(key_nochange) {
build_mable(sel_data, key_data = key_data(.data), model = mdls)
} else {
build_mable(sel_data, key = !!key_vars, model = mdls)
}
res <- select(as_tibble(.data), ...)
restore_mable(res, .data)
}
#' @export
transmute.mdl_df <- function (.data, ...){
res <- transmute(as_tibble(.data), ...)
restore_mable(res, .data)
}

#' @export
Expand Down
6 changes: 4 additions & 2 deletions R/vctrs-mable.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,17 @@ vec_ptype2.mdl_df.tbl_df <- vec_ptype2.mdl_df.mdl_df

mable_ptype2 <- function(x, y, ...) {
key_x <- key_vars(x)
resp_x <- response_vars(x)
mdl_x <- mable_vars(x)
if (is_mable(y)) {
resp_x <- response_vars(x)
if (!identical(resp_x, response_vars(y))) {
abort("Objects with different response variables cannot be combined.")
}
key_x <- union(key_x, key_vars(y))
mdl_x <- union(mdl_x, mable_vars(y))
}
out <- df_ptype2(x, y, ...)
build_mable(out, !!key_x)
build_mable_meta(out, key_data = key_x, model = mdl_x)
}

#' @rdname mable-vctrs
Expand Down

0 comments on commit 91c75fe

Please sign in to comment.