Skip to content

Commit

Permalink
switch to get_state/transform for easier handling of task
Browse files Browse the repository at this point in the history
  • Loading branch information
advieser committed Dec 20, 2024
1 parent 68360d7 commit 930b713
Showing 1 changed file with 25 additions and 17 deletions.
42 changes: 25 additions & 17 deletions R/PipeOpEncodePL.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ PipeOpEncodePL = R6Class("PipeOpEncodePL",
public = list(
initialize = function(id = "encodepl", param_vals = list()) {
private$.reg_tree = LearnerRegrRpart$new()
# this would only work for regr tasks, how do we handle classif tasks, esp. since we don't now task type in init?

private$.encodepl_param_set = ps(
method = p_fct(levels = c("quantiles", "regtree"), tags = c("train", "predict", "required")),
Expand All @@ -56,51 +57,58 @@ PipeOpEncodePL = R6Class("PipeOpEncodePL",
private$.encodepl_param_set$values = list(method = "quantiles")

super$initialize(id, param_set = alist(encodepl = private$.encodepl_param_set, private$.reg_tree$param_set),
param_vals = param_vals, packages = private$.reg_tree$packages, tags = "encode", feature_types = c("numeric", "integer"))
param_vals = param_vals, packages = c("stats", private$.reg_tree$packages), tags = "encode", feature_types = c("numeric", "integer"))
}
),
private = list(

.reg_tree = NULL,
.encodepl_param_set = NULL,

.get_state_dt = function(dt, levels, target) {
.get_state = function(task) {
cols = private$.select_cols(task)
# do we need early exit if there are no cols?

pv = private$.encodepl_param_set$values
numsplits = pv$quantiles_numsplits %??% 2

if (pv$method == "quantiles") {
bins = lapply(dt, function(d)
unique(c(min(d), stats::quantile(d, seq(1, numsplits - 1) / numsplits, na.rm = TRUE), max(d))))
# check that min / max is correct here (according to paper / implementation)
# TODO: check that min / max is correct here (according to paper / implementation)
bins = lapply(task$data(cols = cols), function(d) {
unique(c(min(d), stats::quantile(d, seq(1, numsplits - 1) / numsplits, na.rm = TRUE), max(d)))
})
} else {
learner = private$.reg_tree
cols = colnames(dt)

bins = list()
for (col in cols) {
t = TaskRegr$new(id = "binning", backend = dt[, ..col], target = task$target_names)
t = task$clone(deep = TRUE)$select(col)
splits = learner$train(t)$model$splits
rules = unname(sort(splits[, which(colnames(splits) == "index")]))
bins[[col]] = c(min(dt[[col]]), rules, max(dt[[col]]))
# Get column "index" in model splits
boundaries = unname(sort(splits[, which(colnames(splits) == "index")]))

d = task$data(cols = col)
bins[[col]] = c(min(d), boundaries, max(d))
}
}

list(bins = bins)
},

.transform_dt = function(dt, levels) {
.transform = function(task) {
bins = self$state$bins
cols = names(bins)
if (!length(cols)) {
return(task) # early exit
}

cols = colnames(dt)

dt = data.table()
for (col in cols) {
dt = cbind(dt, ple(dt[, ..col], bins[[col]]))
# do name checking ...
dt = cbind(dt, ple(task$data(cols = col), bins[[col]]))
}

# Drop old columns
dt[, (cols) := NULL]
dt
# TODO: handle name colision
task$cbind(dt)
}
)
)
Expand Down

0 comments on commit 930b713

Please sign in to comment.