Skip to content

Commit

Permalink
tests hopefully pass
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Jun 26, 2024
1 parent 4c06919 commit 71003d7
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 68 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ S3method(predict,Graph)
S3method(print,Multiplicity)
S3method(print,Selector)
S3method(set_validate,GraphLearner)
S3method(set_validate,PipeOpLearner)
S3method(unmarshal_model,Multiplicity_marshaled)
S3method(unmarshal_model,graph_learner_model_marshaled)
S3method(unmarshal_model,pipeop_impute_learner_state_marshaled)
Expand Down
51 changes: 28 additions & 23 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,11 @@
#' The internal tuned parameter values.
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal tuning.
#' * `internal_valid_scores` :: named `list()` or `NULL`\cr
#' The internal tuned parameter values.
#' The internal validation scores as retrieved from the `PipeOps`.
#' The names are prefixed with the respective IDs of the `PipeOp`s.
#' `NULL` is returned if the learner is not trained or none of the wrapped learners supports internal validation.
#' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
#' How to construct the validation data. This also has to be configured in the individual learners wrapped by
#' How to construct the validation data. This also has to be configured in the individual `PipeOp`s such as
#' `PipeOpLearner`, see [`set_validate.GraphLearner`] on how to configure this.
#' For more details on the possible values, see [`mlr3::Learner`].
#' * `marshaled` :: `logical(1)`\cr
Expand Down Expand Up @@ -121,7 +122,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
assert_subset(task_type, mlr_reflections$task_types$type)

private$.can_validate = some(graph$pipeops, function(po) "validation" %in% po$properties)
private$.can_validate = some(graph$pipeops, function(po) "internal_tuning" %in% po$properties)
private$.can_internal_tuning = some(graph$pipeops, function(po) "internal_tuning" %in% po$properties)

properties = setdiff(mlr_reflections$learner_properties[[task_type]],
c("validation", "internal_tuning")[!c(private$.can_validate, private$.can_internal_tuning)])
Expand All @@ -139,11 +140,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
if (!is.null(predict_type)) self$predict_type = predict_type
},
base_learner = function(recursive = Inf) {
self$base_pipeop(recursive = recursive)$learner_model
},
base_pipeop = function(recursive = Inf) {
base_learner = function(recursive = Inf, return_po = FALSE) {
assert(check_numeric(recursive, lower = Inf), check_int(recursive))
assert_flag(return_po)
if (recursive <= 0) return(self)
gm = self$graph_model
gm_output = gm$output
Expand All @@ -162,7 +161,11 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
if (length(last_pipeop_id) > 1) stop("Graph has no unique PipeOp containing a Learner")
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
}
learner_model$base_pipeop(recursive - 1)
if (return_po) {
last_pipeop
} else {
learner_model$base_learner(recursive - 1)
}
},
marshal = function(...) {
learner_marshal(.learner = self, ...)
Expand Down Expand Up @@ -236,13 +239,13 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
.can_internal_tuning = NULL,
.extract_internal_tuned_values = function() {
if (!private$.can_validate) return(NULL)
itvs = unlist(map(pos_with_property(self, "internal_tuning"), "internal_tuned_values"), recursive = FALSE)
itvs = unlist(map(pos_with_property(self$graph_model, "internal_tuning"), "internal_tuned_values"), recursive = FALSE)
if (!length(itvs)) return(named_list())
itvs
},
.extract_internal_valid_scores = function() {
if (!private$.can_internal_tuning) return(NULL)
its = unlist(map(pos_with_property(self, "validation"), "internal_valid_scores"), recursive = FALSE)
ivs = unlist(map(pos_with_property(self$graph_model, "validation"), "internal_valid_scores"), recursive = FALSE)
if (is.null(ivs) || !length(ivs)) return(named_list())
ivs
},
Expand Down Expand Up @@ -367,30 +370,28 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args_all = l
if (is.null(validate)) {
learner$validate = NULL
walk(pos_with_property(learner$graph$pipeops, "validation"), function(po) {
# disabling needs no extra arguments
invoke(set_validate, po, validate = NULL, args_all = args_all, args = args[[po$id]] %??% list())
})
return(invisible(learner))
}

if (is.null(ids)) {
ids = learner$base_pipeop(recursive = 1)$id
ids = learner$base_learner(recursive = 1, return_po = TRUE)$id
} else {
assert_subset(ids, ids(pos_with_property(learner$graph$pipeops, "validation")))
}

assert_list(args, types = "list")
assert_list(args_all, types = "list")
assert_list(args_all)
assert_subset(names(args), ids)

prev_validate_pos = map(pos_with_property(learner$graph$pipeops, "validation"), "validate")
prev_validate = learner$validate
on.exit({
iwalk(prev_validate_pos, function(val, poid) {
# passing the args here is just a heuristic that can in principle fail, but this should be extremely
# rare
args = args[[poid]] %??% list()
set_validate(learner$graph$pipeops[[poid]], validate = val, args = args, args_all = args_all)
iwalk(prev_validate_pos, function(prev_val, poid) {
# Here we don't call into set_validate() as this also does not ensure that we are able to correctly
# reset the configuration to the previous state (e.g. for AutoTuner) and is less transparent
learner$graph$pipeops[[poid]]$validate = prev_val
})
learner$validate = prev_validate
}, add = TRUE)
Expand All @@ -400,13 +401,17 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args_all = l
walk(ids, function(poid) {
# learner might be another GraphLearner / AutoTuner so we call into set_validate() again
withCallingHandlers({
args = c(args[[poid]], args_all) %??% list()
set_validate(learner$graph$pipeops[[poid]], .args = insert_named(list(validate = "predefined"), args))
args = insert_named(c(list(validate = "predefined"), args_all), args[[poid]])
invoke(set_validate, learner$graph$pipeops[[poid]], .args = args)
}, error = function(e) {
e$message = sprintf("Failed to set validate for PipeOp '%s':\n%s", poid, e$message)
e$message = sprintf(paste0(
"Failed to set validate for PipeOp '%s':\n%s\n",
"Trying to heuristically reset validation to its previous state, please check the results"), poid, e$message)
stop(e)
}, warning = function(w) {
w$message = sprintf("Failed to set validate for PipeOp '%s':\n%s", po$id, w$message)
w$message = sprintf(paste0(
"Failed to set validate for PipeOp '%s':\n%s\n",
"Trying to heuristically reset validation to its previous state, please check the results"), poid, w$message)
warning(w)
invokeRestart("muffleWarning")
})
Expand Down Expand Up @@ -487,4 +492,4 @@ infer_task_type = function(graph) {
task_type = get_po_task_type(graph$pipeops[[graph$rhs]])
}
c(task_type, "classif")[[1]] # "classif" as final fallback
}
}
3 changes: 2 additions & 1 deletion R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,8 @@
#' as well as a `$internal_valid_scores` field, which allows to access the internal validation scores after training.
#' * `"internal_tuning"`: the `PipeOp` is able to internally optimize hyperparameters, see [`mlr3::Learner`] for an explanation.
#' `PipeOp`s with that property also implement the standardized accessor `$internal_tuned_values`.
#'
#' An example for such a `PipeOp` is a `PipeOpLearner` that wraps a `Learner` with the `"internal_tuning"` property.
#'
#' Programatic access to all available properties is possible via `mlr_reflections$pipeops$properties`.
#'
#' @section Methods:
Expand Down
52 changes: 39 additions & 13 deletions R/PipeOpLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,12 +97,18 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
initialize = function(learner, id = NULL, param_vals = list()) {
private$.learner = as_learner(learner, clone = TRUE)
id = id %??% private$.learner$id
if (!test_po_validate(get0("validate", private$.learner))) {
stopf(paste0(
"Validate field of PipeOp '%s' must either be NULL or 'predefined'.\nTo configure how ",
"the validation data is created, set the $validate field of the GraphLearner, e.g. using set_validate()."
), id) # nolint
}
# FIXME: can be changed when mlr-org/mlr3#470 has an answer
type = private$.learner$task_type
task_type = mlr_reflections$task_types[type, mult = "first"]$task
out_type = mlr_reflections$task_types[type, mult = "first"]$prediction
properties = c("validation", "internal_tuning")
properties = properties[properties %in% learner$properties]
properties = properties[properties %in% private$.learner$properties]
super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals,
input = data.table(name = "input", train = task_type, predict = task_type),
output = data.table(name = "output", train = "NULL", predict = out_type),
Expand All @@ -112,13 +118,13 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
active = list(
internal_tuned_values = function(rhs) {
assert_ro_binding(rhs)
if ("validate" %nin% self$properties) return(NULL)
self$learner$internal_tuned_values
if ("internal_tuning" %nin% self$properties) return(NULL)
self$learner_model$internal_tuned_values
},
internal_valid_scores = function(rhs) {
assert_ro_binding(rhs)
if ("internal_tuning" %nin% self$properties) return(NULL)
self$learner$internal_valid_scores
if ("validation" %nin% self$properties) return(NULL)
self$learner_model$internal_valid_scores
},
validate = function(rhs) {
if ("validation" %nin% self$properties) {
Expand All @@ -128,8 +134,7 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
return(NULL)
}
if (!missing(rhs)) {
private$.validate = assert_po_validate(rhs)
self$learner$validate = rhs
private$.learner$validate = assert_po_validate(rhs)
}
private$.learner$validate
},
Expand All @@ -147,6 +152,14 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
if (!identical(val, private$.learner)) {
stop("$learner is read-only.")
}
validate = get0("validate", private$.learner)
if (!test_po_validate(validate)) {
warningf(paste(sep = "\n",
"PipeOpLearner '%s' has its validate field set to a value that is neither NULL nor 'predefined'.",
"This will likely lead to unexpected behaviour.",
"Configure the $validate field of the GraphLearner to define how the validation data is created."
), self$id)
}
}
private$.learner
},
Expand All @@ -172,7 +185,6 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
),
private = list(
.learner = NULL,
.validate = NULL,

.train = function(inputs) {
on.exit({private$.learner$state = NULL})
Expand All @@ -192,16 +204,30 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
)
)

mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(id = "learner", task_type = "classif", param_set = ps(), packages = "mlr3pipelines"))$new()))
mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(properties = character(0), id = "learner", task_type = "classif", param_set = ps(), packages = "mlr3pipelines"))$new())) # nolint

#' @export
set_validate.PipeOpLearner = function(learner, validate, ...) {
assert_po_validate(validate)
on.exit({learner$validate = prev_validate})
on.exit({
# also does not work in general (e.g. for AutoTuner) and is even less transparent
learner$validate = prev_validate
})
prev_validate = learner$validate
learner$validate = validate
set_validate(learner, validate = validate, ...)
withCallingHandlers({
set_validate(learner$learner, validate = validate, ...)
}, error = function(e) {
e$message = sprintf(paste0(
"Failed to set validate for Learner '%s':\n%s\n",
"Trying to heuristically reset validation to its previous state, please check the results"), learner$id, e$message)
stop(e)
}, warning = function(w) {
w$message = sprintf(paste0(
"Failed to set validate for PipeOp '%s':\n%s\n",
"Trying to heuristically reset validation to its previous state, please check the results"), learner$id, w$message)
warning(w)
invokeRestart("muffleWarning")
})
on.exit()
learner$validate = validate
invisible(learner)
}
28 changes: 4 additions & 24 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -147,30 +147,6 @@ learner_wrapping_pipeops = function(x) {
keep(x, function(po) inherits(po, "PipeOpLearner") || inherits(po, "PipeOpLearnerCV"))
}


# get the last PipeOpLearner
base_pipeop = function(self) {
gm = self$graph_model
gm_output = gm$output
if (nrow(gm_output) != 1) stop("Graph has no unique output.")
last_pipeop_id = gm_output$op.id

# pacify static checks
src_id = NULL
dst_id = NULL

repeat {
last_pipeop = gm$pipeops[[last_pipeop_id]]
learner_model = if ("learner_model" %in% names(last_pipeop)) last_pipeop$learner_model
if (!is.null(learner_model)) break
last_pipeop_id = gm$edges[dst_id == last_pipeop_id]
if (length(last_pipeop_id) > 1) stop("Graph has no unique PipeOp containing a Learner")
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
}
# New movie idea: "The Last PipeOp"
last_pipeop
}

pos_with_property = function(x, property) {
x = if (test_class(x, "GraphLearner")) {
x$graph$pipeops
Expand All @@ -185,3 +161,7 @@ pos_with_property = function(x, property) {
assert_po_validate = function(rhs) {
assert_choice(rhs, "predefined", null.ok = TRUE)
}

test_po_validate = function(x) {
test_choice(x, "predefined", null.ok = TRUE)
}
13 changes: 13 additions & 0 deletions man/PipeOp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions man/mlr_learners_graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions man/mlr_pipeops_learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 12 additions & 1 deletion man/set_validate.GraphLearner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 71003d7

Please sign in to comment.