Skip to content

Commit

Permalink
preparing
Browse files Browse the repository at this point in the history
  • Loading branch information
mb706 committed Jan 13, 2024
1 parent d223981 commit 18f8d3e
Show file tree
Hide file tree
Showing 6 changed files with 31 additions and 12 deletions.
4 changes: 2 additions & 2 deletions R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ PipeOp = R6Class("PipeOp",
id = function(val) {
if (!missing(val)) {
private$.id = val
if (!is.null(private$.param_set)) {
if (paradox_info$is_old && !is.null(private$.param_set)) {
# private$.param_set may be NULL if it is constructed dynamically by active binding
private$.param_set$set_id = val
}
Expand All @@ -353,7 +353,7 @@ PipeOp = R6Class("PipeOp",
} else {
private$.param_set = sourcelist[[1]]
}
if (!is.null(self$id)) {
if (paradox_info$is_old && !is.null(self$id)) {
private$.param_set$set_id = self$id
}
}
Expand Down
16 changes: 12 additions & 4 deletions R/PipeOpFilter.R
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,24 @@ PipeOpFilter = R6Class("PipeOpFilter",
initialize = function(filter, id = filter$id, param_vals = list()) {
assert_class(filter, "Filter")
self$filter = filter$clone(deep = TRUE)
self$filter$param_set$set_id = ""
map(self$filter$param_set$params, function(p) p$tags = union(p$tags, "train"))
if (paradox_info$is_old) {
self$filter$param_set$set_id = ""
map(self$filter$param_set$params, function(p) p$tags = union(p$tags, "train"))
} else {
for (pn in self$filter$param_set$ids()) {
self$filter$param_set$tags[[pn]] = union(self$filter$param_set$tags[[pn]] , "train")
}
}
private$.outer_param_set = ParamSet$new(list(
ParamInt$new("nfeat", lower = 0, tags = "train"),
ParamDbl$new("frac", lower = 0, upper = 1, tags = "train"),
ParamDbl$new("cutoff", tags = "train"),
ParamInt$new("permuted", lower = 1, tags = "train")
))
private$.outer_param_set$set_id = "filter"
super$initialize(id, alist(private$.outer_param_set, self$filter$param_set), param_vals = param_vals, tags = "feature selection")
if (paradox_info$is_old) {
private$.outer_param_set$set_id = "filter"
}
super$initialize(id, alist(filter = private$.outer_param_set, self$filter$param_set), param_vals = param_vals, tags = "feature selection")
}
),
private = list(
Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpImputeLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ PipeOpImputeLearner = R6Class("PipeOpImputeLearner",
public = list(
initialize = function(learner, id = "imputelearner", param_vals = list()) {
private$.learner = as_learner(learner, clone = TRUE)
private$.learner$param_set$set_id = ""
if (paradox_info$is_old) {
private$.learner$param_set$set_id = ""
}
id = id %??% private$.learner$id
feature_types = switch(private$.learner$task_type,
regr = c("integer", "numeric"),
Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,9 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
id = function(val) {
if (!missing(val)) {
private$.id = val
private$.learner$param_set$set_id = val
if (paradox_info$is_old) {
private$.learner$param_set$set_id = val
}
}
private$.id
},
Expand Down
10 changes: 7 additions & 3 deletions R/PipeOpLearnerCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
public = list(
initialize = function(learner, id = NULL, param_vals = list()) {
private$.learner = as_learner(learner, clone = TRUE)
private$.learner$param_set$set_id = ""
if (paradox_info$is_old) {
private$.learner$param_set$set_id = ""
}
id = id %??% private$.learner$id
# FIXME: can be changed when mlr-org/mlr3#470 has an answer
type = private$.learner$task_type
Expand All @@ -128,7 +130,9 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
ParamLgl$new("keep_response", tags = c("train", "required"))
))
private$.crossval_param_set$values = list(method = "cv", folds = 3, keep_response = FALSE)
private$.crossval_param_set$set_id = "resampling"
if (paradox_info$is_old) {
private$.crossval_param_set$set_id = "resampling"
}
# Dependencies in paradox have been broken from the start and this is known since at least a year:
# https://github.com/mlr-org/paradox/issues/216
# The following would make it _impossible_ to set "method" to "insample", because then "folds"
Expand All @@ -137,7 +141,7 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
# in PipeOp ParamSets.
# private$.crossval_param_set$add_dep("folds", "method", CondEqual$new("cv")) # don't do this.

super$initialize(id, alist(private$.crossval_param_set, private$.learner$param_set), param_vals = param_vals, can_subset_cols = TRUE, task_type = task_type, tags = c("learner", "ensemble"))
super$initialize(id, alist(resampling = private$.crossval_param_set, private$.learner$param_set), param_vals = param_vals, can_subset_cols = TRUE, task_type = task_type, tags = c("learner", "ensemble"))
}

),
Expand Down
5 changes: 4 additions & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ register_mlr3 = function() {
"multiplicity")))
}

paradox_info <- list2env(list(is_old = FALSE), parent = emptyenv())

.onLoad = function(libname, pkgname) { # nocov start
register_mlr3()
setHook(packageEvent("mlr3", "onLoad"), function(...) register_mlr3(), action = "append")
Expand All @@ -27,6 +29,7 @@ register_mlr3 = function() {
if (Sys.getenv("IN_PKGDOWN") == "true") {
lg$set_threshold("warn")
}
paradox_info$is_old = !is.null(ps()$set_id)
} # nocov end

.onUnload = function(libpath) { # nocov start
Expand All @@ -39,4 +42,4 @@ register_mlr3 = function() {
# static code checks should not complain about commonly used data.table columns
utils::globalVariables(c("src_id", "dst_id", "name", "op.id", "response", "truth"))

leanify_package()
# leanify_package()

0 comments on commit 18f8d3e

Please sign in to comment.