Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable multi_predict._coxnet() for all types #282

Merged
merged 4 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ Depends:
R (>= 3.5.0),
survival (>= 3.3-1)
Imports:
cli,
dials,
dplyr (>= 0.8.0.1),
generics,
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@

* `survival_time_coxnet()` and `survival_prob_coxnet()` gain a `multi` argument to allow multiple values for `penalty` (#278, #279).

* `multi_predict()` is now available for all prediction types for `proportional_hazards()` models with the `"glmnet"` engine, so newly also for `type = "time"` and `type = "raw"` (#277, #282).

* Bug fix for `multi_predict(type = "survival")` for `proportional_hazards(engine = "glmnet")` models: when used with a single `penalty` value, this value is now included in the results. It was previously omitted (#267, #282).


# censored 0.2.0

Expand Down
215 changes: 215 additions & 0 deletions R/parsnip-utils.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
# utilities copied from parnsip

pred_types <- c(
"raw", "numeric", "class", "prob", "conf_int", "pred_int", "quantile",
"time", "survival", "linear_pred", "hazard"
)

# used directly, probably export?
check_pred_type <- function(object, type, ...) {
if (is.null(type)) {
type <-
switch(object$spec$mode,
regression = "numeric",
classification = "class",
"censored regression" = "time",
rlang::abort("`type` should be 'regression', 'censored regression', or 'classification'.")
)
}
if (!(type %in% pred_types)) {
rlang::abort(
glue::glue(
"`type` should be one of: ",
glue::glue_collapse(pred_types, sep = ", ", last = " and ")
)
)
}

switch(type,
"numeric" = if (object$spec$mode != "regression") {
rlang::abort("For numeric predictions, the object should be a regression model.")
},
"class" = if (object$spec$mode != "classification") {
rlang::abort("For class predictions, the object should be a classification model.")
},
"prob" = if (object$spec$mode != "classification") {
rlang::abort("For probability predictions, the object should be a classification model.")
},
"time" = if (object$spec$mode != "censored regression") {
rlang::abort("For event time predictions, the object should be a censored regression.")
},
"survival" = if (object$spec$mode != "censored regression") {
rlang::abort("For survival probability predictions, the object should be a censored regression.")
},
"hazard" = if (object$spec$mode != "censored regression") {
rlang::abort("For hazard predictions, the object should be a censored regression.")
},
"linear_pred" = if (object$spec$mode != "censored regression") {
rlang::abort("For the linear predictor, the object should be a censored regression.")
}
)

# TODO check for ... options when not the correct type
type
}

# used directly, maybe export?
check_spec_pred_type <- function(object, type) {
if (!spec_has_pred_type(object, type)) {
possible_preds <- names(object$spec$method$pred)
rlang::abort(c(
glue::glue("No {type} prediction method available for this model."),
glue::glue(
"Value for `type` should be one of: ",
glue::glue_collapse(glue::glue("'{possible_preds}'"), sep = ", ")
)
))
}
invisible(NULL)
}

spec_has_pred_type <- function(object, type) {
possible_preds <- names(object$spec$method$pred)
any(possible_preds == type)
}

# used directly, probably export
check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env()) {
the_dots <- list(...)
nms <- names(the_dots)

# ----------------------------------------------------------------------------

check_for_newdata(..., call = call)

# ----------------------------------------------------------------------------

other_args <- c(
"interval", "level", "std_error", "quantile",
"time", "eval_time", "increasing"
)
is_pred_arg <- names(the_dots) %in% other_args
if (any(!is_pred_arg)) {
bad_args <- names(the_dots)[!is_pred_arg]
bad_args <- paste0("`", bad_args, "`", collapse = ", ")
rlang::abort(
glue::glue(
"The ellipses are not used to pass args to the model function's ",
"predict function. These arguments cannot be used: {bad_args}",
)
)
}

# ----------------------------------------------------------------------------
# places where eval_time should not be given
if (any(nms == "eval_time") & !type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"`eval_time` should only be passed to `predict()` when `type` is one of:",
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
)
)
}
if (any(nms == "time") & !type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"'time' should only be passed to `predict()` when 'type' is one of:",
paste0("'", c("survival", "hazard"), "'", collapse = ", ")
)
)
}
# when eval_time should be passed
if (!any(nms %in% c("eval_time", "time")) & type %in% c("survival", "hazard")) {
rlang::abort(
paste(
"When using `type` values of 'survival' or 'hazard',",
"a numeric vector `eval_time` should also be given."
)
)
}

# `increasing` only applies to linear_pred for censored regression
if (any(nms == "increasing") &
!(type == "linear_pred" &
object$spec$mode == "censored regression")) {
rlang::abort(
paste(
"The 'increasing' argument only applies to predictions of",
"type 'linear_pred' for the mode censored regression."
)
)
}

invisible(TRUE)
}

check_for_newdata <- function(..., call = rlang::caller_env()) {
if (any(names(list(...)) == "newdata")) {
rlang::abort(
"Please use `new_data` instead of `newdata`.",
call = call
)
}
}

