Skip to content

Commit

Permalink
feat: implement marshaling
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Apr 10, 2024
1 parent 273d44b commit 11aa5ce
Show file tree
Hide file tree
Showing 11 changed files with 118 additions and 60 deletions.
5 changes: 5 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ S3method(as_pipeop,Learner)
S3method(as_pipeop,PipeOp)
S3method(as_pipeop,default)
S3method(marshal_model,graph_learner_model)
S3method(marshal_model,pipeop_learner_cv_state)
S3method(marshal_model,pipeop_learner_state)
S3method(po,"NULL")
S3method(po,Filter)
S3method(po,Learner)
Expand All @@ -25,6 +27,8 @@ S3method(predict,Graph)
S3method(print,Multiplicity)
S3method(print,Selector)
S3method(unmarshal_model,graph_learner_model_marshaled)
S3method(unmarshal_model,pipeop_learner_cv_state_marshaled)
S3method(unmarshal_model,pipeop_learner_state_marshaled)
export("%>>!%")
export("%>>%")
export(Graph)
Expand Down Expand Up @@ -154,5 +158,6 @@ importFrom(data.table,as.data.table)
importFrom(digest,digest)
importFrom(stats,setNames)
importFrom(utils,bibentry)
importFrom(utils,head)
importFrom(utils,tail)
importFrom(withr,with_options)
39 changes: 11 additions & 28 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,10 @@
#' Whether the learner is marshaled. Read-only.
#'
#' @section Methods:
#' * `marshal_model(...)`\cr
#' * `marshal(...)`\cr
#' (any) -> `self`\cr
#' Marshal the model.
#' * `unmarshal_model(...)`\cr
#' * `unmarshal(...)`\cr
#' (any) -> `self`\cr
#' Unmarshal the model.
#'
Expand Down Expand Up @@ -253,41 +253,24 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
)
)

#' @title (Un-)Marshal GraphLearner Model
#' @name marshal_graph_learner
#' @description
#' (Un-)marshal the model of a [`GraphLearner`].
#' @param model (model of [`GraphLearner`])\cr
#' The model to be marshaled.
#' @param ... (any)\cr
#' Currently unused.
#' @param inplace (`logical(1)`)\cr
#' Whether to marshal in-place.
#' If `FALSE` (default), all R6-objects are cloned.
#' @keywords internal
#' @export
marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) {
x = map(model, function(po_state) {
po_state$model = if (!is.null(po_state$model)) marshal_model(po_state$model, inplace = inplace, ...)
po_state
})
if (!some(map(x, "model"), is_marshaled_model)) {
return(structure(x, class = c("graph_learner_model", "list")))
}
xm = map(.x = model, .f = marshal_model, inplace = inplace, ...)
# if none of the states required any marshaling we return the model as-is
if (!some(xm, is_marshaled_model)) return(model)

structure(list(
marshaled = x,
packages = "mlr3pipelines"
marshaled = xm,
packages = "mlr3pipelines"
), class = c("graph_learner_model_marshaled", "list_marshaled", "marshaled"))
}

#' @export
unmarshal_model.graph_learner_model_marshaled = function(model, inplace = FALSE, ...) {
structure(
map(model$marshaled, function(po_state) {
po_state$model = if (!is.null(po_state$model)) unmarshal_model(po_state$model, inplace = inplace, ...)
po_state
}
), class = c("graph_learner_model", "list"))
map(.x = model$marshaled, .f = unmarshal_model, inplace = inplace, ...),
class = gsub(x = head(class(model), n = -1), pattern = "_marshaled$", replacement = "")
)
}

#' @export
Expand Down
29 changes: 28 additions & 1 deletion R/PipeOpLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,8 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
.train = function(inputs) {
on.exit({private$.learner$state = NULL})
task = inputs[[1L]]
self$state = private$.learner$train(task)$state
learner_state = private$.learner$train(task)$state
self$state = structure(learner_state, class = c("pipeop_learner_state", class(learner_state)))

list(NULL)
},
Expand All @@ -154,4 +155,30 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
)
)

#' @export
marshal_model.pipeop_learner_state = function(model, inplace = FALSE, ...) {
# Note that a Learner state contains other objects with reference semantics, but we don't clone them here, even when inplace
# is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3
# workhorse function
prev_class = class(model)
model$model = marshal_model(model$model, inplace = inplace)
# only wrap this in a marshaled class if the model was actually marshaled above
# (the default marshal method does nothing)
if (!is_marshaled_model(model$model)) return(model)
structure(
list(marshaled = model, packages = "mlr3pipelines"),
class = c(paste0(prev_class, "_marshaled"), "marshaled")
)
}

#' @export
unmarshal_model.pipeop_learner_state_marshaled = function(model, inplace = FALSE, ...) {
prev_class = head(class(model), n = -1)
state_marshaled = model$marshaled
state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace)
class(state_marshaled) = gsub(x = prev_class, pattern = "_marshaled$", replacement = "")
state_marshaled
}


mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(id = "learner", task_type = "classif", param_set = ps(), packages = "mlr3pipelines"))$new()))
33 changes: 33 additions & 0 deletions R/PipeOpLearnerCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
}
),
private = list(
.train = function(inputs) {
out = super$.train(inputs)
self$state = structure(self$state, class = c("pipeop_learner_cv_state", class(self$state)))
return(out)
},
.train_task = function(task) {
on.exit({private$.learner$state = NULL})

Expand Down Expand Up @@ -222,4 +227,32 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
)
)

#' @export
marshal_model.pipeop_learner_cv_state = function(model, inplace = FALSE, ...) {
# Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace
# is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3
# workhorse function
prev_class = class(model)
model$model = marshal_model(model$model, inplace = inplace)
# only wrap this in a marshaled class if the model was actually marshaled above
# (the default marshal method does nothing)
if (is_marshaled_model(model$model)) {
model = structure(
list(marshaled = model, packages = "mlr3pipelines"),
class = c(paste0(prev_class, "_marshaled"), "marshaled")
)
}
model
}

#' @export
unmarshal_model.pipeop_learner_cv_state_marshaled = function(model, inplace = FALSE, ...) {
prev_class = head(class(model), n = -1)
state_marshaled = model$marshaled
state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace)
class(state_marshaled) = gsub(x = prev_class, pattern = "_marshaled$", replacement = "")
state_marshaled
}


mlr_pipeops$add("learner_cv", PipeOpLearnerCV, list(R6Class("Learner", public = list(id = "learner_cv", task_type = "classif", param_set = ps()))$new()))
2 changes: 1 addition & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' @import paradox
#' @import mlr3misc
#' @importFrom R6 R6Class
#' @importFrom utils tail
#' @importFrom utils tail head
#' @importFrom digest digest
#' @importFrom withr with_options
#' @importFrom stats setNames
Expand Down
24 changes: 0 additions & 24 deletions man/marshal_graph_learner.Rd

This file was deleted.

4 changes: 2 additions & 2 deletions man/mlr_learners_graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions tests/testthat/test_GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -575,12 +575,14 @@ test_that("marshal", {
glrn$train(task)
glrn$marshal()
expect_true(glrn$marshaled)
expect_true(is_marshaled_model(glrn$state$model$marshaled$classif.debug$model))
expect_true(is_marshaled_model(glrn$state$model$marshaled$classif.debug))
glrn$unmarshal()
expect_false(is_marshaled_model(glrn$model))
expect_false(is_marshaled_model(glrn$state$model$marshaled$classif.debug))
expect_class(glrn$model, "graph_learner_model")
expect_false(is_marshaled_model(glrn$state$model$marshaled$classif.debug$model))

glrn$predict(task)

# checks that it is marshalable
glrn$train(task)
expect_learner(glrn, task)
Expand Down
1 change: 0 additions & 1 deletion tests/testthat/test_mlr_graphs_bagging.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
context("ppl - pipeline_bagging")


test_that("Bagging Pipeline", {
skip_on_cran() # takes too long

Expand Down
23 changes: 22 additions & 1 deletion tests/testthat/test_pipeop_learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,31 @@ test_that("PipeOpLearner - model active binding to state", {
})

test_that("packages", {

expect_set_equal(
c("mlr3pipelines", lrn("classif.rpart")$packages),
po("learner", learner = lrn("classif.rpart"))$packages
)
})

test_that("marshal", {
task = tsk("iris")
po_lrn = as_pipeop(lrn("classif.debug"))
po_lrn$train(list(task))
po_state = po_lrn$state
expect_class(po_state, "pipeop_learner_state")
po_state_marshaled = marshal_model(po_state, inplace = FALSE)
expect_class(po_state_marshaled, "pipeop_learner_state_marshaled")
expect_true(is_marshaled_model(po_state_marshaled))
expect_equal(po_state, unmarshal_model(po_state_marshaled))
})

test_that("multiple marshal round-trips", {
task = tsk("iris")
glrn = as_learner(as_graph(lrn("classif.debug")))
glrn$train(task)
glrn$marshal()$unmarshal()$marshal()$unmarshal()
expect_class(glrn$model, "graph_learner_model")
expect_class(glrn$model$classif.debug$model, "classif.debug_model")

expect_learner(glrn, task = task)
})
12 changes: 12 additions & 0 deletions tests/testthat/test_pipeop_learnercv.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,15 @@ test_that("predict_type", {
lcv$train(list(tsk("iris")))[[1]]$feature_names)

})

test_that("marshal", {
task = tsk("iris")
po_lrn = as_pipeop(po("learner_cv", learner = lrn("classif.debug")))
po_lrn$train(list(task))
po_state = po_lrn$state
expect_class(po_state, "pipeop_learner_cv_state")
po_state_marshaled = marshal_model(po_state, inplace = FALSE)
expect_class(po_state_marshaled, "pipeop_learner_cv_state_marshaled")
expect_true(is_marshaled_model(po_state_marshaled))
expect_equal(po_state, unmarshal_model(po_state_marshaled))
})

0 comments on commit 11aa5ce

Please sign in to comment.