Skip to content

Commit

Permalink
merge bundling
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed May 7, 2024
2 parents 6d851c6 + b031b22 commit ae16775
Show file tree
Hide file tree
Showing 21 changed files with 357 additions and 14 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Imports:
data.table,
digest,
lgr,
mlr3 (>= 0.6.0),
mlr3 (>= 0.19.0),
mlr3misc (>= 0.9.0),
paradox,
R6,
Expand Down
9 changes: 9 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ S3method(as_pipeop,Learner)
S3method(as_pipeop,PipeOp)
S3method(as_pipeop,default)
S3method(disable_internal_tuning,GraphLearner)
S3method(marshal_model,Multiplicity)
S3method(marshal_model,graph_learner_model)
S3method(marshal_model,pipeop_impute_learner_state)
S3method(marshal_model,pipeop_learner_cv_state)
S3method(po,"NULL")
S3method(po,Filter)
S3method(po,Learner)
Expand All @@ -25,6 +29,10 @@ S3method(predict,Graph)
S3method(print,Multiplicity)
S3method(print,Selector)
S3method(set_validate,GraphLearner)
S3method(unmarshal_model,Multiplicity_marshaled)
S3method(unmarshal_model,graph_learner_model_marshaled)
S3method(unmarshal_model,pipeop_impute_learner_state_marshaled)
S3method(unmarshal_model,pipeop_learner_cv_state_marshaled)
export("%>>!%")
export("%>>%")
export(Graph)
Expand Down Expand Up @@ -155,5 +163,6 @@ importFrom(data.table,as.data.table)
importFrom(digest,digest)
importFrom(stats,setNames)
importFrom(utils,bibentry)
importFrom(utils,head)
importFrom(utils,tail)
importFrom(withr,with_options)
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
* Minor documentation fixes.
* Test helpers are now available in `inst/`. These are considered experimental and unstable.

* Added marshaling support to `GraphLearner`

# mlr3pipelines 0.5.1

