Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/inner valid #770

Merged
merged 28 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ Authors@R:
comment = c(ORCID = "0000-0001-9754-0393")),
person(given = "Sebastian",
family = "Fischer",
role = "ctb",
role = "aut",
email = "[email protected]",
comment = c(ORCID = "0000-0002-9609-3197")),
person(given = "Susanne",
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ S3method(pos,list)
S3method(predict,Graph)
S3method(print,Multiplicity)
S3method(print,Selector)
S3method(set_validate,GraphLearner)
S3method(set_validate,PipeOpLearner)
S3method(unmarshal_model,Multiplicity_marshaled)
S3method(unmarshal_model,graph_learner_model_marshaled)
S3method(unmarshal_model,pipeop_impute_learner_state_marshaled)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

* Compatibility with new `bbotk` release.
* Added marshaling support to `GraphLearner`
* Support internal tuning and validation

# mlr3pipelines 0.5.2

Expand Down
174 changes: 169 additions & 5 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,17 @@
#' contain the model. Use `graph_model` to access the trained [`Graph`] after `$train()`. Read-only.
#' * `graph_model` :: [`Learner`][mlr3::Learner]\cr
#' [`Graph`] that is being wrapped. This [`Graph`] contains a trained state after `$train()`. Read-only.
#' * `internal_tuned_values` :: named `list()` or `NULL`\cr
#' The internal tuned parameter values collected from all `PipeOp`s.
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal tuning.
#' * `internal_valid_scores` :: named `list()` or `NULL`\cr
#' The internal validation scores as retrieved from the `PipeOps`.
#' The names are prefixed with the respective IDs of the `PipeOp`s.
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
#' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
#' How to construct the validation data. This also has to be configured for the individual `PipeOp`s such as
#' `PipeOpLearner`, see [`set_validate.GraphLearner`].
#' For more details on the possible values, see [`mlr3::Learner`].
#' * `marshaled` :: `logical(1)`\cr
#' Whether the learner is marshaled.
#'
Expand Down Expand Up @@ -110,11 +121,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
assert_subset(task_type, mlr_reflections$task_types$type)

private$.can_validate = some(graph$pipeops, function(po) "validation" %in% po$properties)
private$.can_internal_tuning = some(graph$pipeops, function(po) "internal_tuning" %in% po$properties)

properties = setdiff(mlr_reflections$learner_properties[[task_type]],
c("validation", "internal_tuning")[!c(private$.can_validate, private$.can_internal_tuning)])

super$initialize(id = id, task_type = task_type,
feature_types = mlr_reflections$task_feature_types,
predict_types = names(mlr_reflections$learner_predict_types[[task_type]]),
packages = graph$packages,
properties = mlr_reflections$learner_properties[[task_type]],
properties = properties,
man = "mlr3pipelines::GraphLearner"
)

Expand All @@ -123,8 +140,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
if (!is.null(predict_type)) self$predict_type = predict_type
},
base_learner = function(recursive = Inf) {
base_learner = function(recursive = Inf, return_po = FALSE) {
assert(check_numeric(recursive, lower = Inf), check_int(recursive))
assert_flag(return_po)
if (recursive <= 0) return(self)
gm = self$graph_model
gm_output = gm$output
Expand All @@ -143,7 +161,11 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
if (length(last_pipeop_id) > 1) stop("Graph has no unique PipeOp containing a Learner")
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
}
learner_model$base_learner(recursive - 1)
if (return_po) {
last_pipeop
} else {
learner_model$base_learner(recursive - 1)
}
},
marshal = function(...) {
learner_marshal(.learner = self, ...)
Expand All @@ -153,15 +175,32 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
),
active = list(
internal_valid_scores = function(rhs) {
assert_ro_binding(rhs)
self$state$internal_valid_scores
},
internal_tuned_values = function(rhs) {
assert_ro_binding(rhs)
self$state$internal_tuned_values
},
validate = function(rhs) {
if (!missing(rhs)) {
if (!private$.can_validate) {
stopf("None of the PipeOps in Graph '%s' supports validation.", self$id)
}
private$.validate = assert_validate(rhs)
}
private$.validate
},
marshaled = function() {
learner_marshaled(self)
},
hash = function() {
digest(list(class(self), self$id, self$graph$hash, private$.predict_type,
digest(list(class(self), self$id, self$graph$hash, private$.predict_type, private$.validate,
self$fallback$hash, self$parallel_predict), algo = "xxhash64")
},
phash = function() {
digest(list(class(self), self$id, self$graph$phash, private$.predict_type,
digest(list(class(self), self$id, self$graph$phash, private$.predict_type, private$.validate,
self$fallback$hash, self$parallel_predict), algo = "xxhash64")
},
predict_type = function(rhs) {
Expand Down Expand Up @@ -195,6 +234,21 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
),
private = list(
.graph = NULL,
.validate = NULL,
.can_validate = NULL,
.can_internal_tuning = NULL,
.extract_internal_tuned_values = function() {
if (!private$.can_validate) return(NULL)
itvs = unlist(map(pos_with_property(self$graph_model, "internal_tuning"), "internal_tuned_values"), recursive = FALSE)
if (!length(itvs)) return(named_list())
itvs
},
.extract_internal_valid_scores = function() {
if (!private$.can_internal_tuning) return(NULL)
ivs = unlist(map(pos_with_property(self$graph_model, "validation"), "internal_valid_scores"), recursive = FALSE)
if (!length(ivs)) return(named_list())
ivs
},
deep_clone = function(name, value) {
# FIXME this repairs the mlr3::Learner deep_clone() method which is broken.
if (is.environment(value) && !is.null(value[[".__enclos_env__"]])) {
Expand All @@ -207,6 +261,13 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
},

.train = function(task) {
if (!is.null(get0("validate", self))) {
some_pipeops_validate = some(pos_with_property(self, "validation"), function(po) !is.null(po$validate))
if (!some_pipeops_validate) {
lg$warn("GraphLearner '%s' specifies a validation set, but none of its PipeOps use it.", self$id)
}
}

on.exit({self$graph$state = NULL})
self$graph$train(task)
state = self$graph$state
Expand Down Expand Up @@ -255,6 +316,109 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
)
)

#' @title Configure Validation for a GraphLearner
#'
#' @description
#' Configure validation for a graph learner.
#'
#' In a [`GraphLearner`], validation can be configured on two levels:
#' 1. On the [`GraphLearner`] level, which specifies **how** the validation set is constructed before entering the graph.
#' 2. On the level of the individual `PipeOp`s (such as `PipeOpLearner`), which specifies
#' which pipeops actually make use of the validation data (set its `$validate` field to `"predefined"`) or not (set it to `NULL`).
#' This can be specified via the argument `ids`.
#'
#' @param learner ([`GraphLearner`])\cr
#' The graph learner to configure.
#' @param validate (`numeric(1)`, `"predefined"`, `"test"`, or `NULL`)\cr
#' How to set the `$validate` field of the learner.
#' If set to `NULL` all validation is disabled, both on the graph learner level, but also for all pipeops.
#' @param ids (`NULL` or `character()`)\cr
#' For which pipeops to enable validation.
#' This parameter is ignored when `validate` is set to `NULL`.
#' By default, validation is enabled for the final `PipeOp` in the `Graph`.
#' @param args_all (`list()`)\cr
#' Rarely needed. A named list of parameter values that are passed to all subsequet [`set_validate()`] calls on the individual
#' `PipeOp`s.
#' @param args (named `list()`)\cr
#' Rarely needed.
#' A named list of lists, specifying additional argments to be passed to [`set_validate()`] when calling it on the individual
#' `PipeOp`s.
#' @param ... (any)\cr
#' Currently unused.
#'
#' @export
#' @examples
#' library(mlr3)
#'
#' glrn = as_learner(po("pca") %>>% lrn("classif.debug"))
#' set_validate(glrn, 0.3)
#' glrn$validate
#' glrn$graph$pipeops$classif.debug$learner$validate
#'
#' set_validate(glrn, NULL)
#' glrn$validate
#' glrn$graph$pipeops$classif.debug$learner$validate
#'
#' set_validate(glrn, 0.2, ids = "classif.debug")
#' glrn$validate
#' glrn$graph$pipeops$classif.debug$learner$validate
set_validate.GraphLearner = function(learner, validate, ids = NULL, args_all = list(), args = list(), ...) {
prev_validate_pos = map(pos_with_property(learner$graph$pipeops, "validation"), "validate")
prev_validate = learner$validate
on.exit({
iwalk(prev_validate_pos, function(prev_val, poid) {
# Here we don't call into set_validate() as this also does not ensure that we are able to correctly
# reset the configuration to the previous state, is less transparent and might fail again
# The error message informs the user about this though via the calling handlers below
learner$graph$pipeops[[poid]]$validate = prev_val
})
learner$validate = prev_validate
}, add = TRUE)

if (is.null(validate)) {
learner$validate = NULL
walk(pos_with_property(learner$graph$pipeops, "validation"), function(po) {
invoke(set_validate, po, validate = NULL, args_all = args_all, args = args[[po$id]] %??% list())
})
on.exit()
return(invisible(learner))
}

if (is.null(ids)) {
ids = learner$base_learner(return_po = TRUE)$id
} else {
assert_subset(ids, ids(pos_with_property(learner$graph$pipeops, "validation")))
}

assert_list(args, types = "list")
assert_list(args_all)
assert_subset(names(args), ids)

learner$validate = validate

walk(ids, function(poid) {
# learner might be another GraphLearner / AutoTuner so we call into set_validate() again
withCallingHandlers({
args = insert_named(insert_named(list(validate = "predefined"), args_all), args[[poid]])
invoke(set_validate, learner$graph$pipeops[[poid]], .args = args)
}, error = function(e) {
e$message = sprintf(paste0(
"Failed to set validate for PipeOp '%s':\n%s\n",
"Trying to heuristically reset validation to its previous state, please check the results"), poid, e$message)
stop(e)
}, warning = function(w) {
w$message = sprintf(paste0(
"Failed to set validate for PipeOp '%s':\n%s\n",
"Trying to heuristically reset validation to its previous state, please check the results"), poid, w$message)
warning(w)
invokeRestart("muffleWarning")
})
})
on.exit()

invisible(learner)
}

#' @export
marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) {
xm = map(.x = model, .f = marshal_model, inplace = inplace, ...)
Expand Down
19 changes: 17 additions & 2 deletions R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,20 @@
#' and done, if requested, by the [`Graph`] backend itself; it should *not* be done explicitly by `private$.train()` or `private$.predict()`.
#' * `man` :: `character(1)`\cr
#' Identifying string of the help page that shows with `help()`.
#' * `properties` :: `character()`\cr
#' The properties of the pipeop.
#' Currently supported values are:
#' * `"validation"`: the `PipeOp` can make use of the `$internal_valid_task` of an [`mlr3::Task`].
#' This is for example used for `PipeOpLearner`s that wrap a `Learner` with this property, see [`mlr3::Learner`].
#' `PipeOp`s that have this property, also have a `$validate` field, which controls whether to use the validation task,
#' as well as a `$internal_valid_scores` field, which allows to access the internal validation scores after training.
#' * `"internal_tuning"`: the `PipeOp` is able to internally optimize hyperparameters.
#' This works analogously to the internal tuning implementation for [`mlr3::Learner`].
#' `PipeOp`s with that property also implement the standardized accessor `$internal_tuned_values` and have at least one
#' parameter tagged with `"internal_tuning"`.
#' An example for such a `PipeOp` is a `PipeOpLearner` that wraps a `Learner` with the `"internal_tuning"` property.
#'
#' Programatic access to all available properties is possible via `mlr_reflections$pipeops$properties`.
#'
#' @section Methods:
#' * `train(input)`\cr
Expand Down Expand Up @@ -235,8 +249,9 @@ PipeOp = R6Class("PipeOp",
output = NULL,
.result = NULL,
tags = NULL,
properties = NULL,

initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract") {
initialize = function(id, param_set = ps(), param_vals = list(), input, output, packages = character(0), tags = "abstract", properties = character(0)) {
if (inherits(param_set, "ParamSet")) {
private$.param_set = assert_param_set(param_set)
private$.param_set_source = NULL
Expand All @@ -246,6 +261,7 @@ PipeOp = R6Class("PipeOp",
}
self$id = assert_string(id)

self$properties = assert_subset(properties, mlr_reflections$pipeops$properties)
self$param_set$values = insert_named(self$param_set$values, param_vals)
self$input = assert_connection_table(input)
self$output = assert_connection_table(output)
Expand Down Expand Up @@ -601,4 +617,3 @@ evaluate_multiplicities = function(self, unpacked, evalcall, instate) {
map(transpose_list(map(result, "output")), as.Multiplicity)
}
}

4 changes: 4 additions & 0 deletions R/PipeOpImpute.R
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,10 @@ PipeOpImpute = R6Class("PipeOpImpute",

self$state$outtasklayout = copy(intask$feature_types)

if (!is.null(intask$internal_valid_task)) {
intask$internal_valid_task = private$.predict(list(intask$internal_valid_task))[[1L]]
}

list(intask)
},

Expand Down
Loading