Skip to content

Commit

Permalink
WIP split PipeOpEncodePL into two PipeOps, one for each method
Browse files Browse the repository at this point in the history
  • Loading branch information
advieser committed Jan 19, 2025
1 parent 8a5e162 commit 7d7a826
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 56 deletions.
138 changes: 82 additions & 56 deletions R/PipeOpEncodePL.R
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#' @title Factor Encoding
#' @title Piecewise Linear Encoding Base Class
#'
#' @usage NULL
#' @name mlr_pipeops_encode
#' @format [`R6Class`][R6::R6Class] object inheriting from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`].
#' @name mlr_pipeops_encodepl
#' @format Abstract [`R6Class`][R6::R6Class] object inheriting from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`].
#'
#' @description
#' Abstract base class for piecewise linear encoding.
#'
#' Encodes columns of type `numeric` and `integer`.
#'
#'
Expand Down Expand Up @@ -37,78 +39,39 @@
#' Initialized to `""`. One of:
#'
#' @section Methods:
#' Only methods inherited from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`].
#' Methods inherited from [`PipeOpTaskPreprocSimple`]/[`PipeOpTaskPreproc`]/[`PipeOp`], as well as
#' * `.get_bins(task, cols)`\cr
#' ([`Task`][mlr3::Task], `character`) -> `list` \cr
#'
#'
#' @references
#' `r format_bib("gorishniy_2022")`
#'
#' @family PipeOps
#' @family PipeOpsPLE
#' @template seealso_pipeopslist
#' @include PipeOpTaskPreproc.R
#' @export
#' @examples
#' library("mlr3")
#'
PipeOpEncodePL = R6Class("PipeOpEncodePL",
inherit = PipeOpTaskPreprocSimple,
public = list(
initialize = function(task_type, id = "encodepl", param_vals = list()) {
# NOTE: Might use different name, change assert, and conditions
assert_choice(task_type, mlr_reflections$task_types$task)
if (task_type == "TaskRegr") {
private$.tree_learner = LearnerRegrRpart$new()
} else if (task_type == "TaskClassif") {
private$.tree_learner = LearnerClassifRpart$new()
} else {
stopf("Task type %s not supported", task_type)
}

private$.encodepl_param_set = ps(
method = p_fct(levels = c("quantiles", "tree"), tags = c("train", "predict", "required")),
quantiles_numsplits = p_int(lower = 2, default = 2, tags = c("train", "predict"), depends = quote(method == "quantiles"))
)
private$.encodepl_param_set$values = list(method = "quantiles")

super$initialize(id, param_set = alist(encodepl = private$.encodepl_param_set, private$.tree_learner$param_set),
param_vals = param_vals, packages = c("stats", private$.tree_learner$packages),
initialize = function(id = "encodepl", param_set = ps(), param_vals = list()) {
super$initialize(id, param_set = param_set, param_vals = param_vals,
task_type = task_type, tags = "encode", feature_types = c("numeric", "integer"))
}
),
private = list(

.tree_learner = NULL,
.encodepl_param_set = NULL,
.get_bins = function(task, cols) {
stop("Abstract.")
},

.get_state = function(task) {
cols = private$.select_cols(task)
if (!length(cols)) {
return(task) # early exit
return(list(bins = numeric(0))) # early exit
}

pv = private$.encodepl_param_set$values
numsplits = pv$quantiles_numsplits %??% 2

if (pv$method == "quantiles") {
# TODO: check that min / max is correct here (according to paper / implementation)
bins = lapply(task$data(cols = cols), function(d) {
unique(c(min(d), stats::quantile(d, seq(1, numsplits - 1) / numsplits, na.rm = TRUE), max(d)))
})
} else {
learner = private$.tree_learner

bins = list()
for (col in cols) {
t = task$clone(deep = TRUE)$select(col)
splits = learner$train(t)$model$splits
# Get column "index" in model splits
boundaries = unname(sort(splits[, "index"]))

d = task$data(cols = col)
bins[[col]] = c(min(d), boundaries, max(d))
}
}

list(bins = bins)
list(bins = .get_bins(task, cols))
},

.transform = function(task) {
Expand All @@ -126,8 +89,6 @@ PipeOpEncodePL = R6Class("PipeOpEncodePL",
)
)

mlr_pipeops$add("encodepl", PipeOpEncodePL, list(task_type = "TaskRegr"))

# Helper function to implement piecewise linear encoding.
# * column: numeric vector
# * colname: name of `column`
Expand All @@ -149,3 +110,68 @@ encode_piecewise_linear = function(column, colname, bins) {

dt
}

#' PipeOpEncodePLQuantiles
PipeOpEncodePLQuantiles = R6Class("PipeOpEncodePLQuantiles",
inherit = PipeOpEncodePL,
public = list(
initialize = function(id = "encodeplquantiles", param_vals = list()) {
ps = ps(
numsplits = p_int(lower = 2, default = 2, tags = c("train", "predict", "required"))
)
super$initialize(id, param_set = ps, param_vals = param_vals, packages = "stats")
}
),
private = list(

.get_bins = function(task, cols) {
numsplits = self$param_set$values$numsplits %??% 2
lapply(task$data(cols = cols), function(d) {
unique(c(min(d), stats::quantile(d, seq(1, numsplits - 1) / numsplits, na.rm = TRUE), max(d)))
})
}
)
)

mlr_pipeops$add("encodeplquantiles", PipeOpEncodePLQuantiles)

#' PipeOpEncodePLTree
PipeOpEncodePLTree = R6Class("PipeOpEncodePLTree",
inherit = PipeOpEncodePL,
public = list(
initialize = function(task_type, id = "encodepltree", param_vals = list()) {
assert_choice(task_type, mlr_reflections$task_types$task)
if (task_type == "TaskRegr") {
private$.tree_learner = LearnerRegrRpart$new()
} else if (task_type == "TaskClassif") {
private$.tree_learner = LearnerClassifRpart$new()
} else {
stopf("Task type %s not supported.", task_type)
}

super$initialize(id, param_set = alist(private$.tree_learner$param_set), param_vals = param_vals,
packages = private$.tree_learner$packages, task_type = task_type)
}
),
private = list(

.tree_learner = NULL,

.get_bins = function(task, cols) {
learner = private$.tree_learner

bins = list()
for (col in cols) {
t = task$clone(deep = TRUE)$select(col)
# Get column "index" in model splits
boundaries = unname(sort(learner$train(t)$model$splits[, "index"]))
d = task$data(cols = col)
bins[[col]] = c(min(d), boundaries, max(d))
}
bins
}
)
)

# Registering with "TaskRegr", however both "TaskRegr" and "TaskClassif" are acceptable, see issue ...
mlr_pipeops$add("encodepltree", PipeOpEncodePLTree, list(task_type = "TaskRegr"))
1 change: 1 addition & 0 deletions tests/testthat/test_pipeop_encodepl.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ test_that("PipeOpEncodePL - basic properties", {
# - different methods
# - with params (not all for regtree, hopefully)
# - test on tasks with simple data that behaviour is as expected (compare dts)
# - for different task types
# - TODO: decide how to handle NAs in feature columns and test that

0 comments on commit 7d7a826

Please sign in to comment.