Skip to content

Commit

Permalink
Updates for working with multiple quantile predictions
Browse files Browse the repository at this point in the history
  • Loading branch information
herbps10 committed Jun 15, 2023
1 parent 74f7538 commit e3c167d
Show file tree
Hide file tree
Showing 6 changed files with 42 additions and 12 deletions.
9 changes: 6 additions & 3 deletions R/Lrnr_cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ Lrnr_cv <- R6Class(
.properties = c("wrapper", "cv"),
.train_sublearners = function(task) {
verbose <- getOption("sl3.verbose")

# if we get a delayed task, evaluate it
# TODO: this is a kludge -- ideally we'd have Lrnr_cv work on delayed tasks like other learners
if (inherits(task, "Delayed")) {
Expand Down Expand Up @@ -287,7 +287,7 @@ Lrnr_cv <- R6Class(

learner <- self$params$learner
ever_error <- NULL

if (inherits(learner, "Stack")) {
# if we're cross-validating a stack, check for learner errors in any
# folds and then drop for all folds
Expand Down Expand Up @@ -375,7 +375,10 @@ Lrnr_cv <- R6Class(
predictions <- aorder(preds, order(results$index, results$fold_index))

# don't convert to vector if learner is stack, as stack won't
if ((ncol(predictions) == 1) && !inherits(self$params$learner, "Stack")) {
if(class(predictions[[1]][[1]]) == "packed_predictions") {
predictions <- as.matrix(predictions)
}
else if ((ncol(predictions) == 1) && !inherits(self$params$learner, "Stack")) {
predictions <- unlist(predictions)
}
return(predictions)
Expand Down
2 changes: 1 addition & 1 deletion R/Lrnr_ga.R
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ Lrnr_ga <- R6Class(
)
)

coefs <- as.vector(GA1@bestSol[[1]])
coefs <- as.vector(GA1@bestSol[[1]][1,])
names(coefs) <- colnames(task$X)

fit_object <- list(ga_fit <- GA1)
Expand Down
32 changes: 26 additions & 6 deletions R/Lrnr_gbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,35 @@ Lrnr_gbm <- R6Class(
args$verbose <- getOption("sl3.verbose")
}

fit_object <- call_with_args(gbm::gbm.fit, args)
if(is.null(args$tau)) {
fit_object <- call_with_args(gbm::gbm.fit, args)
}
else {
fit_object <- Map(function(tau) {
args$distribution <- list(name = "quantile", alpha = tau)
call_with_args(gbm::gbm.fit, args)
}, args$tau)
}

return(fit_object)
},
.predict = function(task) {
preds <- stats::predict(
object = private$.fit_object, newdata = task$X,
n.trees = self$params$n.trees, type = "response"
)
return(preds)
if(!is.null(self$params$tau)) {
preds <- matrix(unlist(Map(function(fit_object) {
stats::predict(
object = fit_object, newdata = task$X,
n.trees = self$params$n.trees, type = "response"
)
}, private$.fit_object)), ncol = length(self$params$tau), byrow = FALSE)
return(pack_predictions(preds))
}
else {
preds <- stats::predict(
object = private$.fit_object, newdata = task$X,
n.trees = self$params$n.trees, type = "response"
)
return(preds)
}
},
.required_packages = c("gbm")
)
Expand Down
7 changes: 6 additions & 1 deletion R/Lrnr_grf.R
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,12 @@ Lrnr_grf <- R6Class(
newdata = data.frame(task$X),
quantiles = quantiles_pred
)
predictions <- as.numeric(predictions_list$predictions)
if(length(quantiles_pred) > 1) {
predictions <- pack_predictions(predictions_list$predictions)
}
else {
predictions <- as.numeric(predictions_list$predictions)
}
return(predictions)
},
.required_packages = c("grf")
Expand Down
2 changes: 2 additions & 0 deletions R/Lrnr_lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ Lrnr_lightgbm <- R6Class(
}
args$verbose <- as.integer(verbose)

args$params <- args$lightgbm_params

# set up outcome
outcome_type <- self$get_outcome_type(task)
Y <- outcome_type$format(task$Y)
Expand Down
2 changes: 1 addition & 1 deletion R/loss_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ cv_risk <- function(learner, eval_fun = NULL, coefs = NULL) {
preds <- learner$predict_fold(task, "validation")
if (!is.data.table(preds)) {
preds <- data.table::data.table(preds)
data.table::setnames(preds, names(preds), learner$name)
#data.table::setnames(preds, names(preds), learner$name)
}

get_obsdata <- function(fold, task) {
Expand Down

0 comments on commit e3c167d

Please sign in to comment.