Skip to content

Commit

Permalink
Merge pull request #725 from mlr-org/ppl_bagging_with_replacement
Browse files Browse the repository at this point in the history
add `replace` argument to ppl("bagging")
  • Loading branch information
mb706 authored Mar 26, 2024
2 parents 11a22ce + faab712 commit 5dc84f6
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 8 deletions.
7 changes: 4 additions & 3 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# mlr3pipelines 0.5.0-9000

* Feature: The `$add_pipeop()` method got an argument `clone` (old behaviour `TRUE` by default)
* Bugfix: `PipeOpFeatureUnion` in some rare cases dropped variables called `"x"`
* Compatibility with upcoming paradox release
* `pipeline_bagging()` gets the `replace` argument (old behaviour `FALSE` by default).
* Feature: The `$add_pipeop()` method got an argument `clone` (old behaviour `TRUE` by default).
* Bugfix: `PipeOpFeatureUnion` in some rare cases dropped variables called `"x"`.
* Compatibility with upcoming paradox release.

# mlr3pipelines 0.5.0-2

Expand Down
14 changes: 11 additions & 3 deletions R/pipeline_bagging.R
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
#' predictions respectively.
#' If `NULL` (default), no averager is added to the end of the graph.
#' Note that setting `collect_multipliciy = TRUE` during construction of the averager is required.
#' @param replace `logical(1)` \cr
#' Whether to sample with replacement.
#' Default `FALSE`.
#' @return [`Graph`]
#' @export
#' @examples
Expand All @@ -36,9 +39,14 @@
#' lrn_po = po("learner", lrn("regr.rpart"))
#' task = mlr_tasks$get("boston_housing")
#' gr = pipeline_bagging(lrn_po, 3, averager = po("regravg", collect_multiplicity = TRUE))
#' resample(task, GraphLearner$new(gr), rsmp("holdout"))
#' resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate()
#'
#' # The original bagging method uses boosting by sampling with replacement.
#' gr = ppl("bagging", lrn_po, frac = 1, replace = TRUE,
#' averager = po("regravg", collect_multiplicity = TRUE))
#' resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate()
#' }
pipeline_bagging = function(graph, iterations = 10, frac = 0.7, averager = NULL) {
pipeline_bagging = function(graph, iterations = 10, frac = 0.7, averager = NULL, replace = FALSE) {
g = as_graph(graph)
assert_count(iterations)
assert_number(frac, lower = 0, upper = 1)
Expand All @@ -50,7 +58,7 @@ pipeline_bagging = function(graph, iterations = 10, frac = 0.7, averager = NULL)
}

po("replicate", param_vals = list(reps = iterations)) %>>!%
po("subsample", param_vals = list(frac = frac)) %>>!%
po("subsample", param_vals = list(frac = frac, replace = replace)) %>>!%
g %>>!%
averager
}
Expand Down
19 changes: 17 additions & 2 deletions man/mlr_graphs_bagging.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions tests/testthat/test_mlr_graphs_bagging.R
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,36 @@ test_that("Bagging Pipeline", {
expect_true(all(map_lgl(predict_out, function(x) "PredictionClassif" %in% class(x))))
})

test_that("Bagging with replacement", {
tsk = tsk("iris")
lrn = lrn("classif.rpart")
p = ppl("bagging", graph = po(lrn), replace = TRUE, averager = po("classifavg", collect_multiplicity = TRUE))
expect_graph(p)
res = resample(tsk, GraphLearner$new(p), rsmp("holdout"))
expect_resample_result(res)

tsk$filter(1:140)
expect_equal(anyDuplicated(tsk$data()), 0) # make sure no duplicates

p = ppl("bagging", iterations = 2, frac = 1,
graph = lrn("classif.debug", save_tasks = TRUE),
replace = TRUE, averager = po("classifavg", collect_multiplicity = TRUE)
)
p$train(tsk)

expect_true(anyDuplicated(p$pipeops$classif.debug$state[[1]]$model$task_train$data()) != 0)

getOrigId = function(data) {
tsk$data()[, origline := .I][data, on = colnames(tsk$data()), origline]
}
orig_id_1 = getOrigId(p$pipeops$classif.debug$state[[1]]$model$task_train$data())
orig_id_2 = getOrigId(p$pipeops$classif.debug$state[[2]]$model$task_train$data())

expect_equal(length(orig_id_1), 140)
expect_equal(length(orig_id_2), 140)
# if we sampled the same values twice, the all.equal() would just give TRUE
expect_string(all.equal(orig_id_1, orig_id_2))

expect_true(length(unique(orig_id_1)) < 140)
expect_true(length(unique(orig_id_2)) < 140)
})

0 comments on commit 5dc84f6

Please sign in to comment.