Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Apr 10, 2024
1 parent dbc68f7 commit 54247a3
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 20 deletions.
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ S3method(as_pipeop,Learner)
S3method(as_pipeop,PipeOp)
S3method(as_pipeop,default)
S3method(marshal_model,graph_learner_model)
S3method(marshal_model,pipeop_learner_cv_state)
S3method(marshal_model,pipeop_learner_state)
S3method(po,"NULL")
S3method(po,Filter)
S3method(po,Learner)
Expand All @@ -25,6 +27,8 @@ S3method(predict,Graph)
S3method(print,Multiplicity)
S3method(print,Selector)
S3method(unmarshal_model,graph_learner_model_marshaled)
S3method(unmarshal_model,pipeop_learner_cv_state_marshaled)
S3method(unmarshal_model,pipeop_learner_state_marshaled)
export("%>>!%")
export("%>>%")
export(Graph)
Expand Down
20 changes: 6 additions & 14 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -263,31 +263,23 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
#' Currently unused.
#' @param inplace (`logical(1)`)\cr
#' Whether to marshal in-place.
#' If `FALSE` (default), all R6-objects are cloned.
#' @keywords internal
#' @export
marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) {
x = map(model, function(po_state) {
po_state$model = if (!is.null(po_state$model)) marshal_model(po_state$model, inplace = inplace, ...)
po_state
})
if (!some(map(x, "model"), is_marshaled_model)) {
x = map(.x = model, .f = marshal_model, inplace = inplace, ...)
if (!some(x, is_marshaled_model)) {
return(structure(x, class = c("graph_learner_model", "list")))
}
structure(list(
marshaled = x,
packages = "mlr3pipelines"
marshaled = x,
packages = "mlr3pipelines"
), class = c("graph_learner_model_marshaled", "list_marshaled", "marshaled"))
}

#' @export
unmarshal_model.graph_learner_model_marshaled = function(model, inplace = FALSE, ...) {
structure(
map(model$marshaled, function(po_state) {
po_state$model = if (!is.null(po_state$model)) unmarshal_model(po_state$model, inplace = inplace, ...)
po_state
}
), class = c("graph_learner_model", "list"))
unmarshaled = map(.x = model$marshaled, .f = unmarshal_model, inplace = inplace, ...)
structure(unmarshaled, class = c("graph_learner_model", "list"))
}

#' @export
Expand Down
28 changes: 27 additions & 1 deletion R/PipeOpLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
.train = function(inputs) {
on.exit({private$.learner$state = NULL})
task = inputs[[1L]]
self$state = private$.learner$train(task)$state
self$state = structure(private$.learner$train(task)$state, class = c("pipeop_learner_state", "list"))

list(NULL)
},
Expand All @@ -154,4 +154,30 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
)
)

#' @export
marshal_model.pipeop_learner_state = function(model, inplace = FALSE, ...) {
# Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace
# is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3
# workhorse function
model$model = marshal_model(model$model, inplace = inplace)
# only wrap this in a marshaled class if the model was actually marshaled above
# (the default marshal method does nothing)
if (is_marshaled_model(model$model)) {
model = structure(
list(marshaled = model, packages = "mlr3pipelines"),
class = c("pipeop_learner_state_marshaled", "list_marshaled", "marshaled")
)
}
model
}

#' @export
unmarshal_model.pipeop_learner_state_marshaled = function(model, inplace = FALSE, ...) {
state_marshaled = model$marshaled
state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace)
class(state_marshaled) = c("pipeop_learner_state", "list")
state_marshaled
}


mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(id = "learner", task_type = "classif", param_set = ps(), packages = "mlr3pipelines"))$new()))
31 changes: 31 additions & 0 deletions R/PipeOpLearnerCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
}
),
private = list(
.train = function(inputs) {
out = super$.train(inputs)
self$state = structure(self$state, class = c("pipeop_learner_cv_state", "list"))
return(out)
},
.train_task = function(task) {
on.exit({private$.learner$state = NULL})

Expand Down Expand Up @@ -222,4 +227,30 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
)
)

