diff --git a/NAMESPACE b/NAMESPACE index 104deb9e5..52b995c7a 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -12,6 +12,8 @@ S3method(as_pipeop,Learner) S3method(as_pipeop,PipeOp) S3method(as_pipeop,default) S3method(marshal_model,graph_learner_model) +S3method(marshal_model,pipeop_learner_cv_state) +S3method(marshal_model,pipeop_learner_state) S3method(po,"NULL") S3method(po,Filter) S3method(po,Learner) @@ -25,6 +27,8 @@ S3method(predict,Graph) S3method(print,Multiplicity) S3method(print,Selector) S3method(unmarshal_model,graph_learner_model_marshaled) +S3method(unmarshal_model,pipeop_learner_cv_state_marshaled) +S3method(unmarshal_model,pipeop_learner_state_marshaled) export("%>>!%") export("%>>%") export(Graph) @@ -154,5 +158,6 @@ importFrom(data.table,as.data.table) importFrom(digest,digest) importFrom(stats,setNames) importFrom(utils,bibentry) +importFrom(utils,head) importFrom(utils,tail) importFrom(withr,with_options) diff --git a/R/GraphLearner.R b/R/GraphLearner.R index 16c2c2afb..e573c0cbe 100644 --- a/R/GraphLearner.R +++ b/R/GraphLearner.R @@ -51,10 +51,10 @@ #' Whether the learner is marshaled. Read-only. #' #' @section Methods: -#' * `marshal_model(...)`\cr +#' * `marshal(...)`\cr #' (any) -> `self`\cr #' Marshal the model. -#' * `unmarshal_model(...)`\cr +#' * `unmarshal(...)`\cr #' (any) -> `self`\cr #' Unmarshal the model. #' @@ -253,41 +253,24 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, ) ) -#' @title (Un-)Marshal GraphLearner Model -#' @name marshal_graph_learner -#' @description -#' (Un-)marshal the model of a [`GraphLearner`]. -#' @param model (model of [`GraphLearner`])\cr -#' The model to be marshaled. -#' @param ... (any)\cr -#' Currently unused. -#' @param inplace (`logical(1)`)\cr -#' Whether to marshal in-place. -#' If `FALSE` (default), all R6-objects are cloned. -#' @keywords internal #' @export marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) { - x = map(model, function(po_state) { - po_state$model = if (!is.null(po_state$model)) marshal_model(po_state$model, inplace = inplace, ...) - po_state - }) - if (!some(map(x, "model"), is_marshaled_model)) { - return(structure(x, class = c("graph_learner_model", "list"))) - } + xm = map(.x = model, .f = marshal_model, inplace = inplace, ...) + # if none of the states required any marshaling we return the model as-is + if (!some(xm, is_marshaled_model)) return(model) + structure(list( - marshaled = x, - packages = "mlr3pipelines" + marshaled = xm, + packages = "mlr3pipelines" ), class = c("graph_learner_model_marshaled", "list_marshaled", "marshaled")) } #' @export unmarshal_model.graph_learner_model_marshaled = function(model, inplace = FALSE, ...) { structure( - map(model$marshaled, function(po_state) { - po_state$model = if (!is.null(po_state$model)) unmarshal_model(po_state$model, inplace = inplace, ...) - po_state - } - ), class = c("graph_learner_model", "list")) + map(.x = model$marshaled, .f = unmarshal_model, inplace = inplace, ...), + class = gsub(x = head(class(model), n = -1), pattern = "_marshaled$", replacement = "") + ) } #' @export diff --git a/R/PipeOpLearner.R b/R/PipeOpLearner.R index b35949e78..d56096ac9 100644 --- a/R/PipeOpLearner.R +++ b/R/PipeOpLearner.R @@ -139,7 +139,8 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, .train = function(inputs) { on.exit({private$.learner$state = NULL}) task = inputs[[1L]] - self$state = private$.learner$train(task)$state + learner_state = private$.learner$train(task)$state + self$state = structure(learner_state, class = c("pipeop_learner_state", class(learner_state))) list(NULL) }, @@ -154,4 +155,30 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp, ) ) +#' @export +marshal_model.pipeop_learner_state = function(model, inplace = FALSE, ...) { + # Note that a Learner state contains other objects with reference semantics, but we don't clone them here, even when inplace + # is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3 + # workhorse function + prev_class = class(model) + model$model = marshal_model(model$model, inplace = inplace) + # only wrap this in a marshaled class if the model was actually marshaled above + # (the default marshal method does nothing) + if (!is_marshaled_model(model$model)) return(model) + structure( + list(marshaled = model, packages = "mlr3pipelines"), + class = c(paste0(prev_class, "_marshaled"), "marshaled") + ) +} + +#' @export +unmarshal_model.pipeop_learner_state_marshaled = function(model, inplace = FALSE, ...) { + prev_class = head(class(model), n = -1) + state_marshaled = model$marshaled + state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace) + class(state_marshaled) = gsub(x = prev_class, pattern = "_marshaled$", replacement = "") + state_marshaled +} + + mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(id = "learner", task_type = "classif", param_set = ps(), packages = "mlr3pipelines"))$new())) diff --git a/R/PipeOpLearnerCV.R b/R/PipeOpLearnerCV.R index 994e1045a..580911264 100644 --- a/R/PipeOpLearnerCV.R +++ b/R/PipeOpLearnerCV.R @@ -175,6 +175,11 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV", } ), private = list( + .train = function(inputs) { + out = super$.train(inputs) + self$state = structure(self$state, class = c("pipeop_learner_cv_state", class(self$state))) + return(out) + }, .train_task = function(task) { on.exit({private$.learner$state = NULL}) @@ -222,4 +227,32 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV", ) ) +#' @export +marshal_model.pipeop_learner_cv_state = function(model, inplace = FALSE, ...) { + # Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace + # is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3 + # workhorse function + prev_class = class(model) + model$model = marshal_model(model$model, inplace = inplace) + # only wrap this in a marshaled class if the model was actually marshaled above + # (the default marshal method does nothing) + if (is_marshaled_model(model$model)) { + model = structure( + list(marshaled = model, packages = "mlr3pipelines"), + class = c(paste0(prev_class, "_marshaled"), "marshaled") + ) + } + model +} + +#' @export +unmarshal_model.pipeop_learner_cv_state_marshaled = function(model, inplace = FALSE, ...) { + prev_class = head(class(model), n = -1) + state_marshaled = model$marshaled + state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace) + class(state_marshaled) = gsub(x = prev_class, pattern = "_marshaled$", replacement = "") + state_marshaled +} + + mlr_pipeops$add("learner_cv", PipeOpLearnerCV, list(R6Class("Learner", public = list(id = "learner_cv", task_type = "classif", param_set = ps()))$new())) diff --git a/R/zzz.R b/R/zzz.R index cf50d0dfe..c6054af1c 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -4,7 +4,7 @@ #' @import paradox #' @import mlr3misc #' @importFrom R6 R6Class -#' @importFrom utils tail +#' @importFrom utils tail head #' @importFrom digest digest #' @importFrom withr with_options #' @importFrom stats setNames diff --git a/man/marshal_graph_learner.Rd b/man/marshal_graph_learner.Rd deleted file mode 100644 index 111cc85c6..000000000 --- a/man/marshal_graph_learner.Rd +++ /dev/null @@ -1,24 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/GraphLearner.R -\name{marshal_graph_learner} -\alias{marshal_graph_learner} -\alias{marshal_model.graph_learner_model} -\title{(Un-)Marshal GraphLearner Model} -\usage{ -\method{marshal_model}{graph_learner_model}(model, inplace = FALSE, ...) -} -\arguments{ -\item{model}{(model of \code{\link{GraphLearner}})\cr -The model to be marshaled.} - -\item{inplace}{(\code{logical(1)})\cr -Whether to marshal in-place. -If \code{FALSE} (default), all R6-objects are cloned.} - -\item{...}{(any)\cr -Currently unused.} -} -\description{ -(Un-)marshal the model of a \code{\link{GraphLearner}}. -} -\keyword{internal} diff --git a/man/mlr_learners_graph.Rd b/man/mlr_learners_graph.Rd index 9d58bd56d..19a368e08 100644 --- a/man/mlr_learners_graph.Rd +++ b/man/mlr_learners_graph.Rd @@ -64,10 +64,10 @@ Whether the learner is marshaled. Read-only. \section{Methods}{ \itemize{ -\item \code{marshal_model(...)}\cr +\item \code{marshal(...)}\cr (any) -> \code{self}\cr Marshal the model. -\item \code{unmarshal_model(...)}\cr +\item \code{unmarshal(...)}\cr (any) -> \code{self}\cr Unmarshal the model. } diff --git a/tests/testthat/test_GraphLearner.R b/tests/testthat/test_GraphLearner.R index 18da2a98a..38afa699d 100644 --- a/tests/testthat/test_GraphLearner.R +++ b/tests/testthat/test_GraphLearner.R @@ -575,12 +575,14 @@ test_that("marshal", { glrn$train(task) glrn$marshal() expect_true(glrn$marshaled) - expect_true(is_marshaled_model(glrn$state$model$marshaled$classif.debug$model)) + expect_true(is_marshaled_model(glrn$state$model$marshaled$classif.debug)) glrn$unmarshal() - expect_false(is_marshaled_model(glrn$model)) + expect_false(is_marshaled_model(glrn$state$model$marshaled$classif.debug)) expect_class(glrn$model, "graph_learner_model") expect_false(is_marshaled_model(glrn$state$model$marshaled$classif.debug$model)) + glrn$predict(task) + # checks that it is marshalable glrn$train(task) expect_learner(glrn, task) diff --git a/tests/testthat/test_mlr_graphs_bagging.R b/tests/testthat/test_mlr_graphs_bagging.R index 15a70fb0c..3cd5bee95 100644 --- a/tests/testthat/test_mlr_graphs_bagging.R +++ b/tests/testthat/test_mlr_graphs_bagging.R @@ -1,6 +1,5 @@ context("ppl - pipeline_bagging") - test_that("Bagging Pipeline", { skip_on_cran() # takes too long diff --git a/tests/testthat/test_pipeop_learner.R b/tests/testthat/test_pipeop_learner.R index 223a5292d..f9cf28422 100644 --- a/tests/testthat/test_pipeop_learner.R +++ b/tests/testthat/test_pipeop_learner.R @@ -84,10 +84,31 @@ test_that("PipeOpLearner - model active binding to state", { }) test_that("packages", { - expect_set_equal( c("mlr3pipelines", lrn("classif.rpart")$packages), po("learner", learner = lrn("classif.rpart"))$packages ) +}) + +test_that("marshal", { + task = tsk("iris") + po_lrn = as_pipeop(lrn("classif.debug")) + po_lrn$train(list(task)) + po_state = po_lrn$state + expect_class(po_state, "pipeop_learner_state") + po_state_marshaled = marshal_model(po_state, inplace = FALSE) + expect_class(po_state_marshaled, "pipeop_learner_state_marshaled") + expect_true(is_marshaled_model(po_state_marshaled)) + expect_equal(po_state, unmarshal_model(po_state_marshaled)) +}) + +test_that("multiple marshal round-trips", { + task = tsk("iris") + glrn = as_learner(as_graph(lrn("classif.debug"))) + glrn$train(task) + glrn$marshal()$unmarshal()$marshal()$unmarshal() + expect_class(glrn$model, "graph_learner_model") + expect_class(glrn$model$classif.debug$model, "classif.debug_model") + expect_learner(glrn, task = task) }) diff --git a/tests/testthat/test_pipeop_learnercv.R b/tests/testthat/test_pipeop_learnercv.R index bd369987a..9ad62501c 100644 --- a/tests/testthat/test_pipeop_learnercv.R +++ b/tests/testthat/test_pipeop_learnercv.R @@ -120,3 +120,15 @@ test_that("predict_type", { lcv$train(list(tsk("iris")))[[1]]$feature_names) }) + +test_that("marshal", { + task = tsk("iris") + po_lrn = as_pipeop(po("learner_cv", learner = lrn("classif.debug"))) + po_lrn$train(list(task)) + po_state = po_lrn$state + expect_class(po_state, "pipeop_learner_cv_state") + po_state_marshaled = marshal_model(po_state, inplace = FALSE) + expect_class(po_state_marshaled, "pipeop_learner_cv_state_marshaled") + expect_true(is_marshaled_model(po_state_marshaled)) + expect_equal(po_state, unmarshal_model(po_state_marshaled)) +})