* Changed the ID of `PipeOpFeatureUnion` used in `ppl("robustify")` and `ppl("stacking")`.
Expand Down
41 changes: 40 additions & 1 deletion R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,16 @@
#' * `validate` :: `numeric(1)`, `"predefined"`, `"test"` or `NULL`\cr
#' How to construct the validation data. This also has to be configured in the individual learners wrapped by
#' `PipeOpLearner`, see [`set_validate.GraphLearner`] on how to configure this.
#' * `marshaled` :: `logical(1)`\cr
#' Whether the learner is marshaled.
#'
#' @section Methods:
#' * `marshal(...)`\cr
#' (any) -> `self`\cr
#' Marshal the model.
#' * `unmarshal(...)`\cr
#' (any) -> `self`\cr
#' Unmarshal the model.
#'
#' @section Internals:
#' [`as_graph()`] is called on the `graph` argument, so it can technically also be a `list` of things, which is
Expand Down Expand Up @@ -150,6 +159,12 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
}
learner_model$base_learner(recursive - 1)
},
marshal = function(...) {
learner_marshal(.learner = self, ...)
},
unmarshal = function(...) {
learner_unmarshal(.learner = self, ...)
}
),
active = list(
Expand All @@ -169,7 +184,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
private$.validate = assert_validate(rhs)
}
private$.validate

},
marshaled = function() {
learner_marshaled(self)
},
hash = function() {
digest(list(class(self), self$id, self$graph$hash, private$.predict_type, private$.validate,
Expand Down Expand Up @@ -260,6 +277,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
on.exit({self$graph$state = NULL})
self$graph$train(task)
state = self$graph$state
class(state) = c("graph_learner_model", class(state))
state
},
.predict = function(task) {
Expand Down Expand Up @@ -414,6 +432,27 @@ disable_internal_tuning.GraphLearner = function(learner, ids, ...) {
}


#' @export
marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) {
xm = map(.x = model, .f = marshal_model, inplace = inplace, ...)
# if none of the states required any marshaling we return the model as-is
if (!some(xm, is_marshaled_model)) return(model)

structure(list(
marshaled = xm,
packages = "mlr3pipelines"
), class = c("graph_learner_model_marshaled", "list_marshaled", "marshaled"))
}

#' @export
unmarshal_model.graph_learner_model_marshaled = function(model, inplace = FALSE, ...) {
# need to re-create the class as it gets lost during marshaling
structure(
map(.x = model$marshaled, .f = unmarshal_model, inplace = inplace, ...),
class = gsub(x = head(class(model), n = -1), pattern = "_marshaled$", replacement = "")
)
}

#' @export
as_learner.Graph = function(x, clone = FALSE, ...) {
GraphLearner$new(x, clone_graph = clone)
Expand Down
31 changes: 31 additions & 0 deletions R/PipeOpImputeLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
#' for each column. If a column consists of missing values only during training, the `model` is `0` or the levels of the
#' feature; these are used for sampling during prediction.
#'
#' This state is given the class `"pipeop_impute_learner_state"`.
#'
#' @section Parameters:
#' The parameters are the parameters inherited from [`PipeOpImpute`], in addition to the parameters of the [`Learner`][mlr3::Learner]
#' used for imputation.
Expand Down Expand Up @@ -116,6 +118,13 @@ PipeOpImputeLearner = R6Class("PipeOpImputeLearner",
)
super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals,
whole_task_dependent = TRUE, feature_types = feature_types)
},
train = function(inputs) {
outputs = super$train(inputs)
self$state = multiplicity_recurse(self$state, function(state) {
structure(state, class = c("pipeop_impute_learner_state", class(state)))
})
return(outputs)
}
),
active = list(
Expand Down Expand Up @@ -206,3 +215,25 @@ mlr_pipeops$add("imputelearner", PipeOpImputeLearner, list(R6Class("Learner", pu
convert_to_task = function(id = "imputing", data, target, task_type, ...) {
get(mlr_reflections$task_types[task_type, mult = "first"]$task)$new(id = id, backend = data, target = target, ...)
}

#' @export
marshal_model.pipeop_impute_learner_state = function(model, inplace = FALSE, ...) {
prev_class = class(model)
model$model = map(model$model, marshal_model, inplace = inplace, ...)

if (!some(model$model, is_marshaled_model)) {
return(model)
}

structure(
list(marshaled = model, packages = "mlr3pipelines"),
class = c(paste0(prev_class, "_marshaled"), "marshaled")
)
}

#' @export
unmarshal_model.pipeop_impute_learner_state_marshaled = function(model, inplace = FALSE, ...) {
state = model$marshaled
state$model = map(state$model, unmarshal_model, inplace = inplace, ...)
return(state)
}
3 changes: 1 addition & 2 deletions R/PipeOpLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,7 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals,
input = data.table(name = "input", train = task_type, predict = task_type),
output = data.table(name = "output", train = "NULL", predict = out_type),
tags = "learner", packages = learner$packages
)
tags = "learner", packages = learner$packages)
}
),
active = list(
Expand Down
35 changes: 34 additions & 1 deletion R/PipeOpLearnerCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@
#' * `predict_time` :: `NULL` | `numeric(1)`
#' Prediction time, in seconds.
#'
#' This state is given the class `"pipeop_learner_cv_state"`.
#'
#' @section Parameters:
#' The parameters are the parameters inherited from the [`PipeOpTaskPreproc`], as well as the parameters of the [`Learner`][mlr3::Learner] wrapped by this object.
#' Besides that, parameters introduced are:
Expand Down Expand Up @@ -144,8 +146,14 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
# private$.crossval_param_set$add_dep("folds", "method", CondEqual$new("cv")) # don't do this.

super$initialize(id, alist(resampling = private$.crossval_param_set, private$.learner$param_set), param_vals = param_vals, can_subset_cols = TRUE, task_type = task_type, tags = c("learner", "ensemble"))
},
train = function(inputs) {
outputs = super$train(inputs)
self$state = multiplicity_recurse(self$state, function(state) {
structure(state, class = c("pipeop_learner_cv_state", class(state)))
})
return(outputs)
}

),
active = list(
learner = function(val) {
Expand Down Expand Up @@ -224,4 +232,29 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
)
)

#' @export
marshal_model.pipeop_learner_cv_state = function(model, inplace = FALSE, ...) {
# Note that a Learner state contains other reference objects, but we don't clone them here, even when inplace
# is FALSE. For our use-case this is just not necessary and would cause unnecessary overhead in the mlr3
# workhorse function
model$model = marshal_model(model$model, inplace = inplace)
# only wrap this in a marshaled class if the model was actually marshaled above
# (the default marshal method does nothing)
if (is_marshaled_model(model$model)) {
model = structure(
list(marshaled = model, packages = "mlr3pipelines"),
class = c(paste0(class(model), "_marshaled"), "marshaled")
)
}
model
}

#' @export
unmarshal_model.pipeop_learner_cv_state_marshaled = function(model, inplace = FALSE, ...) {
state_marshaled = model$marshaled
state_marshaled$model = unmarshal_model(state_marshaled$model, inplace = inplace)
state_marshaled
}


