Skip to content

Commit

Permalink
add marshaling
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Apr 9, 2024
1 parent 5ea54e1 commit 030b2b3
Show file tree
Hide file tree
Showing 11 changed files with 46 additions and 60 deletions.
6 changes: 0 additions & 6 deletions R/Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@
#' Whether to store intermediate results in the [`PipeOp`]'s `$.result` slot, mostly for debugging purposes. Default `FALSE`.
#' * `man` :: `character(1)`\cr
#' Identifying string of the help page that shows with `help()`.
#' * `properties` :: `character()`\cr
#' The properties of the `Graph` is the union of all the properties of its [`PipeOp`]s.
#'
#' @section Methods:
#' * `ids(sorted = FALSE)` \cr
Expand Down Expand Up @@ -506,10 +504,6 @@ Graph = R6Class("Graph",
} else {
map(self$pipeops, "state")
}
},
properties = function(rhs) {
assert_ro_binding(rhs)
sort(unique(unlist(map(self$pipeops, "properties"))))
}
),

Expand Down
36 changes: 17 additions & 19 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,11 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
assert_subset(task_type, mlr_reflections$task_types$type)

properties = mlr_reflections$learner_properties[[task_type]]

super$initialize(id = id, task_type = task_type,
feature_types = mlr_reflections$task_feature_types,
predict_types = names(mlr_reflections$learner_predict_types[[task_type]]),
packages = graph$packages,
properties = properties,
properties = mlr_reflections$learner_properties[[task_type]],
man = "mlr3pipelines::GraphLearner"
)

Expand Down Expand Up @@ -258,7 +256,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
#' @title (Un-)Marshal GraphLearner Model
#' @name marshal_graph_learner
#' @description
#' (Un-) marshal the model of a [`GraphLearner`].
#' (Un-)marshal the model of a [`GraphLearner`].
#' @param model (model of [`GraphLearner`])\cr
#' The model to be marshaled.
#' @param ... (any)\cr
Expand All @@ -268,27 +266,27 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
#' If `FALSE` (default), all R6-objects are cloned.
#' @export
marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) {
if (inplace) {
model_marhaled = structure(list(
marshaled = map(model, marshal_model, inplace = TRUE),
packages = "mlr3pipelines"
), class = c("graph_learner_model_marshaled", "list_marshaled", "marshaled"))

return(model_marshaled)
x = map(model, function(po_state) {
po_state$model = if (!is.null(po_state$model)) marshal_model(po_state$model, inplace = inplace, ...)
po_state
})
if (!some(map(x, "model"), is_marshaled_model)) {
return(structure(x, class = c("graph_learner_model", "list")))
}

model_clone = map(model, if (is.R6(model)) model$clone(deep = TRUE) else model)
model_marhaled = structure(list(
marshaled = map(model_clone, marshal_model, inplace = FALSE),
packages = "mlr3pipelines"
structure(list(
marshaled = x,
packages = "mlr3pipelines"
), class = c("graph_learner_model_marshaled", "list_marshaled", "marshaled"))
}

#' @export
unmarshal_model.graph_learner_model_marshaled = function(model, inplace = FALSE, ...) {
model = map(model$marshaled, unmarshal_model)
class(model) = c("graph_learner_model", "list")
model
structure(
map(model$marshaled, function(po_state) {
po_state$model = if (!is.null(po_state$model)) unmarshal_model(po_state$model, inplace = inplace, ...)
po_state
}
), class = c("graph_learner_model", "list"))
}

#' @export
Expand Down
10 changes: 1 addition & 9 deletions R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,6 @@
#' and done, if requested, by the [`Graph`] backend itself; it should *not* be done explicitly by `private$.train()` or `private$.predict()`.
#' * `man` :: `character(1)`\cr
#' Identifying string of the help page that shows with `help()`.
#' * `properties` :: `character()`\cr
#' The properties that this PipeOp has. See `mlr_reflections$pipeops$properties` for available values.
#'
#' @section Methods:
#' * `train(input)`\cr
Expand Down Expand Up @@ -238,7 +236,7 @@ PipeOp = R6Class("PipeOp",
.result = NULL,
tags = NULL,

initialize = function(id, param_set = ParamSet$new(), param_vals = list(), input, output, packages = character(0), tags = "abstract", properties = character(0)) {
initialize = function(id, param_set = ParamSet$new(), param_vals = list(), input, output, packages = character(0), tags = "abstract") {
if (inherits(param_set, "ParamSet")) {
private$.param_set = assert_param_set(param_set)
private$.param_set_source = NULL
Expand All @@ -248,7 +246,6 @@ PipeOp = R6Class("PipeOp",
}
self$id = assert_string(id)

private$.properties = sort(assert_subset(properties, mlr_reflections$pipeops$properties))
self$param_set$values = insert_named(self$param_set$values, param_vals)
self$input = assert_connection_table(input)
self$output = assert_connection_table(output)
Expand Down Expand Up @@ -414,10 +411,6 @@ PipeOp = R6Class("PipeOp",
}
}
private$.label
},
properties = function(rhs) {
assert_ro_binding(rhs)
private$.properties
}
),

Expand All @@ -436,7 +429,6 @@ PipeOp = R6Class("PipeOp",
}
value
},
.properties = NULL,
.train = function(input) stop("abstract"),
.predict = function(input) stop("abstract"),
.additional_phash_input = function() {
Expand Down
6 changes: 2 additions & 4 deletions R/PipeOpLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,10 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
type = private$.learner$task_type
task_type = mlr_reflections$task_types[type, mult = "first"]$task
out_type = mlr_reflections$task_types[type, mult = "first"]$prediction
properties = if ("marshal" %in% private$.learner$properties) "marshal" else character(0)
super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals,
input = data.table(name = "input", train = task_type, predict = task_type),
output = data.table(name = "output", train = "NULL", predict = out_type),
tags = "learner", packages = learner$packages, properties = properties
)
tags = "learner", packages = learner$packages)
}
),
active = list(
Expand Down Expand Up @@ -154,4 +152,4 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
)
)

mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(id = "learner", task_type = "classif", param_set = ParamSet$new(), packages = "mlr3pipelines", properties = character()))$new()))
mlr_pipeops$add("learner", PipeOpLearner, list(R6Class("Learner", public = list(id = "learner", task_type = "classif", param_set = ParamSet$new(), packages = "mlr3pipelines"))$new()))
6 changes: 2 additions & 4 deletions R/PipeOpLearnerCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,6 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
type = private$.learner$task_type
task_type = mlr_reflections$task_types[type, mult = "first"]$task

properties =if ("marshal" %in% learner$properties) "marshal" else character(0)

private$.crossval_param_set = ParamSet$new(params = list(
ParamFct$new("method", levels = c("cv", "insample"), tags = c("train", "required")),
ParamInt$new("folds", lower = 2L, upper = Inf, tags = c("train", "required")),
Expand All @@ -139,7 +137,7 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
# in PipeOp ParamSets.
# private$.crossval_param_set$add_dep("folds", "method", CondEqual$new("cv")) # don't do this.

super$initialize(id, alist(private$.crossval_param_set, private$.learner$param_set), param_vals = param_vals, can_subset_cols = TRUE, task_type = task_type, tags = c("learner", "ensemble"), properties = properties)
super$initialize(id, alist(private$.crossval_param_set, private$.learner$param_set), param_vals = param_vals, can_subset_cols = TRUE, task_type = task_type, tags = c("learner", "ensemble"))
}

),
Expand Down Expand Up @@ -220,4 +218,4 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
)
)

mlr_pipeops$add("learner_cv", PipeOpLearnerCV, list(R6Class("Learner", public = list(id = "learner_cv", task_type = "classif", param_set = ParamSet$new(), properties = character()))$new()))
mlr_pipeops$add("learner_cv", PipeOpLearnerCV, list(R6Class("Learner", public = list(id = "learner_cv", task_type = "classif", param_set = ParamSet$new()))$new()))
6 changes: 2 additions & 4 deletions R/PipeOpTaskPreproc.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ PipeOpTaskPreproc = R6Class("PipeOpTaskPreproc",

public = list(
initialize = function(id, param_set = ParamSet$new(), param_vals = list(), can_subset_cols = TRUE,
packages = character(0), task_type = "Task", tags = NULL, feature_types = mlr_reflections$task_feature_types, properties = character(0)) {
packages = character(0), task_type = "Task", tags = NULL, feature_types = mlr_reflections$task_feature_types) {
if (can_subset_cols) {
acp = ParamUty$new("affect_columns", custom_check = check_function_or_null, default = selector_all(), tags = "train")
if (inherits(param_set, "ParamSet")) {
Expand All @@ -183,9 +183,7 @@ PipeOpTaskPreproc = R6Class("PipeOpTaskPreproc",
super$initialize(id = id, param_set = param_set, param_vals = param_vals,
input = data.table(name = "input", train = task_type, predict = task_type),
output = data.table(name = "output", train = task_type, predict = task_type),
packages = packages, tags = c(tags, "data transform"),
properties = properties
)
packages = packages, tags = c(tags, "data transform"))
}
),
active = list(
Expand Down
2 changes: 0 additions & 2 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,6 @@ register_mlr3 = function() {
c("abstract", "meta", "missings", "feature selection", "imbalanced data",
"data transform", "target transform", "ensemble", "robustify", "learner", "encode",
"multiplicity")))

x$pipeops$properties = "marshal"
}

.onLoad = function(libname, pkgname) { # nocov start
Expand Down
2 changes: 0 additions & 2 deletions man/Graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 0 additions & 2 deletions man/PipeOp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/marshal_graph_learner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

28 changes: 21 additions & 7 deletions tests/testthat/test_GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -571,13 +571,27 @@ test_that("GraphLearner hashes", {

test_that("marshal", {
task = tsk("iris")
po_lily = as_pipeop(lrn("classif.lily"))
graph = as_graph(po_lily)
glrn = as_learner(graph)
expect_true("marshal" %in% glrn$properties)

# als checks that it is marshalable
glrn = as_learner(as_graph(lrn("classif.debug")))
glrn$train(task)
glrn$marshal()
expect_true(glrn$marshaled)
expect_true(is_marshaled_model(glrn$state$model$marshaled$classif.debug$model))
glrn$unmarshal()
expect_false(is_marshaled_model(glrn$model))
expect_class(glrn$model, "graph_learner_model")
expect_false(is_marshaled_model(glrn$state$model$marshaled$classif.debug$model))

# checks that it is marshalable
glrn$train(task)
expect_learner(glrn, task)
})

expect_false("marshal" %in% as_graph(lrn("regr.featureless"))$properties)
test_that("marshal has no effect when nothing needed marshaling", {
task = tsk("iris")
glrn = as_learner(as_graph(lrn("classif.rpart")))
glrn$train(task)
glrn$marshal()
expect_class(glrn$marshal()$model, "graph_learner_model")
expect_class(glrn$unmarshal()$model, "graph_learner_model")
expect_learner(glrn, task = task)
})

0 comments on commit 030b2b3

Please sign in to comment.