Skip to content

Commit

Permalink
Merge branch 'feat/inner_valid' of github.com:mlr-org/mlr3pipelines i…
Browse files Browse the repository at this point in the history
…nto feat/inner_valid
  • Loading branch information
be-marc committed Jun 25, 2024
2 parents 5cda404 + 2e12c95 commit 0170eac
Show file tree
Hide file tree
Showing 8 changed files with 34 additions and 51 deletions.
1 change: 0 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ S3method(as_pipeop,Filter)
S3method(as_pipeop,Learner)
S3method(as_pipeop,PipeOp)
S3method(as_pipeop,default)
S3method(disable_internal_tuning,GraphLearner)
S3method(marshal_model,Multiplicity)
S3method(marshal_model,graph_learner_model)
S3method(marshal_model,pipeop_impute_learner_state)
Expand Down
22 changes: 7 additions & 15 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,13 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
if (!some_pipeops_validate) {
lg$warn("GraphLearner '%s' specifies a validation set, but none of its Learners use it.", self$id)
}
} else {
# otherwise the pipeops will preprocess this unnecessarily
if (!is.null(task$internal_valid_task)) {
prev_itv = task$internal_valid_task
on.exit({task$internal_valid_task = prev_itv}, add = TRUE)
task$internal_valid_task = NULL
}
}

on.exit({self$graph$state = NULL})
Expand Down Expand Up @@ -417,21 +424,6 @@ set_validate.GraphLearner = function(learner, validate, ids = NULL, args = list(
invisible(learner)
}


#' @export
disable_internal_tuning.GraphLearner = function(learner, ids, ...) {
pvs = learner$param_set$values
on.exit({learner$param_set$values = pvs}, add = TRUE)
if (length(ids)) {
walk(learner_wrapping_pipeops(learner), function(po) {
disable_internal_tuning(po$learner, ids = po$param_set$ids()[sprintf("%s.%s", po$id, po$param_set$ids()) %in% ids])
})
}
on.exit()
invisible(learner)
}


#' @export
marshal_model.graph_learner_model = function(model, inplace = FALSE, ...) {
xm = map(.x = model, .f = marshal_model, inplace = inplace, ...)
Expand Down
3 changes: 1 addition & 2 deletions R/pipeline_branch.R
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,9 @@ pipeline_branch = function(graphs, prefix_branchops = "", prefix_paths = FALSE)
pmap(list(
src_id = branch_id, dst_id = gin$op.id,
src_channel = branch_chan, dst_channel = gin$channel.name),
graph$add_edge)
graph$add_edge)
})
graph
}

mlr_graphs$add("branch", pipeline_branch)

