Skip to content

Commit

Permalink
Merge pull request #339 from tidymodels:update-quantile
Browse files Browse the repository at this point in the history
Update-quantile
  • Loading branch information
hfrick authored Feb 14, 2025
2 parents 002ebc7 + 39d1a01 commit ed3aba2
Show file tree
Hide file tree
Showing 10 changed files with 205 additions and 149 deletions.
7 changes: 4 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ URL: https://github.com/tidymodels/censored,
https://censored.tidymodels.org
BugReports: https://github.com/tidymodels/censored/issues
Depends:
parsnip (>= 1.1.0),
parsnip (>= 1.3.0),
R (>= 3.5.0),
survival (>= 3.7-0)
Imports:
Expand All @@ -27,15 +27,16 @@ Imports:
dplyr (>= 0.8.0.1),
generics,
glue,
hardhat (>= 1.1.0),
hardhat (>= 1.4.1),
lifecycle,
mboost,
prodlim (>= 2023.03.31),
purrr,
rlang (>= 1.0.0),
stats,
tibble (>= 3.1.3),
tidyr (>= 1.0.0)
tidyr (>= 1.0.0),
vctrs
Suggests:
aorsf (>= 0.1.2),
coin,
Expand Down
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# censored (development version)

## Breaking change

* The format of quantile predictions now follows the new requirements in parsnip (#339, tidymodels/parsnip/#1209).


# censored 0.3.2

* censored now depends on survival >= 3.7-0 which allows us to use it also for predictions of survival probabilities at infinite evaluation time points. This means that: Survival probabilities at `eval_time = Inf` are now not always set to 0 and confidence intervals at infinite evaluation times are now not always set to `NA`. This applies to `proportional_hazards()`and `bag_tree()` models as well as models with the `partykit` engine, `decision_tree()` and `rand_forest()` (#320).
Expand Down
30 changes: 25 additions & 5 deletions R/censored-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,31 @@ NULL

utils::globalVariables(
c(
"eval_time", ".time", "object", "new_data", ".label", ".pred", ".cuts",
".id", ".tmp", "engine", "predictor_indicators", ".strata", "group",
".pred_quantile", ".quantile", "interval", "level", ".pred_linear_pred",
".pred_link", ".pred_time", ".pred_survival", "next_event_time",
"sum_component", "time_interval"
"eval_time",
".time",
"object",
"new_data",
".label",
".pred",
".cuts",
".id",
".tmp",
"engine",
"predictor_indicators",
".strata",
"group",
".pred_quantile",
".quantile",
"interval",
"level",
".pred_linear_pred",
".pred_link",
".pred_time",
".pred_survival",
"next_event_time",
"sum_component",
"time_interval",
"quantile_levels"
)
)

Expand Down
51 changes: 24 additions & 27 deletions R/survival_reg-data.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,12 @@ make_survival_reg_survival <- function() {
pre = NULL,
post = survreg_quant,
func = c(fun = "predict"),
args =
list(
object = expr(object$fit),
newdata = expr(new_data),
type = "quantile",
p = expr(quantile)
)
args = list(
object = expr(object$fit),
newdata = expr(new_data),
type = "quantile",
p = expr(quantile_levels)
)
)
)

Expand Down Expand Up @@ -236,17 +235,16 @@ make_survival_reg_flexsurv <- function() {
type = "quantile",
value = list(
pre = NULL,
post = NULL,
post = flexsurv_post_quantile,
func = c(fun = "predict"),
args =
list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "quantile",
p = rlang::expr(quantile),
conf.int = rlang::expr(interval == "confidence"),
conf.level = rlang::expr(level)
)
args = list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "quantile",
p = rlang::expr(quantile_levels),
conf.int = rlang::expr(interval == "confidence"),
conf.level = rlang::expr(level)
)
)
)

Expand Down Expand Up @@ -393,17 +391,16 @@ make_survival_reg_flexsurvspline <- function() {
type = "quantile",
value = list(
pre = NULL,
post = NULL,
post = flexsurv_post_quantile,
func = c(fun = "predict"),
args =
list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "quantile",
p = rlang::expr(quantile),
conf.int = rlang::expr(interval == "confidence"),
conf.level = rlang::expr(level)
)
args = list(
object = rlang::expr(object$fit),
newdata = rlang::expr(new_data),
type = "quantile",
p = rlang::expr(quantile_levels),
conf.int = rlang::expr(interval == "confidence"),
conf.level = rlang::expr(level)
)
)
)

Expand Down
43 changes: 43 additions & 0 deletions R/survival_reg-flexsurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,46 @@ flexsurv_rename_time <- function(pred){
dplyr::rename(.eval_time = .time)
}
}

# Conversion of quantile predictions to the vctrs format

# For single quantile levels, flexsurv returns a data frame with column
# ".pred_quantile" and perhaps also ".pred_lower" and ".pred_upper"

