From 66320400513af2235a686283a691a00e3d8176b6 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Fri, 11 Aug 2023 17:12:18 -0700 Subject: [PATCH 01/10] Update Lrnr_cv.R Previously, the following code, which contains a Stack of one learner, outputted a data.table containing NULLs with some entries a list of some subset of predictions. This bug is fixed here by replacing data.table(preds) with as.data.table(preds) n <- 500 W <- runif(n, -1 , 1) Y <- rbinom(n, 1, plogis(W)) task <- sl3_Task$new(data.table(W,Y), covariates = "W", outcome = "Y") Lrnr_cv$new(Stack$new(Lrnr_glm$new() ))$train(task)$predict(task) --- R/Lrnr_cv.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/Lrnr_cv.R b/R/Lrnr_cv.R index c626278e..2a27c755 100644 --- a/R/Lrnr_cv.R +++ b/R/Lrnr_cv.R @@ -371,7 +371,7 @@ Lrnr_cv <- R6Class( list( index = index, fold_index = rep(fold_index(), length(index)), - predictions = data.table(predictions) + predictions = as.data.table(predictions) ) } From 140fa1559dd441ec79756a63460db3266aea9f84 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Fri, 11 Aug 2023 17:24:29 -0700 Subject: [PATCH 02/10] Update Lrnr_base.R --- R/Lrnr_base.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/Lrnr_base.R b/R/Lrnr_base.R index 101b8eae..04e28f9e 100644 --- a/R/Lrnr_base.R +++ b/R/Lrnr_base.R @@ -191,7 +191,7 @@ Lrnr_base <- R6Class( ncols <- ncol(predictions) if (!is.null(ncols) && (ncols == 1)) { - predictions <- as.vector(predictions) + predictions <- as.vector(unlist(predictions)) } return(predictions) }, From 74db1f39531eeb628cbc0f0b1a722276912c2956 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Fri, 11 Aug 2023 17:56:58 -0700 Subject: [PATCH 03/10] Update survival_utils.R --- R/survival_utils.R | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/R/survival_utils.R b/R/survival_utils.R index dd11429c..eb36e1ef 100644 --- a/R/survival_utils.R +++ b/R/survival_utils.R @@ -32,11 +32,16 @@ pooled_hazard_task <- function(task, trim = TRUE) { repeated_data <- underlying_data[index, ] new_folds <- origami::id_folds_to_folds(task$folds, index) - repeated_task <- task$next_in_chain( - column_names = column_names, - data = repeated_data, id = "id", - folds = new_folds - ) + nodes <- task$nodes + nodes$id <- "id" + repeated_task <- sl3_Task$new(repeated_data, column_names = column_names, nodes = task$nodes, folds = new_folds, outcome_levels = outcome_levels, outcome_type = task$outcome_type$type) + + # The below errors when used in CV due to the stored row index not being reset in next_in_chain. + #repeated_task <- task$next_in_chain( + # column_names = column_names, + #data = repeated_data, id = "id", + #folds = new_folds + #) # make bin indicators bin_number <- rep(level_index, each = task$nrow) From 31993b70e3461f404e1c461238e190174553095c Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Fri, 11 Aug 2023 18:11:07 -0700 Subject: [PATCH 04/10] Update Lrnr_base.R --- R/Lrnr_base.R | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/R/Lrnr_base.R b/R/Lrnr_base.R index 04e28f9e..9eb2df0b 100644 --- a/R/Lrnr_base.R +++ b/R/Lrnr_base.R @@ -177,21 +177,25 @@ Lrnr_base <- R6Class( )) } }, - base_predict = function(task = NULL) { + base_predict = function(task = NULL) { self$assert_trained() if (is.null(task)) { task <- private$.training_task } - + assert_that(is(task, "sl3_Task")) task <- self$subset_covariates(task) task <- self$process_formula(task) - + predictions <- private$.predict(task) - ncols <- ncol(predictions) - if (!is.null(ncols) && (ncols == 1)) { - predictions <- as.vector(unlist(predictions)) + if(inherits(predictions, "packed_predictions")) { + # if packed and data.table, as.vector(predictions) retains list structure. + if(is.data.table(predictions)) predictions <- as.vector(predictions) + return(predictions) + } else if(!is.null(ncols) && (ncols == 1)) { + # otherwise return vector. + predictions <- as.vector(unlist(predictions) } return(predictions) }, From 5809b335ffbe0bafe6b50fefbd0099c519c56880 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Fri, 11 Aug 2023 18:12:35 -0700 Subject: [PATCH 05/10] Update Lrnr_base.R --- R/Lrnr_base.R | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/R/Lrnr_base.R b/R/Lrnr_base.R index 9eb2df0b..bf804afd 100644 --- a/R/Lrnr_base.R +++ b/R/Lrnr_base.R @@ -189,16 +189,16 @@ Lrnr_base <- R6Class( predictions <- private$.predict(task) ncols <- ncol(predictions) - if(inherits(predictions, "packed_predictions")) { + if(inherits(predictions, "packed_predictions") & is.data.table(predictions)) { # if packed and data.table, as.vector(predictions) retains list structure. - if(is.data.table(predictions)) predictions <- as.vector(predictions) - return(predictions) + predictions <- as.vector(predictions) } else if(!is.null(ncols) && (ncols == 1)) { # otherwise return vector. - predictions <- as.vector(unlist(predictions) + predictions <- as.vector(unlist(predictions)) } return(predictions) }, + }, base_chain = function(task = NULL) { self$assert_trained() if (is.null(task)) { From 385519d216c96c4906d4445b386cbd3d67259cd2 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Fri, 11 Aug 2023 18:16:29 -0700 Subject: [PATCH 06/10] Update Lrnr_base.R --- R/Lrnr_base.R | 1 - 1 file changed, 1 deletion(-) diff --git a/R/Lrnr_base.R b/R/Lrnr_base.R index bf804afd..0c367863 100644 --- a/R/Lrnr_base.R +++ b/R/Lrnr_base.R @@ -198,7 +198,6 @@ Lrnr_base <- R6Class( } return(predictions) }, - }, base_chain = function(task = NULL) { self$assert_trained() if (is.null(task)) { From f5b4e49f1539cf822f2881dc604cdfc72807e34f Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Fri, 11 Aug 2023 18:32:24 -0700 Subject: [PATCH 07/10] Update Lrnr_base.R --- R/Lrnr_base.R | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/R/Lrnr_base.R b/R/Lrnr_base.R index 0c367863..ebe07a02 100644 --- a/R/Lrnr_base.R +++ b/R/Lrnr_base.R @@ -189,12 +189,15 @@ Lrnr_base <- R6Class( predictions <- private$.predict(task) ncols <- ncol(predictions) - if(inherits(predictions, "packed_predictions") & is.data.table(predictions)) { - # if packed and data.table, as.vector(predictions) retains list structure. - predictions <- as.vector(predictions) - } else if(!is.null(ncols) && (ncols == 1)) { - # otherwise return vector. - predictions <- as.vector(unlist(predictions)) + if(!is.null(ncols) && (ncols == 1)) { + if(is.data.table(predictions)) { + # if a data.table of packed predictions, return a matrix. + predictions <- as.matrix(predictions) + } + # if not packed predictions, return vector + if(!inherits(predictions[[1]], "packed_predictions")) { + predictions <- unlist(predictions) + } } return(predictions) }, From 3c1c25e93c734f9e775914a72aa6f537114bc182 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Fri, 11 Aug 2023 19:31:03 -0700 Subject: [PATCH 08/10] Update Lrnr_cv.R --- R/Lrnr_cv.R | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/R/Lrnr_cv.R b/R/Lrnr_cv.R index 2a27c755..4e058596 100644 --- a/R/Lrnr_cv.R +++ b/R/Lrnr_cv.R @@ -392,9 +392,14 @@ Lrnr_cv <- R6Class( predictions <- aorder(preds, order(results$index, results$fold_index)) - # don't convert to vector if learner is stack, as stack won't + + # don't convert to vector if learner is stack, as stack won't if ((ncol(predictions) == 1) && !inherits(self$params$learner, "Stack")) { - predictions <- unlist(predictions) + # if packed_predictions dont unlist + if(is.data.table(predictions)) predictions <- as.matrix(predictions) + if(!inherits(predictions[[1]], "packed_predictions")) { + predictions <- as.vector(predictions) + } } return(predictions) }, From aaf8024ddef256e12c7f5a7ab08f6599e2fe15c2 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Sat, 12 Aug 2023 08:36:51 -0700 Subject: [PATCH 09/10] Update Lrnr_gam.R Added weights for GAM --- R/Lrnr_gam.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/R/Lrnr_gam.R b/R/Lrnr_gam.R index c52ec44e..24515ca3 100644 --- a/R/Lrnr_gam.R +++ b/R/Lrnr_gam.R @@ -77,7 +77,7 @@ Lrnr_gam <- R6Class( } ), private = list( - .properties = c("continuous", "binomial"), + .properties = c("continuous", "binomial", "weights"), .train = function(task) { # load args args <- self$params @@ -87,6 +87,7 @@ Lrnr_gam <- R6Class( Y <- data.frame(outcome_type$format(task$Y)) colnames(Y) <- task$nodes$outcome args$data <- cbind(task$X, Y) + args$weights <- task$weights ## family if (is.null(args$family)) { if (outcome_type$type == "continuous") { From b75bdb74d612d3b94860d7fa50dff1b1b8b3dd53 Mon Sep 17 00:00:00 2001 From: Lars van der Laan Date: Fri, 18 Aug 2023 09:28:48 -0700 Subject: [PATCH 10/10] Update Lrnr_cv.R --- R/Lrnr_cv.R | 2 ++ 1 file changed, 2 insertions(+) diff --git a/R/Lrnr_cv.R b/R/Lrnr_cv.R index 4e058596..282f8d9e 100644 --- a/R/Lrnr_cv.R +++ b/R/Lrnr_cv.R @@ -174,6 +174,8 @@ Lrnr_cv <- R6Class( predictions <- self$predict_fold(revere_task, fold_number) + # This might not be a matrix + predictions <- as.data.table(predictions) # TODO: make same fixes made to chain here if (nrow(revere_task$data) != nrow(predictions)) { # Gather validation indexes: