diff --git a/DESCRIPTION b/DESCRIPTION index f70ca2cfb..585e76826 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -52,7 +52,7 @@ Imports: data.table, digest, lgr, - mlr3 (>= 0.6.0), + mlr3 (>= 0.19.0), mlr3misc (>= 0.9.0), paradox, R6, diff --git a/NAMESPACE b/NAMESPACE index 37876d740..74e7cb4cc 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -12,6 +12,10 @@ S3method(as_pipeop,Learner) S3method(as_pipeop,PipeOp) S3method(as_pipeop,default) S3method(disable_internal_tuning,GraphLearner) +S3method(marshal_model,Multiplicity) +S3method(marshal_model,graph_learner_model) +S3method(marshal_model,pipeop_impute_learner_state) +S3method(marshal_model,pipeop_learner_cv_state) S3method(po,"NULL") S3method(po,Filter) S3method(po,Learner) @@ -25,6 +29,10 @@ S3method(predict,Graph) S3method(print,Multiplicity) S3method(print,Selector) S3method(set_validate,GraphLearner) +S3method(unmarshal_model,Multiplicity_marshaled) +S3method(unmarshal_model,graph_learner_model_marshaled) +S3method(unmarshal_model,pipeop_impute_learner_state_marshaled) +S3method(unmarshal_model,pipeop_learner_cv_state_marshaled) export("%>>!%") export("%>>%") export(Graph) @@ -155,5 +163,6 @@ importFrom(data.table,as.data.table) importFrom(digest,digest) importFrom(stats,setNames) importFrom(utils,bibentry) +importFrom(utils,head) importFrom(utils,tail) importFrom(withr,with_options) diff --git a/NEWS.md b/NEWS.md index c68a2e96e..0e40259de 100644 --- a/NEWS.md +++ b/NEWS.md @@ -6,6 +6,8 @@ * Minor documentation fixes. * Test helpers are now available in `inst/`. These are considered experimental and unstable. +* Added marshaling support to `GraphLearner` + # mlr3pipelines 0.5.1 * Changed the ID of `PipeOpFeatureUnion` used in `ppl("robustify")` and `ppl("stacking")`. diff --git a/R/GraphLearner.R b/R/GraphLearner.R index f0c092628..db6fc895d 100644 --- a/R/GraphLearner.R +++ b/R/GraphLearner.R @@ -56,7 +56,16 @@ #' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr #' How to construct the validation data. This also has to be configured in the individual learners wrapped by #' `PipeOpLearner`, see [`set_validate.GraphLearner`] on how to configure this. +#' * `marshaled` :: `logical(1)`\cr +#' Whether the learner is marshaled. #' +#' @section Methods: +#' * `marshal(...)`\cr +#' (any) -> `self`\cr +#' Marshal the model. +#' * `unmarshal(...)`\cr +#' (any) -> `self`\cr +#' Unmarshal the model. #' #' @section Internals: #' [`as_graph()`] is called on the `graph` argument, so it can technically also be a `list` of things, which is @@ -150,6 +159,12 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.") } learner_model$base_learner(recursive - 1) + }, + marshal = function(...) { + learner_marshal(.learner = self, ...) + }, + unmarshal = function(...) { + learner_unmarshal(.learner = self, ...) } ), active = list( @@ -169,7 +184,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, private$.validate = assert_validate(rhs) } private$.validate - + }, + marshaled = function() { + learner_marshaled(self) }, hash = function() { digest(list(class(self), self$id, self$graph$hash, private$.predict_type, private$.validate, @@ -260,6 +277,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, on.exit({self$graph$state = NULL}) self$graph$train(task) state = self$graph$state + class(state) = c("graph_learner_model", class(state)) state }, .predict = function(task) { @@ -414,6 +432,27 @@ disable_internal_tuning.GraphLearner = function(learner, ids, ...) { } +#' @export +marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) { + xm = map(.x = model, .f = marshal_model, inplace = inplace, ...) + # if none of the states required any marshaling we return the model as-is + if (!some(xm, is_marshaled_model)) return(model) + + structure(list( + marshaled = xm, + packages = "mlr3pipelines" + ), class = c("graph_learner_model_marshaled", "list_marshaled", "marshaled")) +} + +#' @export +unmarshal_model.graph_learner_model_marshaled = function(model, inplace = FALSE, ...) { + # need to re-create the class as it gets lost during marshaling + structure( + map(.x = model$marshaled, .f = unmarshal_model, inplace = inplace, ...), + class = gsub(x = head(class(model), n = -1), pattern = "_marshaled$", replacement = "") + ) +} + #' @export as_learner.Graph = function(x, clone = FALSE, ...) { GraphLearner$new(x, clone_graph = clone) diff --git a/R/PipeOpImputeLearner.R b/R/PipeOpImputeLearner.R index ca8b730dc..e2e4e048c 100644 --- a/R/PipeOpImputeLearner.R +++ b/R/PipeOpImputeLearner.R @@ -44,6 +44,8 @@ #' for each column. If a column consists of missing values only during training, the `model` is `0` or the levels of the #' feature; these are used for sampling during prediction. #' +#' This state is given the class `"pipeop_impute_learner_state"`. +#' #' @section Parameters: #' The parameters are the parameters inherited from [`PipeOpImpute`], in addition to the parameters of the [`Learner`][mlr3::Learner] #' used for imputation. @@ -116,6 +118,13 @@ PipeOpImputeLearner = R6Class("PipeOpImputeLearner", ) super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals, whole_task_dependent = TRUE, feature_types = feature_types) + }, + train = function(inputs) { + outputs = super$train(inputs) + self$state = multiplicity_recurse(self$state, function(state) { + structure(state, class = c("pipeop_impute_learner_state", class(state))) + }) + return(outputs) } ), active = list( @@ -206,3 +215,25 @@ mlr_pipeops$add("imputelearner", PipeOpImputeLearner, list(R6Class("Learner", pu convert_to_task = function(id = "imputing", data, target, task_type, ...) { get(mlr_reflections$task_types[task_type, mult = "first"]$task)$new(id = id, backend = data, target = target, ...) } + +#' @export +marshal_model.pipeop_impute_learner_state = function(model, inplace = FALSE, ...) { + prev_class = class(model) + model$model = map(model$model, marshal_model, inplace = inplace, ...) + + if (!some(model$model, is_marshaled_model)) { + return(model) + } + + structure( + list(marshaled = model, packages = "mlr3pipelines"), + class = c(paste0(prev_class, "_marshaled"), "marshaled") + ) +} + +#' @export +unmarshal_model.pipeop_impute_learner_state_marshaled = function(model, inplace = FALSE, ...) { + state = model$marshaled + state$model = map(state$model, unmarshal_model, inplace = inplace, ...) + return(state) +} diff --git a/R/PipeOpLearner.R b/R/PipeOpLearner.R index 4649363f8..5894dff94 100644 --- a/R/PipeOpLearner.R +++ b/R/PipeOpLearner.R @@ -94,8 +94,7 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals, input = data.table(name = "input", train = task_type, predict = task_type), output = data.table(name = "output", train = "NULL", predict = out_type), - tags = "learner", packages = learner$packages - ) + tags = "learner", packages = learner$packages) } ), active = list( diff --git a/R/PipeOpLearnerCV.R b/R/PipeOpLearnerCV.R index 8da4b1c12..ed7166c70 100644 --- a/R/PipeOpLearnerCV.R +++ b/R/PipeOpLearnerCV.R @@ -61,6 +61,8 @@ #' * `predict_time` :: `NULL` | `numeric(1)` #' Prediction time, in seconds. #' +#' This state is given the class `"pipeop_learner_cv_state"`. +#' #' @section Parameters: #' The parameters are the parameters inherited from the [`PipeOpTaskPreproc`], as well as the parameters of the [`Learner`][mlr3::Learner] wrapped by this object. #' Besides that, parameters introduced are: @@ -144,8 +146,14 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV", # private$.crossval_param_set$add_dep("folds", "method", CondEqual$new("cv")) # don't do this. super$initialize(id, alist(resampling = private$.crossval_param_set, private$.learner$param_set), param_vals = param_vals, can_subset_cols = TRUE, task_type = task_type, tags = c("learner", "ensemble")) + }, + train = function(inputs) { + outputs = super$train(inputs) + self$state = multiplicity_recurse(self$state, function(state) { + structure(state, class = c("pipeop_learner_cv_state", class(state))) + }) + return(outputs) } - ), active = list( learner = function(val) { @@ -224,4 +232,29 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV", ) ) +#' @export +marshal_model.pipeop_learner_cv_state = function(model, inplace = FALSE, ...) { + # Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace + # is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3 + # workhorse function + model$model = marshal_model(model$model, inplace = inplace) + # only wrap this in a marshaled class if the model was actually marshaled above + # (the default marshal method does nothing) + if (is_marshaled_model(model$model)) { + model = structure( + list(marshaled = model, packages = "mlr3pipelines"), + class = c(paste0(class(model), "_marshaled"), "marshaled") + ) + } + model +} + +#' @export +unmarshal_model.pipeop_learner_cv_state_marshaled = function(model, inplace = FALSE, ...) { + state_marshaled = model$marshaled + state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace) + state_marshaled +} + + mlr_pipeops$add("learner_cv", PipeOpLearnerCV, list(R6Class("Learner", public = list(id = "learner_cv", task_type = "classif", param_set = ps()))$new())) diff --git a/R/PipeOpTaskPreproc.R b/R/PipeOpTaskPreproc.R index be829e85d..fac571710 100644 --- a/R/PipeOpTaskPreproc.R +++ b/R/PipeOpTaskPreproc.R @@ -187,8 +187,7 @@ PipeOpTaskPreproc = R6Class("PipeOpTaskPreproc", super$initialize(id = id, param_set = param_set, param_vals = param_vals, input = data.table(name = "input", train = task_type, predict = task_type), output = data.table(name = "output", train = task_type, predict = task_type), - packages = packages, tags = c(tags, "data transform") - ) + packages = packages, tags = c(tags, "data transform")) } ), active = list( diff --git a/R/multiplicity.R b/R/multiplicity.R index 3ad25bb63..ec90f70a1 100644 --- a/R/multiplicity.R +++ b/R/multiplicity.R @@ -115,3 +115,16 @@ multiplicity_nests_deeper_than = function(x, cutoff) { } ret } + +#' @export +marshal_model.Multiplicity = function(model, inplace = FALSE, ...) { + structure(list( + marshaled = multiplicity_recurse(model, marshal_model, inplace = inplace, ...), + packages = "mlr3pipelines" + ), class = c("Multiplicity_marshaled", "marshaled")) +} + +#' @export +unmarshal_model.Multiplicity_marshaled = function(model, inplace = FALSE, ...) { + multiplicity_recurse(model$marshaled, unmarshal_model, inplace = inplace, ...) +} diff --git a/R/zzz.R b/R/zzz.R index cf50d0dfe..c6054af1c 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -4,7 +4,7 @@ #' @import paradox #' @import mlr3misc #' @importFrom R6 R6Class -#' @importFrom utils tail +#' @importFrom utils tail head #' @importFrom digest digest #' @importFrom withr with_options #' @importFrom stats setNames diff --git a/man/mlr_learners_avg.Rd b/man/mlr_learners_avg.Rd index afba4e11c..d139fa18e 100644 --- a/man/mlr_learners_avg.Rd +++ b/man/mlr_learners_avg.Rd @@ -44,7 +44,7 @@ and \code{"regr.mse"}, i.e. mean squared error for regression. \item \code{optimizer} :: \code{\link[bbotk:Optimizer]{Optimizer}} | \code{character(1)}\cr \code{\link[bbotk:Optimizer]{Optimizer}} used to find optimal thresholds. If \code{character}, converts to \code{\link[bbotk:Optimizer]{Optimizer}} -via \code{\link[bbotk:opt]{opt}}. Initialized to \code{\link[bbotk:OptimizerNLoptr]{OptimizerNLoptr}}. +via \code{\link[bbotk:opt]{opt}}. Initialized to \code{\link[bbotk:mlr_optimizers_nloptr]{OptimizerNLoptr}}. Nloptr hyperparameters are initialized to \code{xtol_rel = 1e-8}, \code{algorithm = "NLOPT_LN_COBYLA"} and equal initial weights for each learner. For more fine-grained control, it is recommended to supply a instantiated \code{\link[bbotk:Optimizer]{Optimizer}}. diff --git a/man/mlr_learners_graph.Rd b/man/mlr_learners_graph.Rd index 33ba86c5b..cc5676487 100644 --- a/man/mlr_learners_graph.Rd +++ b/man/mlr_learners_graph.Rd @@ -65,6 +65,20 @@ The internal tuned parameter values. \item \code{validate} :: \code{numeric(1)}, \code{"predefined"}, \code{"test"} or \code{NULL}\cr How to construct the validation data. This also has to be configured in the individual learners wrapped by \code{PipeOpLearner}, see \code{\link{set_validate.GraphLearner}} on how to configure this. +\item \code{marshaled} :: \code{logical(1)}\cr +Whether the learner is marshaled. +} +} + +\section{Methods}{ + +\itemize{ +\item \code{marshal(...)}\cr +(any) -> \code{self}\cr +Marshal the model. +\item \code{unmarshal(...)}\cr +(any) -> \code{self}\cr +Unmarshal the model. } } diff --git a/man/mlr_pipeops_imputelearner.Rd b/man/mlr_pipeops_imputelearner.Rd index 4a9e19f47..bfd74292d 100644 --- a/man/mlr_pipeops_imputelearner.Rd +++ b/man/mlr_pipeops_imputelearner.Rd @@ -52,6 +52,8 @@ The \verb{$state} is a named \code{list} with the \verb{$state} elements inherit The \verb{$state$models} is a named \code{list} of \code{models} created by the \code{\link[mlr3:Learner]{Learner}}'s \verb{$.train()} function for each column. If a column consists of missing values only during training, the \code{model} is \code{0} or the levels of the feature; these are used for sampling during prediction. + +This state is given the class \code{"pipeop_impute_learner_state"}. } \section{Parameters}{ diff --git a/man/mlr_pipeops_learner_cv.Rd b/man/mlr_pipeops_learner_cv.Rd index b3b7e203e..b66dbd246 100644 --- a/man/mlr_pipeops_learner_cv.Rd +++ b/man/mlr_pipeops_learner_cv.Rd @@ -71,6 +71,8 @@ Errors logged during prediction. \item \code{predict_time} :: \code{NULL} | \code{numeric(1)} Prediction time, in seconds. } + +This state is given the class \code{"pipeop_learner_cv_state"}. } \section{Parameters}{ diff --git a/man/mlr_pipeops_tunethreshold.Rd b/man/mlr_pipeops_tunethreshold.Rd index 3845c8f65..683c87f61 100644 --- a/man/mlr_pipeops_tunethreshold.Rd +++ b/man/mlr_pipeops_tunethreshold.Rd @@ -60,7 +60,7 @@ Initialized to \code{"classif.ce"}, i.e. misclassification error. \item \code{optimizer} :: \code{\link[bbotk:Optimizer]{Optimizer}}|\code{character(1)}\cr \code{\link[bbotk:Optimizer]{Optimizer}} used to find optimal thresholds. If \code{character}, converts to \code{\link[bbotk:Optimizer]{Optimizer}} -via \code{\link[bbotk:opt]{opt}}. Initialized to \code{\link[bbotk:OptimizerGenSA]{OptimizerGenSA}}. +via \code{\link[bbotk:opt]{opt}}. Initialized to \code{\link[bbotk:mlr_optimizers_gensa]{OptimizerGenSA}}. \item \code{log_level} :: \code{character(1)} | \code{integer(1)}\cr Set a temporary log-level for \code{lgr::get_logger("bbotk")}. Initialized to: "warn". } diff --git a/tests/testthat/test_GraphLearner.R b/tests/testthat/test_GraphLearner.R index bd82f360c..c8bce6cdb 100644 --- a/tests/testthat/test_GraphLearner.R +++ b/tests/testthat/test_GraphLearner.R @@ -612,7 +612,8 @@ test_that("internal_tuned_values", { expect_false("internal_tuning" %in% glrn1$properties) expect_equal(glrn1$internal_tuned_values, NULL) - # learner with internal tuning + # learner wQ + # ith internal tuning glrn2 = as_learner(as_graph(lrn("classif.debug"))) expect_true("internal_tuning" %in% glrn2$properties) expect_equal(glrn2$internal_tuned_values, NULL) @@ -661,3 +662,34 @@ test_that("set_validate", { expect_equal(glrn2$graph$pipeops$polearner$learner$graph$pipeops$final$learner$validate, "predefined") expect_equal(glrn2$graph$pipeops$polearner$learner$graph$pipeops$classif.debug$learner$validate, NULL) }) + +test_that("marshal", { + task = tsk("iris") + glrn = as_learner(as_graph(lrn("classif.debug"))) + glrn$train(task) + p1 = glrn$predict(task) + glrn$marshal() + expect_true(glrn$marshaled) + expect_true(is_marshaled_model(glrn$state$model$marshaled$classif.debug)) + glrn$unmarshal() + expect_false(is_marshaled_model(glrn$state$model$marshaled$classif.debug)) + expect_class(glrn$model, "graph_learner_model") + expect_false(is_marshaled_model(glrn$state$model$marshaled$classif.debug$model)) + + p2 = glrn$predict(task) + expect_equal(p1$response, p2$response) + + # checks that it is marshalable + glrn$train(task) + expect_learner(glrn, task) +}) + +test_that("marshal has no effect when nothing needed marshaling", { + task = tsk("iris") + glrn = as_learner(as_graph(lrn("classif.rpart"))) + glrn$train(task) + glrn$marshal() + expect_class(glrn$marshal()$model, "graph_learner_model") + expect_class(glrn$unmarshal()$model, "graph_learner_model") + expect_learner(glrn, task = task) +}) diff --git a/tests/testthat/test_mlr_graphs_bagging.R b/tests/testthat/test_mlr_graphs_bagging.R index eda5df218..14836b705 100644 --- a/tests/testthat/test_mlr_graphs_bagging.R +++ b/tests/testthat/test_mlr_graphs_bagging.R @@ -1,6 +1,5 @@ context("ppl - pipeline_bagging") - test_that("Bagging Pipeline", { skip_if_not_installed("rpart") skip_on_cran() # takes too long diff --git a/tests/testthat/test_pipeop_impute.R b/tests/testthat/test_pipeop_impute.R index 5168346ab..9271b3697 100644 --- a/tests/testthat/test_pipeop_impute.R +++ b/tests/testthat/test_pipeop_impute.R @@ -401,7 +401,6 @@ test_that("More tests for Integers", { expect_false(any(is.na(result$data()$x)), info = po$id) expect_equal(result$missings(), c(t = 0, x = 0), info = po$id) } - }) test_that("impute, test rows and affect_columns", { diff --git a/tests/testthat/test_pipeop_imputelearner.R b/tests/testthat/test_pipeop_imputelearner.R index e0dd00e4c..cf8a37705 100644 --- a/tests/testthat/test_pipeop_imputelearner.R +++ b/tests/testthat/test_pipeop_imputelearner.R @@ -157,3 +157,17 @@ test_that("PipeOpImputeLearner - model active binding to state", { expect_equal(names(models), names(po$learner_models)) expect_true(all(pmap_lgl(list(map(models, .f = "model"), map(po$learner_models, .f = "model")), .f = all.equal))) }) + +test_that("marshal", { + task = tsk("penguins") + po_im = po("imputelearner", learner = lrn("classif.debug")) + po_im$train(list(task)) + + s = po_im$state + expect_class(s, "pipeop_impute_learner_state") + sm = marshal_model(s) + expect_class(sm, "marshaled") + su = unmarshal_model(sm) + expect_equal(s, su) +}) + diff --git a/tests/testthat/test_pipeop_learner.R b/tests/testthat/test_pipeop_learner.R index 6575081b8..9b45b707d 100644 --- a/tests/testthat/test_pipeop_learner.R +++ b/tests/testthat/test_pipeop_learner.R @@ -92,5 +92,82 @@ test_that("packages", { c("mlr3pipelines", lrn("classif.rpart")$packages), po("learner", learner = lrn("classif.rpart"))$packages ) +}) + +test_that("marshal", { + task = tsk("iris") + po_lrn = as_pipeop(lrn("classif.debug")) + po_lrn$train(list(task)) + po_state = po_lrn$state + expect_class(po_state, "learner_state") + po_state_marshaled = marshal_model(po_state, inplace = FALSE) + expect_class(po_state_marshaled, "learner_state_marshaled") + expect_true(is_marshaled_model(po_state_marshaled)) + expect_equal(po_state, unmarshal_model(po_state_marshaled)) +}) + +test_that("multiple marshal round-trips", { + task = tsk("iris") + glrn = as_learner(as_graph(lrn("classif.debug"))) + glrn$train(task) + glrn$marshal()$unmarshal()$marshal()$unmarshal() + expect_class(glrn$model, "graph_learner_model") + expect_class(glrn$model$classif.debug$model, "classif.debug_model") + + expect_learner(glrn, task = task) +}) + +test_that("marshal multiplicity", { + po = as_pipeop(lrn("classif.debug")) + po$train(list(Multiplicity(tsk("iris"), tsk("sonar")))) + s = po$state + sm = marshal_model(po$state) + expect_class(po$state, "Multiplicity") + expect_true(is_marshaled_model(sm$marshaled[[1L]])) + expect_true(is_marshaled_model(sm$marshaled[[2L]])) + + su = unmarshal_model(sm) + expect_equal(su, s) + + # recursive + po = as_pipeop(lrn("classif.debug")) + po$train(list(Multiplicity(Multiplicity(tsk("iris"))))) + p1 = po$predict(list(Multiplicity(Multiplicity(tsk("iris"))))) + + s = po$state + sm = marshal_model(po$state) + expect_class(po$state, "Multiplicity") + expect_true(is_marshaled_model(sm$marshaled[[1L]][[1L]])) + + su = unmarshal_model(sm) + expect_equal(su, s) + + po$state = su + p2 = po$predict(list(Multiplicity(Multiplicity(tsk("iris"))))) + expect_equal(p1, p2) + + task = tsk("iris") + glrn = as_learner(as_pipeop(lrn("classif.debug"))) + expect_learner(glrn, task) + p1 = glrn$train(task)$predict(task) + s1 = glrn$state + glrn$marshal()$unmarshal() + s2 = glrn$state + p2 = glrn$predict(task) + expect_equal(p1, p2) + expect_equal(s1, s2) +}) +test_that("state class and multiplicity", { + po = as_pipeop(lrn("classif.debug")) + po$train(list(Multiplicity(tsk("iris")))) + expect_class(po$state, "Multiplicity") + expect_class(po$state[[1L]], "learner_state") + + # recursive + po1 = as_pipeop(lrn("classif.debug")) + po1$train(list(Multiplicity(Multiplicity(tsk("iris"))))) + expect_class(po1$state, "Multiplicity") + expect_class(po1$state[[1L]], "Multiplicity") + expect_class(po1$state[[1L]][[1L]], "learner_state") }) diff --git a/tests/testthat/test_pipeop_learnercv.R b/tests/testthat/test_pipeop_learnercv.R index 956e3f308..d09756ae0 100644 --- a/tests/testthat/test_pipeop_learnercv.R +++ b/tests/testthat/test_pipeop_learnercv.R @@ -54,7 +54,8 @@ test_that("PipeOpLearnerCV - within resampling", { skip_if_not_installed("rpart") lrn = mlr_learners$get("classif.rpart") gr = GraphLearner$new(PipeOpLearnerCV$new(lrn) %>>% po(id = "l2", lrn)) - resample(tsk("iris"), gr, rsmp("holdout")) + rr = resample(tsk("iris"), gr, rsmp("holdout")) + expect_class(rr, "ResampleResult") }) test_that("PipeOpLearnerCV - insample resampling", { @@ -125,3 +126,81 @@ test_that("predict_type", { lcv$train(list(tsk("iris")))[[1]]$feature_names) }) + +test_that("marshal", { + task = tsk("iris") + po_lrn = as_pipeop(po("learner_cv", learner = lrn("classif.debug"))) + po_lrn$train(list(task)) + po_state = po_lrn$state + expect_class(po_state, "pipeop_learner_cv_state") + po_state_marshaled = marshal_model(po_state, inplace = FALSE) + expect_class(po_state_marshaled, "pipeop_learner_cv_state_marshaled") + expect_true(is_marshaled_model(po_state_marshaled)) + expect_equal(po_state, unmarshal_model(po_state_marshaled)) +}) + +test_that("marshal multiplicity", { + po = po("learner_cv", learner = lrn("classif.debug")) + po$train(list(Multiplicity(tsk("iris"), tsk("sonar")))) + s = po$state + sm = marshal_model(po$state) + expect_class(po$state, "Multiplicity") + expect_true(is_marshaled_model(sm$marshaled[[1L]])) + expect_true(is_marshaled_model(sm$marshaled[[2L]])) + + su = unmarshal_model(sm) + expect_equal(su, s) + + # recursive + po = po("learner_cv", learner = lrn("classif.debug")) + po$train(list(Multiplicity(Multiplicity(tsk("iris"))))) + p1 = po$predict(list(Multiplicity(Multiplicity(tsk("iris"))))) + + s = po$state + sm = marshal_model(po$state) + expect_class(po$state, "Multiplicity") + expect_true(is_marshaled_model(sm$marshaled[[1L]][[1L]])) + + su = unmarshal_model(sm) + expect_equal(su, s) + + po$state = su + p2 = po$predict(list(Multiplicity(Multiplicity(tsk("iris"))))) + expect_equal(p1, p2) + + + task = tsk("iris") + learner = lrn("classif.debug") + + lrncv_po = po("learner_cv", learner) + lrncv_po$learner$predict_type = "response" + + nop = mlr_pipeops$get("nop") + + graph = gunion(list( + lrncv_po, + nop + )) %>>% po("featureunion") %>>% lrn("classif.rpart") + + glrn = as_learner(graph) + expect_learner(glrn, task) + + p1 = glrn$train(task)$predict(task) + p2 = glrn$marshal()$unmarshal()$predict(task) + expect_equal(p1, p2) + +}) + +test_that("state class and multiplicity", { + po = po("learner_cv", learner = lrn("classif.debug")) + po$train(list(Multiplicity(tsk("iris")))) + expect_class(po$state, "Multiplicity") + expect_class(po$state[[1L]], "pipeop_learner_cv_state") + + # recursive + po1 = po("learner_cv", learner = lrn("classif.debug")) + po1$train(list(Multiplicity(Multiplicity(tsk("iris"))))) + expect_class(po1$state, "Multiplicity") + expect_class(po1$state[[1L]], "Multiplicity") + expect_class(po1$state[[1L]][[1L]], "pipeop_learner_cv_state") +})