2 changes: 1 addition & 1 deletion inst/testthat/helper_functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ expect_datapreproc_pipeop_class = function(poclass, constargs = list(), task,
expect_true(task$nrow >= 5)

# overlap between use and test rows
tasktrain$divide(tasktrain$row_roles$use[seq(n_use - 2, n_use)], remove = FALSE)
tasktrain$divide(ids = tasktrain$row_roles$use[seq(n_use - 2, n_use)], remove = FALSE)
tasktrain$row_roles$use = tasktrain$row_roles$use[seq(1, n_use - 2)]

taskpredict = tasktrain$clone(deep = TRUE)
Expand Down
2 changes: 1 addition & 1 deletion man/mlr_learners_avg.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/mlr_pipeops_tunethreshold.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 22 additions & 29 deletions tests/testthat/test_GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ context("GraphLearner")

test_that("basic graphlearner tests", {
skip_if_not_installed("rpart")
skip_on_cran() # takes too long
skip_on_cran() # takes too long
task = mlr_tasks$get("iris")

lrn = mlr_learners$get("classif.rpart")
Expand Down Expand Up @@ -40,8 +40,8 @@ test_that("basic graphlearner tests", {
expect_true(run_experiment(task, glrn)$ok)
glrn2$train(task)
glrn2_clone$state = glrn2$state
# glrn2_clone$state$log = glrn2_clone$state$log$clone(deep = TRUE) # FIXME: this can go when mlr-org/mlr3#343 is fixed
# glrn2_clone$state$model$classif.rpart$log = glrn2_clone$state$model$classif.rpart$log$clone(deep = TRUE) # FIXME: this can go when mlr-org/mlr3#343 is fixed
# glrn2_clone$state$log = glrn2_clone$state$log$clone(deep = TRUE) # FIXME: this can go when mlr-org/mlr3#343 is fixed
# glrn2_clone$state$model$classif.rpart$log = glrn2_clone$state$model$classif.rpart$log$clone(deep = TRUE) # FIXME: this can go when mlr-org/mlr3#343 is fixed
expect_deep_clone(glrn2_clone, glrn2$clone(deep = TRUE))
expect_prediction_classif({
graphpred2 = glrn2$predict(task)
Expand Down Expand Up @@ -109,7 +109,7 @@ test_that("GraphLearner clone_graph FALSE", {
# check that the GraphLearner predicts what we expect
expect_true(isTRUE(all.equal(gl$predict(tsk("iris")), expected_prediction)))

expect_false(gr1$is_trained) # predicting with GraphLearner resets Graph state
expect_false(gr1$is_trained) # predicting with GraphLearner resets Graph state

expect_identical(gl$graph, gr1)

Expand Down Expand Up @@ -177,7 +177,7 @@ test_that("graphlearner parameters behave as they should", {

test_that("graphlearner type inference", {
skip_if_not_installed("rpart")
skip_on_cran() # takes too long
skip_on_cran() # takes too long
# default: classif
lrn = GraphLearner$new(mlr_pipeops$get("nop"))
expect_equal(lrn$task_type, "classif")
Expand Down Expand Up @@ -246,15 +246,15 @@ test_that("graphlearner type inference", {

test_that("graphlearner type inference - branched", {
skip_if_not_installed("rpart")
skip_on_cran() # takes too long
skip_on_cran() # takes too long

# default: classif

lrn = GraphLearner$new(gunion(list(
mlr_pipeops$get(id = "l1", "learner", lrn("classif.rpart")),
po("nop") %>>% mlr_pipeops$get(id = "l2", "learner", lrn("classif.rpart"))
mlr_pipeops$get(id = "l1", "learner", lrn("classif.rpart")),
po("nop") %>>% mlr_pipeops$get(id = "l2", "learner", lrn("classif.rpart"))

)) %>>%
)) %>>%
po("classifavg") %>>%
po(id = "n2", "nop"))
expect_equal(lrn$task_type, "classif")
Expand All @@ -281,9 +281,9 @@ test_that("graphlearner type inference - branched", {

# inference when multiple input, but one is a Task
lrn = GraphLearner$new(gunion(list(
mlr_pipeops$get(id = "l1", "learner", lrn("regr.rpart")),
po("nop") %>>% mlr_pipeops$get(id = "l2", "learner", lrn("regr.rpart"))
)) %>>%
mlr_pipeops$get(id = "l1", "learner", lrn("regr.rpart")),
po("nop") %>>% mlr_pipeops$get(id = "l2", "learner", lrn("regr.rpart"))
)) %>>%
po("regravg") %>>%
po(id = "n2", "nop"))
expect_equal(lrn$task_type, "regr")
Expand Down Expand Up @@ -311,7 +311,7 @@ test_that("graphlearner type inference - branched", {

test_that("graphlearner predict type inference", {
skip_if_not_installed("rpart")
skip_on_cran() # takes too long
skip_on_cran() # takes too long
# Getter:

# Classification
Expand Down Expand Up @@ -403,7 +403,9 @@ test_that("graphlearner predict type inference", {
expect_equal(lrn$graph$pipeops[[lrr$id]]$predict_type, "prob")

# Errors:
expect_error({lrrp = po(lrn("classif.featureless", predict_type = "se"))})
expect_error({
lrrp = po(lrn("classif.featureless", predict_type = "se"))
})
})


Expand Down Expand Up @@ -439,7 +441,6 @@ test_that("GraphLearner model", {

expect_equal(lr$graph_model$pipeops$classif.rpart$learner_model$importance(), imp)


})

test_that("predict() function for Graph", {
Expand Down Expand Up @@ -468,7 +469,6 @@ test_that("predict() function for Graph", {
p1$response
)


})

test_that("base_learner() works", {
Expand Down Expand Up @@ -558,20 +558,20 @@ test_that("GraphLearner hashes", {
expect_string(all.equal(po("copy", 2)$hash, po("copy", 3)$hash), "mismatch")


lr1 <- lrn("classif.rpart")
lr2 <- lrn("classif.rpart", fallback = lrn("classif.rpart"))
lr1 = lrn("classif.rpart")
lr2 = lrn("classif.rpart", fallback = lrn("classif.rpart"))

expect_string(all.equal(lr1$hash, lr2$hash), "mismatch")
expect_string(all.equal(lr1$phash, lr2$phash), "mismatch")

lr1 <- as_learner(as_pipeop(lr1))
lr2 <- as_learner(as_pipeop(lr2))
lr1 = as_learner(as_pipeop(lr1))
lr2 = as_learner(as_pipeop(lr2))

expect_string(all.equal(lr1$hash, lr2$hash), "mismatch")
expect_string(all.equal(lr1$phash, lr2$phash), "mismatch")

lr1 <- as_learner(as_pipeop(lr1))
lr2 <- as_learner(as_pipeop(lr2))
lr1 = as_learner(as_pipeop(lr1))
lr2 = as_learner(as_pipeop(lr2))

expect_string(all.equal(lr1$hash, lr2$hash), "mismatch")
expect_string(all.equal(lr1$phash, lr2$phash), "mismatch")
Expand Down Expand Up @@ -625,13 +625,6 @@ test_that("internal_tuned_values", {
expect_equal(names(glrn2$internal_tuned_values), "classif.debug.iter")
})

test_that("disable_internal_tuning", {
glrn = as_learner(as_pipeop(lrn("classif.debug", iter = 100, early_stopping = TRUE)))
disable_internal_tuning(glrn, "classif.debug.iter")
expect_false(glrn$graph$pipeops$classif.debug$param_set$values$early_stopping)
expect_error(disable_internal_tuning(glrn, "classif.debug.abc"), "subset of")
})

test_that("set_validate", {
glrn = as_learner(as_pipeop(lrn("classif.debug", validate = 0.3)))
set_validate(glrn, "test")
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_pipeop_impute.R
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ test_that("More tests for Integers", {
test_that("impute, test rows and affect_columns", {
po_impute = po("imputeconstant", affect_columns = selector_name("insulin"), constant = 2)
task = tsk("pima")
task$divide(1:30)
task$divide(ids = 1:30)
outtrain = po_impute$train(list(task))[[1L]]
outpredict = po_impute$predict(list(task$internal_valid_task))[[1L]]
expect_true(isTRUE(all.equal(outtrain$internal_valid_task$data(), outpredict$data())))
Expand Down

0 comments on commit 0170eac

Please sign in to comment.