Skip to content

Commit

Permalink
Merge pull request #770 from mlr-org/feat/inner_valid
Browse files Browse the repository at this point in the history
Feat/inner valid
  • Loading branch information
mb706 authored Jun 30, 2024
2 parents d0a5495 + ac09cae commit 3de86cd
Show file tree
Hide file tree
Showing 22 changed files with 621 additions and 18 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Authors@R:
comment = c(ORCID = "0000-0001-9754-0393")),
person(given = "Sebastian",
family = "Fischer",
role = "ctb",
role = "aut",
email = "[email protected]",
comment = c(ORCID = "0000-0002-9609-3197")),
person(given = "Susanne",
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ S3method(pos,list)
S3method(predict,Graph)
S3method(print,Multiplicity)
S3method(print,Selector)
S3method(set_validate,GraphLearner)
S3method(set_validate,PipeOpLearner)
S3method(unmarshal_model,Multiplicity_marshaled)
S3method(unmarshal_model,graph_learner_model_marshaled)
S3method(unmarshal_model,pipeop_impute_learner_state_marshaled)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

* Compatibility with new `bbotk` release.
* Added marshaling support to `GraphLearner`
* Support internal tuning and validation

# mlr3pipelines 0.5.2

Expand Down
174 changes: 169 additions & 5 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@
#' contain the model. Use `graph_model` to access the trained [`Graph`] after `$train()`. Read-only.
#' * `graph_model` :: [`Learner`][mlr3::Learner]\cr
#' [`Graph`] that is being wrapped. This [`Graph`] contains a trained state after `$train()`. Read-only.
#' * `internal_tuned_values` :: named `list()` or `NULL`\cr
#' The internal tuned parameter values collected from all `PipeOp`s.
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal tuning.
#' * `internal_valid_scores` :: named `list()` or `NULL`\cr
#' The internal validation scores as retrieved from the `PipeOps`.
#' The names are prefixed with the respective IDs of the `PipeOp`s.
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
#' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
#' How to construct the validation data. This also has to be configured for the individual `PipeOp`s such as
#' `PipeOpLearner`, see [`set_validate.GraphLearner`].
#' For more details on the possible values, see [`mlr3::Learner`].
#' * `marshaled` :: `logical(1)`\cr
#' Whether the learner is marshaled.
#'
Expand Down Expand Up @@ -110,11 +121,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
assert_subset(task_type, mlr_reflections$task_types$type)

private$.can_validate = some(graph$pipeops, function(po) "validation" %in% po$properties)
private$.can_internal_tuning = some(graph$pipeops, function(po) "internal_tuning" %in% po$properties)

properties = setdiff(mlr_reflections$learner_properties[[task_type]],
c("validation", "internal_tuning")[!c(private$.can_validate, private$.can_internal_tuning)])

super$initialize(id = id, task_type = task_type,
feature_types = mlr_reflections$task_feature_types,
predict_types = names(mlr_reflections$learner_predict_types[[task_type]]),
packages = graph$packages,
properties = mlr_reflections$learner_properties[[task_type]],
properties = properties,
man = "mlr3pipelines::GraphLearner"
)

Expand All @@ -123,8 +140,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
if (!is.null(predict_type)) self$predict_type = predict_type
},
base_learner = function(recursive = Inf) {
base_learner = function(recursive = Inf, return_po = FALSE) {
assert(check_numeric(recursive, lower = Inf), check_int(recursive))
assert_flag(return_po)
if (recursive <= 0) return(self)
gm = self$graph_model
gm_output = gm$output
Expand All @@ -143,7 +161,11 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
if (length(last_pipeop_id) > 1) stop("Graph has no unique PipeOp containing a Learner")
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
}
learner_model$base_learner(recursive - 1)
if (return_po) {
last_pipeop
} else {
learner_model$base_learner(recursive - 1)
}
},
marshal = function(...) {
learner_marshal(.learner = self, ...)
Expand All @@ -153,15 +175,32 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
),
active = list(
internal_valid_scores = function(rhs) {
assert_ro_binding(rhs)
self$state$internal_valid_scores
},
internal_tuned_values = function(rhs) {
assert_ro_binding(rhs)
self$state$internal_tuned_values
},
validate = function(rhs) {
if (!missing(rhs)) {
if (!private$.can_validate) {
stopf("None of the PipeOps in Graph '%s' supports validation.", self$id)
}
private$.validate = assert_validate(rhs)
}
private$.validate
},
marshaled = function() {
learner_marshaled(self)
},
hash = function() {
digest(list(class(self), self$id, self$graph$hash, private$.predict_type,
digest(list(class(self), self$id, self$graph$hash, private$.predict_type, private$.validate,
self$fallback$hash, self$parallel_predict), algo = "xxhash64")
},
phash = function() {
digest(list(class(self), self$id, self$graph$phash, private$.predict_type,
digest(list(class(self), self$id, self$graph$phash, private$.predict_type, private$.validate,
self$fallback$hash, self$parallel_predict), algo = "xxhash64")
},
predict_type = function(rhs) {
Expand Down Expand Up @@ -195,6 +234,21 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
),
private = list(
.graph = NULL,
.validate = NULL,
.can_validate = NULL,
.can_internal_tuning = NULL,
.extract_internal_tuned_values = function() {
if (!private$.can_validate) return(NULL)
itvs = unlist(map(pos_with_property(self$graph_model, "internal_tuning"), "internal_tuned_values"), recursive = FALSE)
if (!length(itvs)) return(named_list())
itvs
},
.extract_internal_valid_scores = function() {
if (!private$.can_internal_tuning) return(NULL)
ivs = unlist(map(pos_with_property(self$graph_model, "validation"), "internal_valid_scores"), recursive = FALSE)
if (!length(ivs)) return(named_list())
ivs
},
deep_clone = function(name, value) {
# FIXME this repairs the mlr3::Learner deep_clone() method which is broken.
if (is.environment(value) && !is.null(value[[".__enclos_env__"]])) {
Expand All @@ -207,6 +261,13 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
},

.train = function(task) {
if (!is.null(get0("validate", self))) {
some_pipeops_validate = some(pos_with_property(self, "validation"), function(po) !is.null(po$validate))
if (!some_pipeops_validate) {
lg$warn("GraphLearner '%s' specifies a validation set, but none of its PipeOps use it.", self$id)
}
}

on.exit({self$graph$state = NULL})
self$graph$train(task)
state = self$graph$state
Expand Down Expand Up @@ -255,6 +316,109 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
)
)

#' @title Configure Validation for a GraphLearner
#'
#' @description
#' Configure validation for a graph learner.
#'
#' In a [`GraphLearner`], validation can be configured on two levels:
#' 1. On the [`GraphLearner`] level, which specifies **how** the validation set is constructed before entering the graph.
#' 2. On the level of the individual `PipeOp`s (such as `PipeOpLearner`), which specifies
#' which pipeops actually make use of the validation data (set its `$validate` field to `"predefined"`) or not (set it to `NULL`).
#' This can be specified via the argument `ids`.
#'
#' @param learner ([`GraphLearner`])\cr
#' The graph learner to configure.
#' @param validate (`numeric(1)`, `"predefined"`, `"test"`, or `NULL`)\cr
#' How to set the `$validate` field of the learner.
#' If set to `NULL` all validation is disabled, both on the graph learner level, but also for all pipeops.
#' @param ids (`NULL` or `character()`)\cr
#' For which pipeops to enable validation.
#' This parameter is ignored when `validate` is set to `NULL`.
#' By default, validation is enabled for the final `PipeOp` in the `Graph`.
#' @param args_all (`list()`)\cr
#' Rarely needed. A named list of parameter values that are passed to all subsequet [`set_validate()`] calls on the individual
#' `PipeOp`s.
#' @param args (named `list()`)\cr
#' Rarely needed.
#' A named list of lists, specifying additional argments to be passed to [`set_validate()`] when calling it on the individual
#' `PipeOp`s.
#' @param ... (any)\cr
#' Currently unused.
#'
#' @export
#' @examples
#' library(mlr3)
#'
#' glrn = as_learner(po("pca") %>>% lrn("classif.debug"))
#' set_validate(glrn, 0.3)
#' glrn$validate
#' glrn$graph$pipeops$classif.debug$learner$validate
#'
#' set_validate(glrn, NULL)
#' glrn$validate
#' glrn$graph$pipeops$classif.debug$learner$validate
#'
#' set_validate(glrn, 0.2, ids = "classif.debug")
#' glrn$validate
#' glrn$graph$pipeops$classif.debug$learner$validate
set_validate.GraphLearner = function(learner, validate, ids = NULL, args_all = list(), args = list(), ...) {
prev_validate_pos = map(pos_with_property(learner$graph$pipeops, "validation"), "validate")
prev_validate = learner$validate
on.exit({
iwalk(prev_validate_pos, function(prev_val, poid) {
# Here we don't call into set_validate() as this also does not ensure that we are able to correctly
# reset the configuration to the previous state, is less transparent and might fail again
# The error message informs the user about this though via the calling handlers below
learner$graph$pipeops[[poid]]$validate = prev_val
})
learner$validate = prev_validate
}, add = TRUE)

if (is.null(validate)) {
learner$validate = NULL
walk(pos_with_property(learner$graph$pipeops, "validation"), function(po) {
invoke(set_validate, po, validate = NULL, args_all = args_all, args = args[[po$id]] %??% list())
})
on.exit()
return(invisible(learner))
}

if (is.null(ids)) {
ids = learner$base_learner(return_po = TRUE)$id
} else {
assert_subset(ids, ids(pos_with_property(learner$graph$pipeops, "validation")))
}

assert_list(args, types = "list")
assert_list(args_all)
assert_subset(names(args), ids)

learner$validate = validate

walk(ids, function(poid) {
# learner might be another GraphLearner / AutoTuner so we call into set_validate() again
withCallingHandlers({
args = insert_named(insert_named(list(validate = "predefined"), args_all), args[[poid]])
invoke(set_validate, learner$graph$pipeops[[poid]], .args = args)
}, error = function(e) {
e$message = sprintf(paste0(
"Failed to set validate for PipeOp '%s':\n%s\n",
"Trying to heuristically reset validation to its previous state, please check the results"), poid, e$message)
stop(e)
}, warning = function(w) {
w$message = sprintf(paste0(
"Failed to set validate for PipeOp '%s':\n%s\n",
"Trying to heuristically reset validation to its previous state, please check the results"), poid, w$message)
warning(w)
invokeRestart("muffleWarning")
})
})
on.exit()

invisible(learner)
}

#' @export
marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) {
xm = map(.x = model, .f = marshal_model, inplace = inplace, ...)
Expand Down
19 changes: 17 additions & 2 deletions R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@
#' and done, if requested, by the [`Graph`] backend itself; it should *not* be done explicitly by `private$.train()` or `private$.predict()`.
#' * `man` :: `character(1)`\cr
#' Identifying string of the help page that shows with `help()`.
#' * `properties` :: `character()`\cr
#' The properties of the pipeop.
#' Currently supported values are:
#' * `"validation"`: the `PipeOp` can make use of the `$internal_valid_task` of an [`mlr3::Task`].
#' This is for example used for `PipeOpLearner`s that wrap a `Learner` with this property, see [`mlr3::Learner`].
#' `PipeOp`s that have this property, also have a `$validate` field, which controls whether to use the validation task,
#' as well as a `$internal_valid_scores` field, which allows to access the internal validation scores after training.
#' * `"internal_tuning"`: the `PipeOp` is able to internally optimize hyperparameters.
#' This works analogously to the internal tuning implementation for [`mlr3::Learner`].
#' `PipeOp`s with that property also implement the standardized accessor `$internal_tuned_values` and have at least one
#' parameter tagged with `"internal_tuning"`.
#' An example for such a `PipeOp` is a `PipeOpLearner` that wraps a `Learner` with the `"internal_tuning"` property.
#'
#' Programatic access to all available properties is possible via `mlr_reflections$pipeops$properties`.
#'
#' @section Methods:
#' * `train(input)`\cr
Expand Down Expand Up @@ -235,8 +249,9 @@ PipeOp = R6Class("PipeOp",
output = NULL,
.result = NULL,
tags = NULL,
properties = NULL,

initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract") {
initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract", properties = character(0)) {
if (inherits(param_set, "ParamSet")) {
private$.param_set = assert_param_set(param_set)
private$.param_set_source = NULL
Expand All @@ -246,6 +261,7 @@ PipeOp = R6Class("PipeOp",
}
self$id = assert_string(id)

self$properties = assert_subset(properties, mlr_reflections$pipeops$properties)
self$param_set$values = insert_named(self$param_set$values, param_vals)
self$input = assert_connection_table(input)
self$output = assert_connection_table(output)
Expand Down Expand Up @@ -601,4 +617,3 @@ evaluate_multiplicities = function(self, unpacked, evalcall, instate) {
map(transpose_list(map(result, "output")), as.Multiplicity)
}
}

4 changes: 4 additions & 0 deletions R/PipeOpImpute.R
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ PipeOpImpute = R6Class("PipeOpImpute",

self$state$outtasklayout = copy(intask$feature_types)

if (!is.null(intask$internal_valid_task)) {
intask$internal_valid_task = private$.predict(list(intask$internal_valid_task))[[1L]]
}

list(intask)
},

Expand Down
Loading

0 comments on commit 3de86cd

Please sign in to comment.