mlr_pipeops$add("learner_cv", PipeOpLearnerCV, list(R6Class("Learner", public = list(id = "learner_cv", task_type = "classif", param_set = ps()))$new()))
3 changes: 1 addition & 2 deletions R/PipeOpTaskPreproc.R
Original file line number Diff line number Diff line change
Expand Up @@ -187,8 +187,7 @@ PipeOpTaskPreproc = R6Class("PipeOpTaskPreproc",
super$initialize(id = id, param_set = param_set, param_vals = param_vals,
input = data.table(name = "input", train = task_type, predict = task_type),
output = data.table(name = "output", train = task_type, predict = task_type),
packages = packages, tags = c(tags, "data transform")
)
packages = packages, tags = c(tags, "data transform"))
}
),
active = list(
Expand Down
13 changes: 13 additions & 0 deletions R/multiplicity.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,16 @@ multiplicity_nests_deeper_than = function(x, cutoff) {
}
ret
}

#' @export
marshal_model.Multiplicity = function(model, inplace = FALSE, ...) {
structure(list(
marshaled = multiplicity_recurse(model, marshal_model, inplace = inplace, ...),
packages = "mlr3pipelines"
), class = c("Multiplicity_marshaled", "marshaled"))
}

#' @export
unmarshal_model.Multiplicity_marshaled = function(model, inplace = FALSE, ...) {
multiplicity_recurse(model$marshaled, unmarshal_model, inplace = inplace, ...)
}
2 changes: 1 addition & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' @import paradox
#' @import mlr3misc
#' @importFrom R6 R6Class
#' @importFrom utils tail
#' @importFrom utils tail head
#' @importFrom digest digest
#' @importFrom withr with_options
#' @importFrom stats setNames
Expand Down
2 changes: 1 addition & 1 deletion man/mlr_learners_avg.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions man/mlr_learners_graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/mlr_pipeops_imputelearner.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/mlr_pipeops_learner_cv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_pipeops_tunethreshold.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

34 changes: 33 additions & 1 deletion tests/testthat/test_GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -612,7 +612,8 @@ test_that("internal_tuned_values", {
expect_false("internal_tuning" %in% glrn1$properties)
expect_equal(glrn1$internal_tuned_values, NULL)

# learner with internal tuning
# learner wQ
# ith internal tuning
glrn2 = as_learner(as_graph(lrn("classif.debug")))
expect_true("internal_tuning" %in% glrn2$properties)
expect_equal(glrn2$internal_tuned_values, NULL)
Expand Down Expand Up @@ -661,3 +662,34 @@ test_that("set_validate", {
expect_equal(glrn2$graph$pipeops$polearner$learner$graph$pipeops$final$learner$validate, "predefined")
expect_equal(glrn2$graph$pipeops$polearner$learner$graph$pipeops$classif.debug$learner$validate, NULL)
})

test_that("marshal", {
task = tsk("iris")
glrn = as_learner(as_graph(lrn("classif.debug")))
glrn$train(task)
p1 = glrn$predict(task)
glrn$marshal()
expect_true(glrn$marshaled)
expect_true(is_marshaled_model(glrn$state$model$marshaled$classif.debug))
glrn$unmarshal()
expect_false(is_marshaled_model(glrn$state$model$marshaled$classif.debug))
expect_class(glrn$model, "graph_learner_model")
expect_false(is_marshaled_model(glrn$state$model$marshaled$classif.debug$model))

p2 = glrn$predict(task)
expect_equal(p1$response, p2$response)

# checks that it is marshalable
glrn$train(task)
expect_learner(glrn, task)
})

test_that("marshal has no effect when nothing needed marshaling", {
task = tsk("iris")
glrn = as_learner(as_graph(lrn("classif.rpart")))
glrn$train(task)
glrn$marshal()
expect_class(glrn$marshal()$model, "graph_learner_model")
expect_class(glrn$unmarshal()$model, "graph_learner_model")
expect_learner(glrn, task = task)
})
1 change: 0 additions & 1 deletion tests/testthat/test_mlr_graphs_bagging.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
context("ppl - pipeline_bagging")


test_that("Bagging Pipeline", {
skip_if_not_installed("rpart")
skip_on_cran() # takes too long
Expand Down
1 change: 0 additions & 1 deletion tests/testthat/test_pipeop_impute.R
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,6 @@ test_that("More tests for Integers", {
expect_false(any(is.na(result$data()$x)), info = po$id)
expect_equal(result$missings(), c(t = 0, x = 0), info = po$id)
}

})

test_that("impute, test rows and affect_columns", {
Expand Down
Loading

0 comments on commit ae16775

Please sign in to comment.