From e3c167d8c8a557028442c2f281d1c953347bc123 Mon Sep 17 00:00:00 2001 From: Herb Susmann Date: Thu, 15 Jun 2023 21:36:09 +0200 Subject: [PATCH] Updates for working with multiple quantile predictions --- R/Lrnr_cv.R | 9 ++++++--- R/Lrnr_ga.R | 2 +- R/Lrnr_gbm.R | 32 ++++++++++++++++++++++++++------ R/Lrnr_grf.R | 7 ++++++- R/Lrnr_lightgbm.R | 2 ++ R/loss_functions.R | 2 +- 6 files changed, 42 insertions(+), 12 deletions(-) diff --git a/R/Lrnr_cv.R b/R/Lrnr_cv.R index 9aa85837..8dccc2ef 100644 --- a/R/Lrnr_cv.R +++ b/R/Lrnr_cv.R @@ -234,7 +234,7 @@ Lrnr_cv <- R6Class( .properties = c("wrapper", "cv"), .train_sublearners = function(task) { verbose <- getOption("sl3.verbose") - + # if we get a delayed task, evaluate it # TODO: this is a kludge -- ideally we'd have Lrnr_cv work on delayed tasks like other learners if (inherits(task, "Delayed")) { @@ -287,7 +287,7 @@ Lrnr_cv <- R6Class( learner <- self$params$learner ever_error <- NULL - + if (inherits(learner, "Stack")) { # if we're cross-validating a stack, check for learner errors in any # folds and then drop for all folds @@ -375,7 +375,10 @@ Lrnr_cv <- R6Class( predictions <- aorder(preds, order(results$index, results$fold_index)) # don't convert to vector if learner is stack, as stack won't - if ((ncol(predictions) == 1) && !inherits(self$params$learner, "Stack")) { + if(class(predictions[[1]][[1]]) == "packed_predictions") { + predictions <- as.matrix(predictions) + } + else if ((ncol(predictions) == 1) && !inherits(self$params$learner, "Stack")) { predictions <- unlist(predictions) } return(predictions) diff --git a/R/Lrnr_ga.R b/R/Lrnr_ga.R index c0f484a6..17219dba 100644 --- a/R/Lrnr_ga.R +++ b/R/Lrnr_ga.R @@ -161,7 +161,7 @@ Lrnr_ga <- R6Class( ) ) - coefs <- as.vector(GA1@bestSol[[1]]) + coefs <- as.vector(GA1@bestSol[[1]][1,]) names(coefs) <- colnames(task$X) fit_object <- list(ga_fit <- GA1) diff --git a/R/Lrnr_gbm.R b/R/Lrnr_gbm.R index 3457f02d..996af60c 100644 --- a/R/Lrnr_gbm.R +++ b/R/Lrnr_gbm.R @@ -101,15 +101,35 @@ Lrnr_gbm <- R6Class( args$verbose <- getOption("sl3.verbose") } - fit_object <- call_with_args(gbm::gbm.fit, args) + if(is.null(args$tau)) { + fit_object <- call_with_args(gbm::gbm.fit, args) + } + else { + fit_object <- Map(function(tau) { + args$distribution <- list(name = "quantile", alpha = tau) + call_with_args(gbm::gbm.fit, args) + }, args$tau) + } + return(fit_object) }, .predict = function(task) { - preds <- stats::predict( - object = private$.fit_object, newdata = task$X, - n.trees = self$params$n.trees, type = "response" - ) - return(preds) + if(!is.null(self$params$tau)) { + preds <- matrix(unlist(Map(function(fit_object) { + stats::predict( + object = fit_object, newdata = task$X, + n.trees = self$params$n.trees, type = "response" + ) + }, private$.fit_object)), ncol = length(self$params$tau), byrow = FALSE) + return(pack_predictions(preds)) + } + else { + preds <- stats::predict( + object = private$.fit_object, newdata = task$X, + n.trees = self$params$n.trees, type = "response" + ) + return(preds) + } }, .required_packages = c("gbm") ) diff --git a/R/Lrnr_grf.R b/R/Lrnr_grf.R index a7633b7e..c3724f57 100644 --- a/R/Lrnr_grf.R +++ b/R/Lrnr_grf.R @@ -128,7 +128,12 @@ Lrnr_grf <- R6Class( newdata = data.frame(task$X), quantiles = quantiles_pred ) - predictions <- as.numeric(predictions_list$predictions) + if(length(quantiles_pred) > 1) { + predictions <- pack_predictions(predictions_list$predictions) + } + else { + predictions <- as.numeric(predictions_list$predictions) + } return(predictions) }, .required_packages = c("grf") diff --git a/R/Lrnr_lightgbm.R b/R/Lrnr_lightgbm.R index 0b6aa556..5ae515ac 100644 --- a/R/Lrnr_lightgbm.R +++ b/R/Lrnr_lightgbm.R @@ -94,6 +94,8 @@ Lrnr_lightgbm <- R6Class( } args$verbose <- as.integer(verbose) + args$params <- args$lightgbm_params + # set up outcome outcome_type <- self$get_outcome_type(task) Y <- outcome_type$format(task$Y) diff --git a/R/loss_functions.R b/R/loss_functions.R index 168ecc8a..3bda53a8 100644 --- a/R/loss_functions.R +++ b/R/loss_functions.R @@ -198,7 +198,7 @@ cv_risk <- function(learner, eval_fun = NULL, coefs = NULL) { preds <- learner$predict_fold(task, "validation") if (!is.data.table(preds)) { preds <- data.table::data.table(preds) - data.table::setnames(preds, names(preds), learner$name) + #data.table::setnames(preds, names(preds), learner$name) } get_obsdata <- function(fold, task) {