diff --git a/NEWS.md b/NEWS.md index 813303cd..185324ab 100644 --- a/NEWS.md +++ b/NEWS.md @@ -8,4 +8,6 @@ * Internal re-organisation of code (#206, 209). +* Added internal `survfit_summary_*()` helper functions (#216). + * Added a `NEWS.md` file to track changes to the package. diff --git a/R/aaa_survival_prob.R b/R/aaa_survival_prob.R index 5e471871..1f8e0ef0 100644 --- a/R/aaa_survival_prob.R +++ b/R/aaa_survival_prob.R @@ -145,3 +145,191 @@ matrix_to_nested_tibbles_survival <- function(x, time) { dplyr::group_nest(res, .row, .key = ".pred")$.pred } + + +# summary_survfit helpers ------------------------------------------------- + + +survfit_summary_typestable <- function(object){ + # make matrix of dimension n_times x n_obs + sanitize_element <- function(x, n_obs) { + if (!is.matrix(x)) { + x <- matrix(x, ncol = n_obs) + } + x + } + # sanitize elements we care about + elements <- available_survfit_summary_elements(object) + for (i in elements) { + object[[i]] <- sanitize_element(object[[i]], n_obs = length(object$n)) + } + + object +} + +available_survfit_summary_elements <- function(object) { + intersect( + names(object), + c("surv", "std.err", "lower", "upper", "cumhaz", "std.chaz") + ) +} + +survfit_summary_patch_infinite_time <- function(object, time) { + + time_neg_inf <- is.infinite(time) & (time < 0) + time_inf <- is.infinite(time) & (time > 0) + + patch_neg_inf <- function(x, value, n_patch) { + rbind( + matrix(value, nrow = n_patch, ncol = ncol(x)), + x + ) + } + patch_inf <- function(x, value, n_patch) { + rbind( + x, + matrix(value, nrow = n_patch, ncol = ncol(x)) + ) + } + + # glmnet does not provide standard errors etc + has_std_error <- "std.err" %in% names(object) + + if (any(time_neg_inf)) { + object$surv <- patch_neg_inf( + object$surv, + value = 1, + n_patch = sum(time_neg_inf) + ) + object$cumhaz <- patch_neg_inf( + object$cumhaz, + value = 0, + n_patch = sum(time_neg_inf) + ) + if (has_std_error) { + object$std.err <- patch_neg_inf( + object$std.err, + value = NA_real_, + n_patch = sum(time_neg_inf) + ) + object$lower <- patch_neg_inf( + object$lower, + value = NA_real_, + n_patch = sum(time_neg_inf) + ) + object$upper <- patch_neg_inf( + object$upper, + value = NA_real_, + n_patch = sum(time_neg_inf) + ) + object$std.chaz <- patch_neg_inf( + object$std.chaz, + value = NA_real_, + n_patch = sum(time_neg_inf) + ) + } + } + if (any(time_inf)) { + object$surv <- patch_inf(object$surv, value = 0, n_patch = sum(time_inf)) + object$cumhaz <- patch_inf( + object$cumhaz, + value = 1, + n_patch = sum(time_inf) + ) + if (has_std_error) { + object$std.err <- patch_inf( + object$std.err, + value = NA_real_, + n_patch = sum(time_inf) + ) + object$lower <- patch_inf( + object$lower, + value = NA_real_, + n_patch = sum(time_inf) + ) + object$upper <- patch_inf( + object$upper, + value = NA_real_, + n_patch = sum(time_inf) + ) + object$std.chaz <- patch_inf( + object$std.chaz, + value = NA_real_, + n_patch = sum(time_inf) + ) + } + } + + object +} + +survfit_summary_restore_time_order <- function(object, time) { + # preserve original order of `time` because `summary()` returns a result for + # an ordered vector of finite time + # Note that this requires a survfit summary object which has already been + # patched for infinite time points + original_order_time <- match(time, sort(time)) + + elements <- available_survfit_summary_elements(object) + + # restore original order of prediction time points + for (i in elements) { + object[[i]] <- object[[i]][original_order_time, , drop = FALSE] + } + + object +} + +survfit_summary_patch_missings <- function(object, index_missing, time, n_obs) { + if (is.null(index_missing)) { + return(object) + } + + patch_element <- function(x, time, n_obs, index_missing) { + full_matrix <- matrix(NA, nrow = length(time), ncol = n_obs) + full_matrix[, -index_missing] <- x + full_matrix + } + + elements <- available_survfit_summary_elements(object) + + for (i in elements) { + object[[i]] <- patch_element( + object[[i]], + time = time, + n_obs = n_obs, + index_missing = index_missing + ) + } + + object +} + +survfit_summary_to_tibble <- function(object, time, n_obs) { + ret <- tibble::tibble( + .row = rep(seq_len(n_obs), each = length(time)), + .time = rep(time, times = n_obs), + .pred_survival = as.vector(object$surv), + # TODO standard error + .pred_lower = as.vector(object$lower), + .pred_upper = as.vector(object$upper), + .pred_hazard_cumulative = as.vector(object$cumhaz) + # TODO standard error for cumulative hazard + ) + ret +} + +survfit_summary_to_patched_tibble <- function(object, index_missing, time, n_obs) { + object %>% + summary(times = time, extend = TRUE) %>% + survfit_summary_typestable() %>% + survfit_summary_patch_infinite_time(time = time) %>% + survfit_summary_restore_time_order(time = time) %>% + survfit_summary_patch_missings( + index_missing = index_missing, + time = time, + n_obs = n_obs + ) %>% + survfit_summary_to_tibble(time = time, n_obs = n_obs) +} + diff --git a/tests/testthat/test-aaa_survival_prob.R b/tests/testthat/test-aaa_survival_prob.R new file mode 100644 index 00000000..e1173a66 --- /dev/null +++ b/tests/testthat/test-aaa_survival_prob.R @@ -0,0 +1,256 @@ +test_that("survfit_summary_typestable() works for survival prob - unstratified (coxph)", { + lung_pred <- tidyr::drop_na(lung) + mod <- coxph(Surv(time, status) ~ ., data = lung) + + # multiple observations + surv_fit <- survfit(mod, newdata = lung_pred) + + pred_time <- c(100, 200) + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), nrow(lung_pred))) + expect_true(all(prob[1,] > prob[2,])) + + pred_time <- 100 + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), nrow(lung_pred))) + + # single observation + surv_fit <- survfit(mod, newdata = lung_pred[1,]) + + pred_time <- c(100, 200) + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), 1)) + expect_true(all(prob[1,] > prob[2,])) + + pred_time <- 100 + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), 1)) +}) + +test_that("survfit_summary_typestable() works for survival prob - stratified (coxph)", { + lung_pred <- tidyr::drop_na(lung) + mod <- coxph(Surv(time, status) ~ age + ph.ecog + strata(sex), data = lung) + + # multiple observations + surv_fit <- survfit(mod, newdata = lung_pred) + + pred_time <- c(100, 200) + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), nrow(lung_pred))) + expect_true(all(prob[1,] > prob[2,])) + + pred_time <- 100 + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), nrow(lung_pred))) + + # single observation + surv_fit <- survfit(mod, newdata = lung_pred[1,]) + + pred_time <- c(100, 200) + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), 1)) + expect_true(all(prob[1,] > prob[2,])) + + pred_time <- 100 + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), 1)) +}) + +test_that("survfit_summary_typestable() works for survival prob - unstratified (coxnet)", { + lung2 <- lung[-14, ] + lung_x = as.matrix(lung2[, c("age", "ph.ecog")]) + lung_y = Surv(lung2$time, lung2$status) + lung_pred <- lung_x[1:5, ] + + mod <- suppressWarnings( + glmnet::glmnet(x = lung_x, y = lung_y, family = "cox") + ) + + # multiple observations + surv_fit <- survfit(mod, newx = lung_pred, s = 0.1, x = lung_x, y = lung_y) + + pred_time <- c(100, 200) + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), nrow(lung_pred))) + expect_true(all(prob[1,] > prob[2,])) + + pred_time <- 100 + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), nrow(lung_pred))) + + # single observation + surv_fit <- survfit(mod, newx = lung_pred[1,, drop = FALSE], s = 0.1, x = lung_x, y = lung_y) + + pred_time <- c(100, 200) + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), 1)) + expect_true(all(prob[1,] > prob[2,])) + + pred_time <- 100 + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), 1)) +}) + +test_that("survfit_summary_typestable() works for survival prob - stratified (coxnet)", { + lung2 <- lung[-14, ] + lung_x = as.matrix(lung2[, c("age", "ph.ecog")]) + lung_y = glmnet::stratifySurv(Surv(lung2$time, lung2$status), lung2$sex) + lung_pred <- lung_x[1:5, ] + lung_pred_strata <- lung2$sex[1:5] + + mod <- suppressWarnings(glmnet::glmnet(x = lung_x, y = lung_y, family = "cox")) + + # multiple observations + surv_fit <- survfit(mod, newx = lung_pred, newstrata = lung_pred_strata, + s = 0.1, x = lung_x, y = lung_y) + + pred_time <- c(100, 200) + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), nrow(lung_pred))) + expect_true(all(prob[1,] > prob[2,])) + + pred_time <- 100 + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), nrow(lung_pred))) + + # single observation + surv_fit <- survfit(mod, newx = lung_pred[1,], newstrata = lung_pred_strata[1], + s = 0.1, x = lung_x, y = lung_y) + + pred_time <- c(100, 200) + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), 1)) + expect_true(all(prob[1,] > prob[2,])) + + pred_time <- 100 + surv_fit_summary <- summary(surv_fit, times = pred_time) %>% + survfit_summary_typestable() + prob <- surv_fit_summary$surv + expect_equal(dim(prob), c(length(pred_time), 1)) +}) + +test_that("survfit_summary_patch_infinite_time() works (coxph)", { + lung_pred <- tidyr::drop_na(lung) + pred_time <- c(-Inf, 0, Inf, 1022, -Inf) + + mod <- coxph(Surv(time, status) ~ ., data = lung) + surv_fit <- survfit(mod, newdata = lung_pred) + surv_fit_summary <- summary(surv_fit, times = pred_time, extend = TRUE) + + surv_fit_summary_patched <- surv_fit_summary %>% + survfit_summary_typestable() %>% + survfit_summary_patch_infinite_time(time = pred_time) + + prob <- surv_fit_summary_patched$surv + exp_prob <- surv_fit_summary$surv + + expect_equal(prob[c(3,4),], exp_prob) + expect_equal( + prob[c(1,2),], + matrix(1, nrow = 2, ncol = nrow(lung_pred)), + ignore_attr = "dimnames" + ) + expect_equal(unname(prob[5,]), rep(0, nrow(lung_pred))) +}) + +test_that("survfit_summary_patch_infinite_time() works (coxnet)", { + pred_time <- c(-Inf, 0, Inf, 1022, -Inf) + + lung2 <- lung[-14, ] + lung_x = as.matrix(lung2[, c("age", "ph.ecog")]) + lung_y = Surv(lung2$time, lung2$status) + lung_pred <- lung_x[1:5, ] + + mod <- suppressWarnings( + glmnet::glmnet(x = lung_x, y = lung_y, family = "cox") + ) + surv_fit <- survfit(mod, newx = lung_pred, s = 0.1, x = lung_x, y = lung_y) + surv_fit_summary <- summary(surv_fit, times = pred_time, extend = TRUE) + + surv_fit_summary_patched <- surv_fit_summary %>% + survfit_summary_typestable() %>% + survfit_summary_patch_infinite_time(time = pred_time) + + prob <- surv_fit_summary_patched$surv + exp_prob <- surv_fit_summary$surv + + expect_equal(prob[c(3,4),], exp_prob) + expect_equal( + prob[c(1,2),], + matrix(1, nrow = 2, ncol = nrow(lung_pred)), + ignore_attr = "dimnames" + ) + expect_equal(unname(prob[5,]), rep(0, nrow(lung_pred))) +}) + +test_that("survfit_summary_restore_time_order() works", { + lung_pred <- tidyr::drop_na(lung) + pred_time <- c(300, 100, 200) + + mod <- coxph(Surv(time, status) ~ ., data = lung) + surv_fit <- survfit(mod, newdata = lung_pred) + surv_fit_summary <- summary(surv_fit, times = pred_time, extend = TRUE) + + surv_fit_summary_patched <- surv_fit_summary %>% + survfit_summary_typestable() %>% + survfit_summary_patch_infinite_time(time = pred_time) %>% + survfit_summary_restore_time_order(time = pred_time) + + prob <- surv_fit_summary_patched$surv + exp_prob <- surv_fit_summary$surv + + expect_equal(prob, exp_prob[c(3,1:2),]) +}) + +test_that("survfit_summary_patch_missings() works", { + pred_time <- c(100, 200) + mod <- coxph(Surv(time, status) ~ age + ph.ecog, data = lung) + + lung_pred <- lung[13:14, ] + surv_fit <- survfit(mod, newdata = lung_pred) + surv_fit_summary <- summary(surv_fit, times = pred_time, extend = TRUE) + + surv_fit_summary_patched <- surv_fit_summary %>% + survfit_summary_typestable() %>% + survfit_summary_patch_missings( + time = pred_time, + index_missing = 2, + n_obs = 2 + ) + + prob <- surv_fit_summary_patched$surv + + expect_equal(ncol(prob), nrow(lung_pred)) + expect_equal(prob[,2], rep(NA_real_, length(pred_time))) +}) +