diff --git a/R/GraphLearner.R b/R/GraphLearner.R index 3c9b8e57f..64f69702b 100644 --- a/R/GraphLearner.R +++ b/R/GraphLearner.R @@ -321,9 +321,25 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, }, predict_type = function(rhs) { if (!missing(rhs)) { - private$set_predict_type(rhs) + assert_subset(rhs, unlist(mlr_reflections$learner_predict_types[[self$task_type]], use.names = FALSE)) } - private$get_predict_type() + + # we look for *all* pipeops with a predict_type if we want to set it, but + # we only retrieve the predict_type of the active Learner (from branching) if we + # are getting. + predict_type_pipeops = graph_base_learner( + self$graph, resolve_branching = missing(rhs), lookup_field = "predict_type") + if (!missing(rhs)) { + walk(predict_type_pipeops, function(po) po$predict_type = rhs) + return(rhs) + } + pt = unique(unlist(map(predict_type_pipeops, "predict_type"), recursive = FALSE, use.names = FALSE)) + if (!length(pt)) return(names(mlr_reflections$learner_predict_types[[self$task_type]])[[1]]) + if (length(pt) > 1) { + # if there are multiple predict types, predict the "first" one, according to reflections + return(intersect(names(mlr_reflections$learner_predict_types[[self$task_type]]), pt)[[1]]) + } + pt }, param_set = function(rhs) { if (!missing(rhs) && !identical(rhs, self$graph$param_set)) { @@ -397,37 +413,6 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner, assert_list(prediction, types = "Prediction", len = 1, .var.name = sprintf("Prediction returned by Graph %s", self$id)) prediction[[1]] - }, - get_predict_type = function() { - # recursively walk backwards through the graph - get_po_predict_type = function(x) { - if (!is.null(x$predict_type)) return(x$predict_type) - prdcssrs = self$graph$edges[dst_id == x$id, ]$src_id - if (length(prdcssrs)) { - # all non-null elements - predict_types = discard(map(self$graph$pipeops[prdcssrs], get_po_predict_type), is.null) - if (length(unique(predict_types)) == 1L) - return(unlist(unique(predict_types))) - } - return(NULL) - } - predict_type = get_po_predict_type(self$graph$pipeops[[self$graph$rhs]]) - if (is.null(predict_type)) - names(mlr_reflections$learner_predict_types[[self$task_type]])[[1]] - else - predict_type - }, - set_predict_type = function(predict_type) { - # recursively walk backwards through the graph - set_po_predict_type = function(x, predict_type) { - assert_subset(predict_type, unlist(mlr_reflections$learner_predict_types[[self$task_type]])) - if (!is.null(x$predict_type)) x$predict_type = predict_type - prdcssrs = self$graph$edges[dst_id == x$id, ]$src_id - if (length(prdcssrs)) { - map(self$graph$pipeops[prdcssrs], set_po_predict_type, predict_type = predict_type) - } - } - set_po_predict_type(self$graph$pipeops[[self$graph$rhs]], predict_type) } ) ) @@ -720,9 +705,9 @@ andpaste = function(x, sep = ", ", lastsep = ", and ") { paste0(paste(first(x, -1), collapse = sep), lastsep, last(x)) } -graph_base_learner = function(graph, resolve_branching = TRUE) { +graph_base_learner = function(graph, resolve_branching = TRUE, lookup_field = "learner_model") { # GraphLearner$base_learner(), where return_all is TRUE, return_po is TRUE, and recursive is 1. - # We are looking for all PipeOps with a `$learner_model` field, possibly resolving branching. + # We are looking for all PipeOps with the non-NULL field named `lookup_field`, typically "learner_model", possibly resolving branching. gm_output = graph$output if (nrow(gm_output) != 1) { @@ -744,8 +729,8 @@ graph_base_learner = function(graph, resolve_branching = TRUE) { last_pipeop = graph$pipeops[[current_pipeop]] if (get0(current_pipeop, pipeops_visited, ifnotfound = FALSE)) return(list()) assign(current_pipeop, TRUE, pipeops_visited) - learner_model = if ("learner_model" %in% names(last_pipeop)) last_pipeop$learner_model - if (!is.null(learner_model)) return(list(last_pipeop)) + field_content = get0(lookup_field, last_pipeop, ifnotfound = NULL) + if (!is.null(field_content)) return(list(last_pipeop)) next_pipeop = graph$edges[dst_id == current_pipeop, src_id] if (length(next_pipeop) > 1) { # more than one predecessor diff --git a/tests/testthat/test_GraphLearner.R b/tests/testthat/test_GraphLearner.R index 642f3e73a..60b41d7b5 100644 --- a/tests/testthat/test_GraphLearner.R +++ b/tests/testthat/test_GraphLearner.R @@ -272,6 +272,29 @@ test_that("graphlearner type inference", { g$train(tsk("boston_housing")) expect_equal(as_learner(g)$task_type, "regr") # should not fail because of multiplicity in the graph + g_branch = ppl("branch", list(rpart = lrn("classif.rpart"), debug = lrn("classif.debug"))) + + l_branch = as_learner(g_branch) + + expect_equal(l_branch$predict_type, "response") + + l_branch$graph$pipeops$classif.debug$predict_type = "prob" + + expect_equal(l_branch$predict_type, "response") + + l_branch$param_set$values$branch.selection = "debug" + + expect_equal(l_branch$predict_type, "prob") + + l_branch$predict_type = "prob" + + expect_equal(l_branch$graph$pipeops$classif.debug$learner$predict_type, "prob") + + expect_equal(l_branch$predict_type, "prob") + + l_branch$param_set$values$branch.selection = "rpart" + + expect_equal(l_branch$predict_type, "prob") }) test_that("graphlearner type inference - branched", {