#' @export
marshal_model.pipeop_learner_cv_state = function(model, inplace = FALSE, ...) {
# Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace
# is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3
# workhorse function
model$model = marshal_model(model$model, inplace = inplace)
# only wrap this in a marshaled class if the model was actually marshaled above
# (the default marshal method does nothing)
if (is_marshaled_model(model$model)) {
model = structure(
list(marshaled = model, packages = "mlr3pipelines"),
class = c("pipeop_learner_cv_state_marshaled", "list_marshaled", "marshaled")
)
}
model
}

#' @export
unmarshal_model.pipeop_learner_cv_state_marshaled = function(model, inplace = FALSE, ...) {
state_marshaled = model$marshaled
state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace)
class(state_marshaled) = c("pipeop_learner_cv_state", "list")
state_marshaled
}


mlr_pipeops$add("learner_cv", PipeOpLearnerCV, list(R6Class("Learner", public = list(id = "learner_cv", task_type = "classif", param_set = ps()))$new()))
3 changes: 1 addition & 2 deletions man/marshal_graph_learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions tests/testthat/test_GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -575,12 +575,14 @@ test_that("marshal", {
glrn$train(task)
glrn$marshal()
expect_true(glrn$marshaled)
expect_true(is_marshaled_model(glrn$state$model$marshaled$classif.debug$model))
expect_true(is_marshaled_model(glrn$state$model$marshaled$classif.debug))
glrn$unmarshal()
expect_false(is_marshaled_model(glrn$model))
expect_false(is_marshaled_model(glrn$state$model$marshaled$classif.debug))
expect_class(glrn$model, "graph_learner_model")
expect_false(is_marshaled_model(glrn$state$model$marshaled$classif.debug$model))

glrn$predict(task)

# checks that it is marshalable
glrn$train(task)
expect_learner(glrn, task)
Expand Down
23 changes: 22 additions & 1 deletion tests/testthat/test_pipeop_learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,31 @@ test_that("PipeOpLearner - model active binding to state", {
})

test_that("packages", {

expect_set_equal(
c("mlr3pipelines", lrn("classif.rpart")$packages),
po("learner", learner = lrn("classif.rpart"))$packages
)
})

test_that("marshal", {
task = tsk("iris")
po_lrn = as_pipeop(lrn("classif.debug"))
po_lrn$train(list(task))
po_state = po_lrn$state
expect_class(po_state, "pipeop_learner_state")
po_state_marshaled = marshal_model(po_state, inplace = FALSE)
expect_class(po_state_marshaled, "pipeop_learner_state_marshaled")
expect_true(is_marshaled_model(po_state_marshaled))
expect_equal(po_state, unmarshal_model(po_state_marshaled))
})

test_that("multiple marshal round-trips", {
task = tsk("iris")
glrn = as_learner(as_graph(lrn("classif.debug")))
glrn$train(task)
glrn$marshal()$unmarshal()$marshal()$unmarshal()
expect_class(glrn$model, "graph_learner_model")
expect_class(glrn$model$classif.debug$model, "classif.debug_model")

expect_learner(glrn, task = task)
})
12 changes: 12 additions & 0 deletions tests/testthat/test_pipeop_learnercv.R
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,15 @@ test_that("predict_type", {
lcv$train(list(tsk("iris")))[[1]]$feature_names)

})

test_that("marshal", {
task = tsk("iris")
po_lrn = as_pipeop(po("learner_cv", learner = lrn("classif.debug")))
po_lrn$train(list(task))
po_state = po_lrn$state
expect_class(po_state, "pipeop_learner_cv_state")
po_state_marshaled = marshal_model(po_state, inplace = FALSE)
expect_class(po_state_marshaled, "pipeop_learner_cv_state_marshaled")
expect_true(is_marshaled_model(po_state_marshaled))
expect_equal(po_state, unmarshal_model(po_state_marshaled))
})

0 comments on commit 54247a3

Please sign in to comment.