# With mutiple quantile levels, flexsurv returns a data frame with a ".pred"
# column with co.lumns ".quantile" and ".pred_quantile" and perhaps
# ".pred_lower" and ".pred_upper"
flexsurv_post_quantile <- function(pred, object) {
# if one level, convert to nested format
if (!identical(names(pred), ".pred")) {
# convert to the same format as predictions with mulitplel levels
pred <- re_nest(pred)
}

# Get column names to convert to vctrs encoding
nms <- names(pred$.pred[[1]])
possible_cols <- c(".pred_quantile", ".pred_lower", ".pred_upper")
existing_cols <- intersect(possible_cols, nms)

# loop over prediction column names
res <- list()
for (col in existing_cols) {
res[[col]] <- purrr::map_vec(pred$.pred, col_to_quantile_pred, col = col)
}
tibble::new_tibble(res)
}

re_nest <- function(df) {
.row <- seq_len(nrow(df))
df <- vctrs::vec_split(df, by = .row)
df$key <- NULL
names(df) <- ".pred"
df
}

col_to_quantile_pred <- function(df, col) {
hardhat::quantile_pred(
matrix(df[[col]], nrow = 1),
quantile_levels = df$.quantile
)
}
42 changes: 10 additions & 32 deletions R/survival_reg-survival.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,41 +3,19 @@
# ------------------------------------------------------------------------------

survreg_quant <- function(results, object) {
quantile_levels <- object$spec$quantile_levels

if (!is.matrix(results)) {
results <- matrix(results, nrow = 1)
if (length(quantile_levels) < 2) {
results <- matrix(results, ncol = 1)
} else {
results <- matrix(results, nrow = 1)
}
}

pctl <- object$spec$method$pred$quantile$args$p
n <- nrow(results)
p <- ncol(results)
colnames(results) <- names0(p)

res <- results %>%
tibble::as_tibble(results) %>%
dplyr::mutate(.row = 1:n) %>%
tidyr::pivot_longer(
-.row,
names_to = ".label",
values_to = ".pred_quantile"
) %>%
dplyr::arrange(.row, .label) %>%
dplyr::mutate(.quantile = rep(pctl, n)) %>%
dplyr::select(.row, .quantile, .pred_quantile) %>%
tidyr::nest(.pred = c(-.row)) %>%
dplyr::select(-.row)

res
}

# copied form recipes
names0 <- function(num, prefix = "x", ..., call = caller_env()) {
check_dots_empty()
if (num < 1) {
cli::cli_abort("{.arg num} should be > 0.", call = call)
}
ind <- format(1:num)
ind <- gsub(" ", "0", ind)
paste0(prefix, ind)
tibble::new_tibble(
x = list(.pred_quantile = hardhat::quantile_pred(results, quantile_levels))
)
}

# ------------------------------------------------------------------------------
Expand Down
67 changes: 33 additions & 34 deletions tests/testthat/test-survival_reg-flexsurv.R
Original file line number Diff line number Diff line change
Expand Up @@ -236,45 +236,44 @@ test_that("quantile predictions", {
)

expect_s3_class(pred, "tbl_df")
expect_equal(names(pred), ".pred")
expect_equal(names(pred), ".pred_quantile")
expect_equal(nrow(pred), 3)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(dim(.x) == c(9, 2))
))
)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(names(.x) == c(".quantile", ".pred_quantile"))
))
)
expect_equal(
tidyr::unnest(pred, cols = .pred)$.pred_quantile,
do.call(rbind, exp_pred)$est
)
expect_s3_class(pred$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))

for (.row in 1:nrow(pred)) {
expect_equal(
unclass(pred$.pred_quantile[.row])[[1]],
exp_pred[[.row]]$est
)
}

# add confidence interval
pred <- predict(fit_s,
new_data = bladder[1:3, ], type = "quantile",
interval = "confidence", level = 0.7
)
expect_true(
all(purrr::map_lgl(
pred$.pred,
~ all(names(.x) == c(
".quantile",
".pred_quantile",
".pred_lower",
".pred_upper"
))
))
pred <- predict(
fit_s,
new_data = bladder[1:3, ],
type = "quantile",
interval = "confidence",
level = 0.7
)
expect_s3_class(pred, "tbl_df")
expect_equal(names(pred), c(".pred_quantile", ".pred_lower", ".pred_upper"))
expect_equal(nrow(pred), 3)
expect_s3_class(pred$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))
expect_s3_class(pred$.pred_lower, c("quantile_pred", "vctrs_vctr", "list"))
expect_s3_class(pred$.pred_upper, c("quantile_pred", "vctrs_vctr", "list"))

# single observation
f_pred_1 <- predict(fit_s, bladder[2,], type = "quantile")
f_pred_1 <- predict(fit_s, bladder[2, ], type = "quantile")
expect_identical(nrow(f_pred_1), 1L)

# single quantile
f_pred_2 <- predict(
fit_s,
bladder[1:2, ],
type = "quantile",
quantile_levels = 0.5
)
expect_identical(nrow(f_pred_2), 2L)
})

# prediction: hazard ------------------------------------------------------
Expand Down Expand Up @@ -401,13 +400,13 @@ test_that("`fix_xy()` works", {
f_fit,
new_data = lung_pred,
type = "quantile",
quantile = c(0.2, 0.8)
quantile_levels = c(0.2, 0.8)
)
xy_pred_quantile <- predict(
xy_fit,
new_data = lung_pred,
type = "quantile",
quantile = c(0.2, 0.8)
quantile_levels = c(0.2, 0.8)
)
expect_equal(f_pred_quantile, xy_pred_quantile)

Expand Down
Loading

0 comments on commit ed3aba2

Please sign in to comment.