From 4c06919c9a286c2b7d193d34259452de1cb07d8b Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 25 Jun 2024 17:28:44 +0200 Subject: [PATCH] wip --- R/GraphLearner.R | 70 ++++++++++++++------------------ R/PipeOp.R | 15 ++++++- R/PipeOpLearner.R | 52 +++++++++++++++++++++++- R/utils.R | 15 +++++++ R/zzz.R | 1 + inst/testthat/helper_functions.R | 8 ++++ tests/testthat/test_PipeOp.R | 15 +++++++ 7 files changed, 133 insertions(+), 43 deletions(-) diff --git a/R/GraphLearner.R b/R/GraphLearner.R index 1cefa61a4..2eb07b1f0 100644 --- a/R/GraphLearner.R +++ b/R/GraphLearner.R @@ -56,6 +56,7 @@ #' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr #' How to construct the validation data. This also has to be configured in the individual learners wrapped by #' `PipeOpLearner`, see [`set_validate.GraphLearner`] on how to configure this. +#' For more details on the possible values, see [`mlr3::Learner`]. #' * `marshaled` :: `logical(1)`\cr #' Whether the learner is marshaled. #' @@ -119,8 +120,8 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, } assert_subset(task_type, mlr_reflections$task_types$type) - private$.can_validate = some(learner_wrapping_pipeops(graph), function(po) "validation" %in% po$learner$properties) - private$.can_internal_tuning = some(learner_wrapping_pipeops(graph), function(po) "internal_tuning" %in% po$learner$properties) + private$.can_validate = some(graph$pipeops, function(po) "validation" %in% po$properties) + private$.can_validate = some(graph$pipeops, function(po) "internal_tuning" %in% po$properties) properties = setdiff(mlr_reflections$learner_properties[[task_type]], c("validation", "internal_tuning")[!c(private$.can_validate, private$.can_internal_tuning)]) @@ -139,6 +140,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, if (!is.null(predict_type)) self$predict_type = predict_type }, base_learner = function(recursive = Inf) { + self$base_pipeop(recursive = recursive)$learner_model + }, + base_pipeop = function(recursive = Inf) { assert(check_numeric(recursive, lower = Inf), check_int(recursive)) if (recursive <= 0) return(self) gm = self$graph_model @@ -158,7 +162,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, if (length(last_pipeop_id) > 1) stop("Graph has no unique PipeOp containing a Learner") if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.") } - learner_model$base_learner(recursive - 1) + learner_model$base_pipeop(recursive - 1) }, marshal = function(...) { learner_marshal(.learner = self, ...) @@ -179,7 +183,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, validate = function(rhs) { if (!missing(rhs)) { if (!private$.can_validate) { - stopf("None of the Learners wrapped by GraphLearner '%s' support validation.", self$id) + stopf("None of the PipeOps in Graph '%s' supports validation.", self$id) } private$.validate = assert_validate(rhs) } @@ -232,30 +236,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, .can_internal_tuning = NULL, .extract_internal_tuned_values = function() { if (!private$.can_validate) return(NULL) - itvs = unlist(map( - learner_wrapping_pipeops(self$graph_model), function(po) { - if (exists("internal_tuned_values", po$learner)) { - po$learner_model$internal_tuned_values - } - } - ), recursive = FALSE) - if (is.null(itvs) || !length(itvs)) return(named_list()) + itvs = unlist(map(pos_with_property(self, "internal_tuning"), "internal_tuned_values"), recursive = FALSE) + if (!length(itvs)) return(named_list()) itvs }, .extract_internal_valid_scores = function() { if (!private$.can_internal_tuning) return(NULL) - ivs = unlist(map( - learner_wrapping_pipeops(self$graph_model), function(po) { - if (exists("internal_valid_scores", po$learner)) { - po$learner_model$internal_valid_scores - } - } - ), recursive = FALSE) + its = unlist(map(pos_with_property(self, "validation"), "internal_valid_scores"), recursive = FALSE) if (is.null(ivs) || !length(ivs)) return(named_list()) ivs }, deep_clone = function(name, value) { - private$.param_set = NULL # FIXME this repairs the mlr3::Learner deep_clone() method which is broken. if (is.environment(value) && !is.null(value[[".__enclos_env__"]])) { return(value$clone(deep = TRUE)) @@ -268,17 +259,10 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, .train = function(task) { if (!is.null(get0("validate", self))) { - some_pipeops_validate = some(learner_wrapping_pipeops(self), function(po) !is.null(get0("validate", po$learner))) + some_pipeops_validate = some(pos_with_property(self, "validation"), function(po) !is.null(po$validate)) if (!some_pipeops_validate) { lg$warn("GraphLearner '%s' specifies a validation set, but none of its Learners use it.", self$id) } - } else { - # otherwise the pipeops will preprocess this unnecessarily - if (!is.null(task$internal_valid_task)) { - prev_itv = task$internal_valid_task - on.exit({task$internal_valid_task = prev_itv}, add = TRUE) - task$internal_valid_task = NULL - } } on.exit({self$graph$state = NULL}) @@ -350,6 +334,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, #' For which pipeops to enable validation. #' This parameter is ignored when `validate` is set to `NULL`. #' By default, validation is enabled for the base learner. +#' @param args_all (`list()`)\cr +#' Rarely needed. A named list of parameter values that are passed to all subsequet [`set_validate()`] calls on the individual +#' `PipeOp`s. #' @param args (named `list()`)\cr #' Rarely needed. #' A named list of lists, specifying additional argments to be passed to [`set_validate()`] for the respective learners. @@ -376,31 +363,35 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, #' glrn$validate #' glrn$graph$pipeops$classif.debug$learner$validate #' glrn$graph$pipeops$final$learner$validate -set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(), ...) { +set_validate.GraphLearner = function(learner, validate, ids = NULL, args_all = list(), args = list(), ...) { if (is.null(validate)) { learner$validate = NULL - walk(learner_wrapping_pipeops(learner), function(po) { - po$learner$validate = NULL + walk(pos_with_property(learner$graph$pipeops, "validation"), function(po) { + # disabling needs no extra arguments + invoke(set_validate, po, validate = NULL, args_all = args_all, args = args[[po$id]] %??% list()) }) return(invisible(learner)) } if (is.null(ids)) { - ids = base_pipeop(learner)$id + ids = learner$base_pipeop(recursive = 1)$id } else { - assert_subset(ids, ids(keep(learner_wrapping_pipeops(learner), function(po) "validation" %in% po$learner$properties))) + assert_subset(ids, ids(pos_with_property(learner$graph$pipeops, "validation"))) } assert_list(args, types = "list") + assert_list(args_all, types = "list") assert_subset(names(args), ids) - prev_validate_pos = discard(map(learner_wrapping_pipeops(learner), function(po) get0("validate", po$learner, ifnotfound = NA)), - function(x) identical(x, NA)) - + prev_validate_pos = map(pos_with_property(learner$graph$pipeops, "validation"), "validate") prev_validate = learner$validate - on.exit({ - iwalk(prev_validate_pos, function(val, poid) learner$graph$pipeops[[poid]]$learner$validate = val) + iwalk(prev_validate_pos, function(val, poid) { + # passing the args here is just a heuristic that can in principle fail, but this should be extremely + # rare + args = args[[poid]] %??% list() + set_validate(learner$graph$pipeops[[poid]], validate = val, args = args, args_all = args_all) + }) learner$validate = prev_validate }, add = TRUE) @@ -409,7 +400,8 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list( walk(ids, function(poid) { # learner might be another GraphLearner / AutoTuner so we call into set_validate() again withCallingHandlers({ - invoke(set_validate, learner = learner$graph$pipeops[[poid]]$learner, .args = insert_named(list(validate = "predefined"), args[[poid]])) + args = c(args[[poid]], args_all) %??% list() + set_validate(learner$graph$pipeops[[poid]], .args = insert_named(list(validate = "predefined"), args)) }, error = function(e) { e$message = sprintf("Failed to set validate for PipeOp '%s':\n%s", poid, e$message) stop(e) diff --git a/R/PipeOp.R b/R/PipeOp.R index 9c8821f57..884304ef0 100644 --- a/R/PipeOp.R +++ b/R/PipeOp.R @@ -135,6 +135,16 @@ #' and done, if requested, by the [`Graph`] backend itself; it should *not* be done explicitly by `private$.train()` or `private$.predict()`. #' * `man` :: `character(1)`\cr #' Identifying string of the help page that shows with `help()`. +#' * `properties` :: `character()`\cr +#' The properties of the pipeop. +#' Currently supported values are: +#' * `"validation"`: the `PipeOp` can make use of the `$internal_valid_task` of an [`mlr3::Task`], see [`mlr3::Learner`] for more information. +#' `PipeOp`s that have this property, also have a `$validate` field, which controls whether to use the validation task, +#' as well as a `$internal_valid_scores` field, which allows to access the internal validation scores after training. +#' * `"internal_tuning"`: the `PipeOp` is able to internally optimize hyperparameters, see [`mlr3::Learner`] for an explanation. +#' `PipeOp`s with that property also implement the standardized accessor `$internal_tuned_values`. +#' +#' Programatic access to all available properties is possible via `mlr_reflections$pipeops$properties`. #' #' @section Methods: #' * `train(input)`\cr @@ -235,8 +245,9 @@ PipeOp = R6Class("PipeOp", output = NULL, .result = NULL, tags = NULL, + properties = NULL, - initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract") { + initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract", properties = character(0)) { if (inherits(param_set, "ParamSet")) { private$.param_set = assert_param_set(param_set) private$.param_set_source = NULL @@ -246,6 +257,7 @@ PipeOp = R6Class("PipeOp", } self$id = assert_string(id) + self$properties = assert_subset(properties, mlr_reflections$pipeops$properties) self$param_set$values = insert_named(self$param_set$values, param_vals) self$input = assert_connection_table(input) self$output = assert_connection_table(output) @@ -596,4 +608,3 @@ evaluate_multiplicities = function(self, unpacked, evalcall, instate) { map(transpose_list(map(result, "output")), as.Multiplicity) } } - diff --git a/R/PipeOpLearner.R b/R/PipeOpLearner.R index 5894dff94..ba38c0b3a 100644 --- a/R/PipeOpLearner.R +++ b/R/PipeOpLearner.R @@ -62,7 +62,17 @@ #' [`Learner`][mlr3::Learner] that is being wrapped. Read-only. #' * `learner_model` :: [`Learner`][mlr3::Learner]\cr #' [`Learner`][mlr3::Learner] that is being wrapped. This learner contains the model if the `PipeOp` is trained. Read-only. -#' +#' * `validate` :: `"predefined"` or `NULL`\cr +#' This field can only be set for `Learner`s that have the `"validation"` property. +#' Setting the field to `"predefined"` means that the wrapped `Learner` will use the internal validation task, +#' otherwise it will be ignored. +#' Note that specifying *how* the validation data is created is possible via the `$validate` field of the [`GraphLearner`]. +#' For each `PipeOp` it is then only possible to either use it (`"predefined"`) or not use it (`NULL`). +#' Also see [`set_validate.GraphLearner`] for more information. +#' * `internal_tuned_values` :: named `list()` or `NULL`\cr +#' The internally tuned values if the wrapped `Learner`s supports internal tuning, `NULL` otherwise. +#' * `internal_valid_scores` :: named `list()` or `NULL`\cr +#' The internal validation scores if the wrapped `Learner`s supports internal validation, `NULL` otherwise. #' @section Methods: #' Methods inherited from [`PipeOp`]. #' @@ -91,13 +101,38 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, type = private$.learner$task_type task_type = mlr_reflections$task_types[type, mult = "first"]$task out_type = mlr_reflections$task_types[type, mult = "first"]$prediction + properties = c("validation", "internal_tuning") + properties = properties[properties %in% learner$properties] super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals, input = data.table(name = "input", train = task_type, predict = task_type), output = data.table(name = "output", train = "NULL", predict = out_type), - tags = "learner", packages = learner$packages) + tags = "learner", packages = learner$packages, properties = properties) } ), active = list( + internal_tuned_values = function(rhs) { + assert_ro_binding(rhs) + if ("validate" %nin% self$properties) return(NULL) + self$learner$internal_tuned_values + }, + internal_valid_scores = function(rhs) { + assert_ro_binding(rhs) + if ("internal_tuning" %nin% self$properties) return(NULL) + self$learner$internal_valid_scores + }, + validate = function(rhs) { + if ("validation" %nin% self$properties) { + if (!missing(rhs)) { + stopf("PipeOp '%s' does not support validation, because the wrapped Learner doesn't.", self$id) + } + return(NULL) + } + if (!missing(rhs)) { + private$.validate = assert_po_validate(rhs) + self$learner$validate = rhs + } + private$.learner$validate + }, id = function(val) { if (!missing(val)) { private$.id = val @@ -137,6 +172,7 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, ), private = list( .learner = NULL, + .validate = NULL, .train = function(inputs) { on.exit({private$.learner$state = NULL}) @@ -157,3 +193,15 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, ) mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(id = "learner", task_type = "classif", param_set = ps(), packages = "mlr3pipelines"))$new())) + +#' @export +set_validate.PipeOpLearner = function(learner, validate, ...) { + assert_po_validate(validate) + on.exit({learner$validate = prev_validate}) + prev_validate = learner$validate + learner$validate = validate + set_validate(learner, validate = validate, ...) + on.exit() + learner$validate = validate + invisible(learner) +} diff --git a/R/utils.R b/R/utils.R index 8836b6770..a900b9781 100644 --- a/R/utils.R +++ b/R/utils.R @@ -170,3 +170,18 @@ base_pipeop = function(self) { # New movie idea: "The Last PipeOp" last_pipeop } + +pos_with_property = function(x, property) { + x = if (test_class(x, "GraphLearner")) { + x$graph$pipeops + } else if(test_class(x, "Graph")) { + x$pipeops + } else { + x + } + keep(x, function(po) property %in% po$properties) +} + +assert_po_validate = function(rhs) { + assert_choice(rhs, "predefined", null.ok = TRUE) +} diff --git a/R/zzz.R b/R/zzz.R index c6054af1c..a4333c8a8 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -16,6 +16,7 @@ register_mlr3 = function() { c("abstract", "meta", "missings", "feature selection", "imbalanced data", "data transform", "target transform", "ensemble", "robustify", "learner", "encode", "multiplicity"))) + x$pipeops$properties = c("validation", "internal_tuning") } paradox_info <- list2env(list(is_old = FALSE), parent = emptyenv()) diff --git a/inst/testthat/helper_functions.R b/inst/testthat/helper_functions.R index ef480769e..80dfd76ae 100644 --- a/inst/testthat/helper_functions.R +++ b/inst/testthat/helper_functions.R @@ -108,6 +108,14 @@ expect_pipeop = function(po, check_ps_default_values = TRUE) { expect_int(po$innum, lower = 1) expect_int(po$outnum, lower = 1) expect_valid_pipeop_param_set(po, check_ps_default_values = check_ps_default_values) + if ("validation" %in% po$properties) { + testthat::expect_true(exists("validate", po)) + testthat::expect_true(exists("internal_valid_scores", envir = po)) + checkmate::expect_function(mlr3misc::get_private(po)$.extract_internal_valid_scores) + } + if ("internal_tuning" %in% po$properties) { + checkmate::assert_false(exists("internal_tuning", po)) + } } # autotest for the parmset of a pipeop diff --git a/tests/testthat/test_PipeOp.R b/tests/testthat/test_PipeOp.R index b56d3ba56..6a030929b 100644 --- a/tests/testthat/test_PipeOp.R +++ b/tests/testthat/test_PipeOp.R @@ -123,3 +123,18 @@ test_that("Informative error and warning messages", { expect_warning(potest$predict(list(1)), NA) }) + +test_that("properties", { + f = function(properties) { + PipeOp$new( + id = "potest", + input = data.table(name = "input", train = "*", predict = "*"), + output = data.table(name = "input", train = "*", predict = "*"), + properties = properties + ) + } + + expect_error(f("abc")) + po1 = f("validation") + expect_equal(po1$properties, "validation") +})