# used directly, maybe export?
check_installs <- function(x) {
if (length(x$method$libs) > 0) {
is_inst <- purrr::map_lgl(x$method$libs, is_installed)
if (any(!is_inst)) {
missing_pkg <- x$method$libs[!is_inst]
missing_pkg <- paste0(missing_pkg, collapse = ", ")
rlang::abort(
glue::glue(
"This engine requires some package installs: ",
glue::glue_collapse(glue::glue("'{missing_pkg}'"), sep = ", ")
)
)
}
}
}

shhhh <- function(x) {
suppressPackageStartupMessages(requireNamespace(x, quietly = TRUE))
}

is_installed <- function(pkg) {
res <- try(shhhh(pkg), silent = TRUE)
res
}

# used directly, maybe export?
load_libs <- function(x, quiet, attach = FALSE) {
for (pkg in x$method$libs) {
if (!attach) {
suppressPackageStartupMessages(requireNamespace(pkg, quietly = quiet))
} else {
library(pkg, character.only = TRUE, quietly = quiet)
}
}
invisible(x)
}

# used directly, from parsnip's standalone file
.filter_eval_time <- function(eval_time, fail = TRUE) {
if (!is.null(eval_time)) {
eval_time <- as.numeric(eval_time)
}
eval_time_0 <- eval_time
# will still propagate nulls:
eval_time <- eval_time[!is.na(eval_time)]
eval_time <- eval_time[eval_time >= 0 & is.finite(eval_time)]
eval_time <- unique(eval_time)
if (fail && identical(eval_time, numeric(0))) {
cli::cli_abort(
"There were no usable evaluation times (finite, non-missing, and >= 0).",
call = NULL
)
}
if (!identical(eval_time, eval_time_0)) {
diffs <- setdiff(eval_time_0, eval_time)
cli::cli_warn("There {?was/were} {length(diffs)} inappropriate evaluation
time point{?s} that {?was/were} removed.", call = NULL)
}
eval_time
}
96 changes: 82 additions & 14 deletions R/proportional_hazards-glmnet.R
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ check_strata_remaining <- function(expr, call = rlang::caller_env()) {
call = call
)
} else if (is_call(expr)) {
#lapply() instead of map() to avoid map() reporting the index of where it errors
# lapply() instead of map() to avoid map() reporting the index of where it errors
expr[-1] <- lapply(as.list(expr[-1]), check_strata_remaining, call = call)
expr
} else {
Expand Down Expand Up @@ -307,14 +307,11 @@ predict_raw._coxnet <- function(object, new_data, opts = list(), ...) {
multi_predict._coxnet <- function(object,
new_data,
type = NULL,
opts = list(),
penalty = NULL,
...) {
dots <- list(...)

if (any(names(dots) == "newdata")) {
rlang::abort("Please use `new_data` instead of `newdata`.")
}

object$spec <- eval_args(object$spec)

if (is.null(penalty)) {
Expand All @@ -326,29 +323,100 @@ multi_predict._coxnet <- function(object,
}
}

if (type == "linear_pred") {
pred <- multi_predict_coxnet_linear_pred(
# from predict._coxnet()
object$spec$args$penalty <- parsnip::.check_glmnet_penalty_predict(
penalty,
object,
multi = TRUE
)

# from predict.model_fit()
check_installs(object$spec)
load_libs(object$spec, quiet = TRUE)

type <- check_pred_type(object, type)
check_spec_pred_type(object, type) # added from predict_<type>()
if (type != "raw" && length(opts) > 0) {
rlang::warn("`opts` is only used with `type = 'raw'` and was ignored.")
}
check_pred_type_dots(object, type, ...)

pred <- switch(
type,
"time" = multi_predict_coxnet_time(
object,
new_data = new_data,
penalty = penalty
),
"survival" = multi_predict_coxnet_survival(
object,
new_data = new_data,
penalty = penalty,
... # contains eval_time
),
"linear_pred" = multi_predict_coxnet_linear_pred(
object,
new_data = new_data,
opts = dots,
penalty = penalty
)
} else {
pred <- predict(
),
"raw" = predict(
object,
new_data = new_data,
type = type,
...,
type = "raw",
opts = opts,
penalty = penalty,
multi = TRUE
)
}
)

pred
}

multi_predict_coxnet_linear_pred <- function(object, new_data, opts, penalty) {
multi_predict_coxnet_time <- function(object, new_data, penalty) {
# from predict_time.model_fit()
new_data <- parsnip::prepare_data(object, new_data)

# no pre- or post-hooks for this engine
res <- survival_time_coxnet(
object,
new_data = new_data,
penalty = penalty,
multi = TRUE
)

res
}

multi_predict_coxnet_survival <- function(object, new_data, penalty, ...) {
dots <- list(...)

# from predict_survival.model_fit()
if ("time" %in% names(dots)) {
lifecycle::deprecate_warn(
"0.2.0",
"multi_predict(time)",
"multi_predict(eval_time)"
)
dots$eval_time <- dots$time
}
dots$eval_time <- .filter_eval_time(dots$eval_time)

new_data <- parsnip::prepare_data(object, new_data)

# no pre- or post-hooks for this engine
res <- survival_prob_coxnet(
object,
new_data = new_data,
penalty = penalty,
multi = TRUE,
eval_time = dots$eval_time
)

res
}

multi_predict_coxnet_linear_pred <- function(object, new_data, opts, penalty) {
if ("increasing" %in% names(opts)) {
increasing <- opts$increasing
opts$increasing <- NULL
Expand Down
Loading
Loading