From 7d7a826522239f2ab99a449973db3858d9f6f0c8 Mon Sep 17 00:00:00 2001 From: advieser Date: Sun, 19 Jan 2025 17:46:11 +0100 Subject: [PATCH] WIP split PipeOpEncodePL into two PipeOps, one for each method --- R/PipeOpEncodePL.R | 138 +++++++++++++++----------- tests/testthat/test_pipeop_encodepl.R | 1 + 2 files changed, 83 insertions(+), 56 deletions(-) diff --git a/R/PipeOpEncodePL.R b/R/PipeOpEncodePL.R index 4557466c3..d8e1201a3 100644 --- a/R/PipeOpEncodePL.R +++ b/R/PipeOpEncodePL.R @@ -1,10 +1,12 @@ -#' @title Factor Encoding +#' @title Piecewise Linear Encoding Base Class #' #' @usage NULL -#' @name mlr_pipeops_encode -#' @format [`R6Class`][R6::R6Class] object inheriting from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`]. +#' @name mlr_pipeops_encodepl +#' @format Abstract [`R6Class`][R6::R6Class] object inheriting from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`]. #' #' @description +#' Abstract base class for piecewise linear encoding. +#' #' Encodes columns of type `numeric` and `integer`. #' #' @@ -37,78 +39,39 @@ #' Initialized to `""`. One of: #' #' @section Methods: -#' Only methods inherited from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`]. +#' Methods inherited from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`], as well as +#' * `.get_bins(task, cols)`\cr +#' ([`Task`][mlr3::Task], `character`) -> `list` \cr +#' #' #' @references #' `r format_bib("gorishniy_2022")` #' #' @family PipeOps +#' @family PipeOpsPLE #' @template seealso_pipeopslist #' @include PipeOpTaskPreproc.R #' @export -#' @examples -#' library("mlr3") -#' PipeOpEncodePL = R6Class("PipeOpEncodePL", inherit = PipeOpTaskPreprocSimple, public = list( - initialize = function(task_type, id = "encodepl", param_vals = list()) { - # NOTE: Might use different name, change assert, and conditions - assert_choice(task_type, mlr_reflections$task_types$task) - if (task_type == "TaskRegr") { - private$.tree_learner = LearnerRegrRpart$new() - } else if (task_type == "TaskClassif") { - private$.tree_learner = LearnerClassifRpart$new() - } else { - stopf("Task type %s not supported", task_type) - } - - private$.encodepl_param_set = ps( - method = p_fct(levels = c("quantiles", "tree"), tags = c("train", "predict", "required")), - quantiles_numsplits = p_int(lower = 2, default = 2, tags = c("train", "predict"), depends = quote(method == "quantiles")) - ) - private$.encodepl_param_set$values = list(method = "quantiles") - - super$initialize(id, param_set = alist(encodepl = private$.encodepl_param_set, private$.tree_learner$param_set), - param_vals = param_vals, packages = c("stats", private$.tree_learner$packages), + initialize = function(id = "encodepl", param_set = ps(), param_vals = list()) { + super$initialize(id, param_set = param_set, param_vals = param_vals, task_type = task_type, tags = "encode", feature_types = c("numeric", "integer")) } ), private = list( - .tree_learner = NULL, - .encodepl_param_set = NULL, + .get_bins = function(task, cols) { + stop("Abstract.") + }, .get_state = function(task) { cols = private$.select_cols(task) if (!length(cols)) { - return(task) # early exit + return(list(bins = numeric(0))) # early exit } - - pv = private$.encodepl_param_set$values - numsplits = pv$quantiles_numsplits %??% 2 - - if (pv$method == "quantiles") { - # TODO: check that min / max is correct here (according to paper / implementation) - bins = lapply(task$data(cols = cols), function(d) { - unique(c(min(d), stats::quantile(d, seq(1, numsplits - 1) / numsplits, na.rm = TRUE), max(d))) - }) - } else { - learner = private$.tree_learner - - bins = list() - for (col in cols) { - t = task$clone(deep = TRUE)$select(col) - splits = learner$train(t)$model$splits - # Get column "index" in model splits - boundaries = unname(sort(splits[, "index"])) - - d = task$data(cols = col) - bins[[col]] = c(min(d), boundaries, max(d)) - } - } - - list(bins = bins) + list(bins = .get_bins(task, cols)) }, .transform = function(task) { @@ -126,8 +89,6 @@ PipeOpEncodePL = R6Class("PipeOpEncodePL", ) ) -mlr_pipeops$add("encodepl", PipeOpEncodePL, list(task_type = "TaskRegr")) - # Helper function to implement piecewise linear encoding. # * column: numeric vector # * colname: name of `column` @@ -149,3 +110,68 @@ encode_piecewise_linear = function(column, colname, bins) { dt } + +#' PipeOpEncodePLQuantiles +PipeOpEncodePLQuantiles = R6Class("PipeOpEncodePLQuantiles", + inherit = PipeOpEncodePL, + public = list( + initialize = function(id = "encodeplquantiles", param_vals = list()) { + ps = ps( + numsplits = p_int(lower = 2, default = 2, tags = c("train", "predict", "required")) + ) + super$initialize(id, param_set = ps, param_vals = param_vals, packages = "stats") + } + ), + private = list( + + .get_bins = function(task, cols) { + numsplits = self$param_set$values$numsplits %??% 2 + lapply(task$data(cols = cols), function(d) { + unique(c(min(d), stats::quantile(d, seq(1, numsplits - 1) / numsplits, na.rm = TRUE), max(d))) + }) + } + ) +) + +mlr_pipeops$add("encodeplquantiles", PipeOpEncodePLQuantiles) + +#' PipeOpEncodePLTree +PipeOpEncodePLTree = R6Class("PipeOpEncodePLTree", + inherit = PipeOpEncodePL, + public = list( + initialize = function(task_type, id = "encodepltree", param_vals = list()) { + assert_choice(task_type, mlr_reflections$task_types$task) + if (task_type == "TaskRegr") { + private$.tree_learner = LearnerRegrRpart$new() + } else if (task_type == "TaskClassif") { + private$.tree_learner = LearnerClassifRpart$new() + } else { + stopf("Task type %s not supported.", task_type) + } + + super$initialize(id, param_set = alist(private$.tree_learner$param_set), param_vals = param_vals, + packages = private$.tree_learner$packages, task_type = task_type) + } + ), + private = list( + + .tree_learner = NULL, + + .get_bins = function(task, cols) { + learner = private$.tree_learner + + bins = list() + for (col in cols) { + t = task$clone(deep = TRUE)$select(col) + # Get column "index" in model splits + boundaries = unname(sort(learner$train(t)$model$splits[, "index"])) + d = task$data(cols = col) + bins[[col]] = c(min(d), boundaries, max(d)) + } + bins + } + ) +) + +# Registering with "TaskRegr", however both "TaskRegr" and "TaskClassif" are acceptable, see issue ... +mlr_pipeops$add("encodepltree", PipeOpEncodePLTree, list(task_type = "TaskRegr")) diff --git a/tests/testthat/test_pipeop_encodepl.R b/tests/testthat/test_pipeop_encodepl.R index ceb52e064..30ce1acfd 100644 --- a/tests/testthat/test_pipeop_encodepl.R +++ b/tests/testthat/test_pipeop_encodepl.R @@ -13,4 +13,5 @@ test_that("PipeOpEncodePL - basic properties", { # - different methods # - with params (not all for regtree, hopefully) # - test on tasks with simple data that behaviour is as expected (compare dts) +# - for different task types # - TODO: decide how to handle NAs in feature columns and test that