From 8946fcda7fa3f8ec4789a7e97156bf07ac2b1de2 Mon Sep 17 00:00:00 2001 From: mb706 Date: Fri, 16 Jun 2023 01:47:07 +0200 Subject: [PATCH 1/5] add `replace` argument to ppl("bagging") --- R/pipeline_bagging.R | 15 ++++++++++++--- man/mlr_graphs_bagging.Rd | 20 ++++++++++++++++++-- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/R/pipeline_bagging.R b/R/pipeline_bagging.R index afcf0c7f9..2ffd03d7a 100644 --- a/R/pipeline_bagging.R +++ b/R/pipeline_bagging.R @@ -28,6 +28,9 @@ #' predictions respectively. #' If `NULL` (default), no averager is added to the end of the graph. #' Note that setting `collect_multipliciy = TRUE` during construction of the averager is required. +#' @param replace `logical(1)` \cr +#' Whether to sample with replacement. +#' Default `FALSE`. #' @return [`Graph`] #' @export #' @examples @@ -36,9 +39,15 @@ #' lrn_po = po("learner", lrn("regr.rpart")) #' task = mlr_tasks$get("boston_housing") #' gr = pipeline_bagging(lrn_po, 3, averager = po("regravg", collect_multiplicity = TRUE)) -#' resample(task, GraphLearner$new(gr), rsmp("holdout")) +#' resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate() +#' +#' # The original bagging method uses boosting by sampling with replacement. +#' # This may give better performance but is also slower. +#' gr = ppl("bagging", lrn_po, frac = 1, replace = TRUE, +#' averager = po("regravg", collect_multiplicity = TRUE)) +#' resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate() #' } -pipeline_bagging = function(graph, iterations = 10, frac = 0.7, averager = NULL) { +pipeline_bagging = function(graph, iterations = 10, frac = 0.7, averager = NULL, replace = FALSE) { g = as_graph(graph) assert_count(iterations) assert_number(frac, lower = 0, upper = 1) @@ -50,7 +59,7 @@ pipeline_bagging = function(graph, iterations = 10, frac = 0.7, averager = NULL) } po("replicate", param_vals = list(reps = iterations)) %>>!% - po("subsample", param_vals = list(frac = frac)) %>>!% + po("subsample", param_vals = list(frac = frac, replace = replace)) %>>!% g %>>!% averager } diff --git a/man/mlr_graphs_bagging.Rd b/man/mlr_graphs_bagging.Rd index 42828ef94..186ac71ff 100644 --- a/man/mlr_graphs_bagging.Rd +++ b/man/mlr_graphs_bagging.Rd @@ -5,7 +5,13 @@ \alias{pipeline_bagging} \title{Create a bagging learner} \usage{ -pipeline_bagging(graph, iterations = 10, frac = 0.7, averager = NULL) +pipeline_bagging( + graph, + iterations = 10, + frac = 0.7, + averager = NULL, + replace = FALSE +) } \arguments{ \item{graph}{\code{\link{PipeOp}} | \code{\link{Graph}} \cr @@ -27,6 +33,10 @@ in order to perform simple averaging of classification and regression predictions respectively. If \code{NULL} (default), no averager is added to the end of the graph. Note that setting \code{collect_multipliciy = TRUE} during construction of the averager is required.} + +\item{replace}{\code{logical(1)} \cr +Whether to sample with replacement. +Default \code{FALSE}.} } \value{ \code{\link{Graph}} @@ -49,6 +59,12 @@ library(mlr3) lrn_po = po("learner", lrn("regr.rpart")) task = mlr_tasks$get("boston_housing") gr = pipeline_bagging(lrn_po, 3, averager = po("regravg", collect_multiplicity = TRUE)) -resample(task, GraphLearner$new(gr), rsmp("holdout")) +resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate() + +# The original bagging method uses boosting by sampling with replacement. +# This may give better performance but is also slower. +gr = ppl("bagging", lrn_po, frac = 1, replace = TRUE, + averager = po("regravg", collect_multiplicity = TRUE)) +resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate() } } From b1713faa0b9de29722d72a624a9e5d73b932791a Mon Sep 17 00:00:00 2001 From: mb706 Date: Fri, 16 Jun 2023 01:49:38 +0200 Subject: [PATCH 2/5] news entry --- NEWS.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/NEWS.md b/NEWS.md index a1074398a..5e13c1044 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,7 @@ # mlr3pipelines 0.5.0-9000 +* `pipeline_bagging()` gets the `replace` argument. + # mlr3pipelines 0.5.0-1 * Bugfix: `PipeOpTuneThreshold` was not overloading the correct `.train` and `.predict` functions. From 99ee083a1b340e0e159a80fa2c05f63916200953 Mon Sep 17 00:00:00 2001 From: mb706 Date: Tue, 26 Mar 2024 15:43:51 +0100 Subject: [PATCH 3/5] doc --- R/pipeline_bagging.R | 1 - man/mlr_graphs_bagging.Rd | 1 - 2 files changed, 2 deletions(-) diff --git a/R/pipeline_bagging.R b/R/pipeline_bagging.R index 2ffd03d7a..31b743d32 100644 --- a/R/pipeline_bagging.R +++ b/R/pipeline_bagging.R @@ -42,7 +42,6 @@ #' resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate() #' #' # The original bagging method uses boosting by sampling with replacement. -#' # This may give better performance but is also slower. #' gr = ppl("bagging", lrn_po, frac = 1, replace = TRUE, #' averager = po("regravg", collect_multiplicity = TRUE)) #' resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate() diff --git a/man/mlr_graphs_bagging.Rd b/man/mlr_graphs_bagging.Rd index 186ac71ff..58d5e1e83 100644 --- a/man/mlr_graphs_bagging.Rd +++ b/man/mlr_graphs_bagging.Rd @@ -62,7 +62,6 @@ gr = pipeline_bagging(lrn_po, 3, averager = po("regravg", collect_multiplicity = resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate() # The original bagging method uses boosting by sampling with replacement. -# This may give better performance but is also slower. gr = ppl("bagging", lrn_po, frac = 1, replace = TRUE, averager = po("regravg", collect_multiplicity = TRUE)) resample(task, GraphLearner$new(gr), rsmp("holdout"))$aggregate() From 0d8a29d31ae5fc5d945eaf12527444835287421d Mon Sep 17 00:00:00 2001 From: mb706 Date: Tue, 26 Mar 2024 16:15:39 +0100 Subject: [PATCH 4/5] tests --- tests/testthat/test_mlr_graphs_bagging.R | 33 ++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/tests/testthat/test_mlr_graphs_bagging.R b/tests/testthat/test_mlr_graphs_bagging.R index a5dc1067c..b7deb981a 100644 --- a/tests/testthat/test_mlr_graphs_bagging.R +++ b/tests/testthat/test_mlr_graphs_bagging.R @@ -39,3 +39,36 @@ test_that("Bagging Pipeline", { expect_true(all(map_lgl(predict_out, function(x) "PredictionClassif" %in% class(x)))) }) +test_that("Bagging with replacement", { + tsk = tsk("iris") + lrn = lrn("classif.rpart") + p = ppl("bagging", graph = po(lrn), replace = TRUE, averager = po("classifavg", collect_multiplicity = TRUE)) + expect_graph(p) + res = resample(tsk, GraphLearner$new(p), rsmp("holdout")) + expect_resample_result(res) + + tsk$filter(1:140) + expect_equal(anyDuplicated(tsk$data()), 0) # make sure no duplicates + + p = ppl("bagging", iterations = 2, + graph = lrn("classif.debug", save_tasks = TRUE), + replace = TRUE, averager = po("classifavg", collect_multiplicity = TRUE) + ) + p$train(tsk) + + expect_true(anyDuplicated(p$pipeops$classif.debug$state[[1]]$model$task_train$data()) != 0) + + getOrigId = function(data) { + tsk$data()[, origline := .I][data, on = colnames(tsk$data()), origline] + } + orig_id_1 = getOrigId(p$pipeops$classif.debug$state[[1]]$model$task_train$data()) + orig_id_2 = getOrigId(p$pipeops$classif.debug$state[[2]]$model$task_train$data()) + + expect_equal(length(orig_id_1), 140) + expect_equal(length(orig_id_2), 140) + # if we sampled the same values twice, the all.equal() would just give TRUE + expect_string(all.equal(orig_id_1, orig_id_2)) + + expect_true(length(unique(orig_id_1)) < 140) + expect_true(length(unique(orig_id_2)) < 140) +}) From 22ecbfc008235675a5dd6a6fd4eeb6b08eb134e0 Mon Sep 17 00:00:00 2001 From: mb706 Date: Tue, 26 Mar 2024 17:17:12 +0100 Subject: [PATCH 5/5] fix --- tests/testthat/test_mlr_graphs_bagging.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test_mlr_graphs_bagging.R b/tests/testthat/test_mlr_graphs_bagging.R index b7deb981a..15a70fb0c 100644 --- a/tests/testthat/test_mlr_graphs_bagging.R +++ b/tests/testthat/test_mlr_graphs_bagging.R @@ -50,7 +50,7 @@ test_that("Bagging with replacement", { tsk$filter(1:140) expect_equal(anyDuplicated(tsk$data()), 0) # make sure no duplicates - p = ppl("bagging", iterations = 2, + p = ppl("bagging", iterations = 2, frac = 1, graph = lrn("classif.debug", save_tasks = TRUE), replace = TRUE, averager = po("classifavg", collect_multiplicity = TRUE) )