Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jun 25, 2024
1 parent 0170eac commit 4c06919
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 43 deletions.
70 changes: 31 additions & 39 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
#' How to construct the validation data. This also has to be configured in the individual learners wrapped by
#' `PipeOpLearner`, see [`set_validate.GraphLearner`] on how to configure this.
#' For more details on the possible values, see [`mlr3::Learner`].
#' * `marshaled` :: `logical(1)`\cr
#' Whether the learner is marshaled.
#'
Expand Down Expand Up @@ -119,8 +120,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
assert_subset(task_type, mlr_reflections$task_types$type)

private$.can_validate = some(learner_wrapping_pipeops(graph), function(po) "validation" %in% po$learner$properties)
private$.can_internal_tuning = some(learner_wrapping_pipeops(graph), function(po) "internal_tuning" %in% po$learner$properties)
private$.can_validate = some(graph$pipeops, function(po) "validation" %in% po$properties)
private$.can_validate = some(graph$pipeops, function(po) "internal_tuning" %in% po$properties)

properties = setdiff(mlr_reflections$learner_properties[[task_type]],
c("validation", "internal_tuning")[!c(private$.can_validate, private$.can_internal_tuning)])
Expand All @@ -139,6 +140,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
if (!is.null(predict_type)) self$predict_type = predict_type
},
base_learner = function(recursive = Inf) {
self$base_pipeop(recursive = recursive)$learner_model
},
base_pipeop = function(recursive = Inf) {
assert(check_numeric(recursive, lower = Inf), check_int(recursive))
if (recursive <= 0) return(self)
gm = self$graph_model
Expand All @@ -158,7 +162,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
if (length(last_pipeop_id) > 1) stop("Graph has no unique PipeOp containing a Learner")
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
}
learner_model$base_learner(recursive - 1)
learner_model$base_pipeop(recursive - 1)
},
marshal = function(...) {
learner_marshal(.learner = self, ...)
Expand All @@ -179,7 +183,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
validate = function(rhs) {
if (!missing(rhs)) {
if (!private$.can_validate) {
stopf("None of the Learners wrapped by GraphLearner '%s' support validation.", self$id)
stopf("None of the PipeOps in Graph '%s' supports validation.", self$id)
}
private$.validate = assert_validate(rhs)
}
Expand Down Expand Up @@ -232,30 +236,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
.can_internal_tuning = NULL,
.extract_internal_tuned_values = function() {
if (!private$.can_validate) return(NULL)
itvs = unlist(map(
learner_wrapping_pipeops(self$graph_model), function(po) {
if (exists("internal_tuned_values", po$learner)) {
po$learner_model$internal_tuned_values
}
}
), recursive = FALSE)
if (is.null(itvs) || !length(itvs)) return(named_list())
itvs = unlist(map(pos_with_property(self, "internal_tuning"), "internal_tuned_values"), recursive = FALSE)
if (!length(itvs)) return(named_list())
itvs
},
.extract_internal_valid_scores = function() {
if (!private$.can_internal_tuning) return(NULL)
ivs = unlist(map(
learner_wrapping_pipeops(self$graph_model), function(po) {
if (exists("internal_valid_scores", po$learner)) {
po$learner_model$internal_valid_scores
}
}
), recursive = FALSE)
its = unlist(map(pos_with_property(self, "validation"), "internal_valid_scores"), recursive = FALSE)
if (is.null(ivs) || !length(ivs)) return(named_list())
ivs
},
deep_clone = function(name, value) {
private$.param_set = NULL
# FIXME this repairs the mlr3::Learner deep_clone() method which is broken.
if (is.environment(value) && !is.null(value[[".__enclos_env__"]])) {
return(value$clone(deep = TRUE))
Expand All @@ -268,17 +259,10 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,

.train = function(task) {
if (!is.null(get0("validate", self))) {
some_pipeops_validate = some(learner_wrapping_pipeops(self), function(po) !is.null(get0("validate", po$learner)))
some_pipeops_validate = some(pos_with_property(self, "validation"), function(po) !is.null(po$validate))
if (!some_pipeops_validate) {
lg$warn("GraphLearner '%s' specifies a validation set, but none of its Learners use it.", self$id)
}
} else {
# otherwise the pipeops will preprocess this unnecessarily
if (!is.null(task$internal_valid_task)) {
prev_itv = task$internal_valid_task
on.exit({task$internal_valid_task = prev_itv}, add = TRUE)
task$internal_valid_task = NULL
}
}

on.exit({self$graph$state = NULL})
Expand Down Expand Up @@ -350,6 +334,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
#' For which pipeops to enable validation.
#' This parameter is ignored when `validate` is set to `NULL`.
#' By default, validation is enabled for the base learner.
#' @param args_all (`list()`)\cr
#' Rarely needed. A named list of parameter values that are passed to all subsequet [`set_validate()`] calls on the individual
#' `PipeOp`s.
#' @param args (named `list()`)\cr
#' Rarely needed.
#' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the respective learners.
Expand All @@ -376,31 +363,35 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
#' glrn$validate
#' glrn$graph$pipeops$classif.debug$learner$validate
#' glrn$graph$pipeops$final$learner$validate
set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(), ...) {
set_validate.GraphLearner = function(learner, validate, ids = NULL, args_all = list(), args = list(), ...) {
if (is.null(validate)) {
learner$validate = NULL
walk(learner_wrapping_pipeops(learner), function(po) {
po$learner$validate = NULL
walk(pos_with_property(learner$graph$pipeops, "validation"), function(po) {
# disabling needs no extra arguments
invoke(set_validate, po, validate = NULL, args_all = args_all, args = args[[po$id]] %??% list())
})
return(invisible(learner))
}

if (is.null(ids)) {
ids = base_pipeop(learner)$id
ids = learner$base_pipeop(recursive = 1)$id
} else {
assert_subset(ids, ids(keep(learner_wrapping_pipeops(learner), function(po) "validation" %in% po$learner$properties)))
assert_subset(ids, ids(pos_with_property(learner$graph$pipeops, "validation")))
}

assert_list(args, types = "list")
assert_list(args_all, types = "list")
assert_subset(names(args), ids)

prev_validate_pos = discard(map(learner_wrapping_pipeops(learner), function(po) get0("validate", po$learner, ifnotfound = NA)),
function(x) identical(x, NA))

prev_validate_pos = map(pos_with_property(learner$graph$pipeops, "validation"), "validate")
prev_validate = learner$validate

on.exit({
iwalk(prev_validate_pos, function(val, poid) learner$graph$pipeops[[poid]]$learner$validate = val)
iwalk(prev_validate_pos, function(val, poid) {
# passing the args here is just a heuristic that can in principle fail, but this should be extremely
# rare
args = args[[poid]] %??% list()
set_validate(learner$graph$pipeops[[poid]], validate = val, args = args, args_all = args_all)
})
learner$validate = prev_validate
}, add = TRUE)

Expand All @@ -409,7 +400,8 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
walk(ids, function(poid) {
# learner might be another GraphLearner / AutoTuner so we call into set_validate() again
withCallingHandlers({
invoke(set_validate, learner = learner$graph$pipeops[[poid]]$learner, .args = insert_named(list(validate = "predefined"), args[[poid]]))
args = c(args[[poid]], args_all) %??% list()
set_validate(learner$graph$pipeops[[poid]], .args = insert_named(list(validate = "predefined"), args))
}, error = function(e) {
e$message = sprintf("Failed to set validate for PipeOp '%s':\n%s", poid, e$message)
stop(e)
Expand Down
15 changes: 13 additions & 2 deletions R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,16 @@
#' and done, if requested, by the [`Graph`] backend itself; it should *not* be done explicitly by `private$.train()` or `private$.predict()`.
#' * `man` :: `character(1)`\cr
#' Identifying string of the help page that shows with `help()`.
#' * `properties` :: `character()`\cr
#' The properties of the pipeop.
#' Currently supported values are:
#' * `"validation"`: the `PipeOp` can make use of the `$internal_valid_task` of an [`mlr3::Task`], see [`mlr3::Learner`] for more information.
#' `PipeOp`s that have this property, also have a `$validate` field, which controls whether to use the validation task,
#' as well as a `$internal_valid_scores` field, which allows to access the internal validation scores after training.
#' * `"internal_tuning"`: the `PipeOp` is able to internally optimize hyperparameters, see [`mlr3::Learner`] for an explanation.
#' `PipeOp`s with that property also implement the standardized accessor `$internal_tuned_values`.
#'
#' Programatic access to all available properties is possible via `mlr_reflections$pipeops$properties`.
#'
#' @section Methods:
#' * `train(input)`\cr
Expand Down Expand Up @@ -235,8 +245,9 @@ PipeOp = R6Class("PipeOp",
output = NULL,
.result = NULL,
tags = NULL,
properties = NULL,

initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract") {
initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract", properties = character(0)) {
if (inherits(param_set, "ParamSet")) {
private$.param_set = assert_param_set(param_set)
private$.param_set_source = NULL
Expand All @@ -246,6 +257,7 @@ PipeOp = R6Class("PipeOp",
}
self$id = assert_string(id)

self$properties = assert_subset(properties, mlr_reflections$pipeops$properties)
self$param_set$values = insert_named(self$param_set$values, param_vals)
self$input = assert_connection_table(input)
self$output = assert_connection_table(output)
Expand Down Expand Up @@ -596,4 +608,3 @@ evaluate_multiplicities = function(self, unpacked, evalcall, instate) {
map(transpose_list(map(result, "output")), as.Multiplicity)
}
}

52 changes: 50 additions & 2 deletions R/PipeOpLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,17 @@
#' [`Learner`][mlr3::Learner] that is being wrapped. Read-only.
#' * `learner_model` :: [`Learner`][mlr3::Learner]\cr
#' [`Learner`][mlr3::Learner] that is being wrapped. This learner contains the model if the `PipeOp` is trained. Read-only.
#'
#' * `validate` :: `"predefined"` or `NULL`\cr
#' This field can only be set for `Learner`s that have the `"validation"` property.
#' Setting the field to `"predefined"` means that the wrapped `Learner` will use the internal validation task,
#' otherwise it will be ignored.
#' Note that specifying *how* the validation data is created is possible via the `$validate` field of the [`GraphLearner`].
#' For each `PipeOp` it is then only possible to either use it (`"predefined"`) or not use it (`NULL`).
#' Also see [`set_validate.GraphLearner`] for more information.
#' * `internal_tuned_values` :: named `list()` or `NULL`\cr
#' The internally tuned values if the wrapped `Learner`s supports internal tuning, `NULL` otherwise.
#' * `internal_valid_scores` :: named `list()` or `NULL`\cr
#' The internal validation scores if the wrapped `Learner`s supports internal validation, `NULL` otherwise.
#' @section Methods:
#' Methods inherited from [`PipeOp`].
#'
Expand Down Expand Up @@ -91,13 +101,38 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
type = private$.learner$task_type
task_type = mlr_reflections$task_types[type, mult = "first"]$task
out_type = mlr_reflections$task_types[type, mult = "first"]$prediction
properties = c("validation", "internal_tuning")
properties = properties[properties %in% learner$properties]
super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals,
input = data.table(name = "input", train = task_type, predict = task_type),
output = data.table(name = "output", train = "NULL", predict = out_type),
tags = "learner", packages = learner$packages)
tags = "learner", packages = learner$packages, properties = properties)
}
),
active = list(
internal_tuned_values = function(rhs) {
assert_ro_binding(rhs)
if ("validate" %nin% self$properties) return(NULL)
self$learner$internal_tuned_values
},
internal_valid_scores = function(rhs) {
assert_ro_binding(rhs)
if ("internal_tuning" %nin% self$properties) return(NULL)
self$learner$internal_valid_scores
},
validate = function(rhs) {
if ("validation" %nin% self$properties) {
if (!missing(rhs)) {
stopf("PipeOp '%s' does not support validation, because the wrapped Learner doesn't.", self$id)
}
return(NULL)
}
if (!missing(rhs)) {
private$.validate = assert_po_validate(rhs)
self$learner$validate = rhs
}
private$.learner$validate
},
id = function(val) {
if (!missing(val)) {
private$.id = val
Expand Down Expand Up @@ -137,6 +172,7 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
),
private = list(
.learner = NULL,
.validate = NULL,

.train = function(inputs) {
on.exit({private$.learner$state = NULL})
Expand All @@ -157,3 +193,15 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
)

mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(id = "learner", task_type = "classif", param_set = ps(), packages = "mlr3pipelines"))$new()))

#' @export
set_validate.PipeOpLearner = function(learner, validate, ...) {
assert_po_validate(validate)
on.exit({learner$validate = prev_validate})
prev_validate = learner$validate
learner$validate = validate
set_validate(learner, validate = validate, ...)
on.exit()
learner$validate = validate
invisible(learner)
}
15 changes: 15 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,18 @@ base_pipeop = function(self) {
# New movie idea: "The Last PipeOp"
last_pipeop
}

pos_with_property = function(x, property) {
x = if (test_class(x, "GraphLearner")) {
x$graph$pipeops
} else if(test_class(x, "Graph")) {
x$pipeops
} else {
x
}
keep(x, function(po) property %in% po$properties)
}

assert_po_validate = function(rhs) {
assert_choice(rhs, "predefined", null.ok = TRUE)
}
1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ register_mlr3 = function() {
c("abstract", "meta", "missings", "feature selection", "imbalanced data",
"data transform", "target transform", "ensemble", "robustify", "learner", "encode",
"multiplicity")))
x$pipeops$properties = c("validation", "internal_tuning")
}

paradox_info <- list2env(list(is_old = FALSE), parent = emptyenv())
Expand Down
8 changes: 8 additions & 0 deletions inst/testthat/helper_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,14 @@ expect_pipeop = function(po, check_ps_default_values = TRUE) {
expect_int(po$innum, lower = 1)
expect_int(po$outnum, lower = 1)
expect_valid_pipeop_param_set(po, check_ps_default_values = check_ps_default_values)
if ("validation" %in% po$properties) {
testthat::expect_true(exists("validate", po))
testthat::expect_true(exists("internal_valid_scores", envir = po))
checkmate::expect_function(mlr3misc::get_private(po)$.extract_internal_valid_scores)
}
if ("internal_tuning" %in% po$properties) {
checkmate::assert_false(exists("internal_tuning", po))
}
}

# autotest for the parmset of a pipeop
Expand Down
15 changes: 15 additions & 0 deletions tests/testthat/test_PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,3 +123,18 @@ test_that("Informative error and warning messages", {
expect_warning(potest$predict(list(1)), NA)

})

test_that("properties", {
f = function(properties) {
PipeOp$new(
id = "potest",
input = data.table(name = "input", train = "*", predict = "*"),
output = data.table(name = "input", train = "*", predict = "*"),
properties = properties
)
}

expect_error(f("abc"))
po1 = f("validation")
expect_equal(po1$properties, "validation")
})

0 comments on commit 4c06919

Please sign in to comment.