Skip to content

Commit

Permalink
threshold through base_learner machinery
Browse files Browse the repository at this point in the history
  • Loading branch information
mb706 committed Aug 23, 2024
1 parent bb49f16 commit 48305a4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 37 deletions.
59 changes: 22 additions & 37 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -321,9 +321,25 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
},
predict_type = function(rhs) {
if (!missing(rhs)) {
private$set_predict_type(rhs)
assert_subset(rhs, unlist(mlr_reflections$learner_predict_types[[self$task_type]], use.names = FALSE))
}
private$get_predict_type()

# we look for *all* pipeops with a predict_type if we want to set it, but
# we only retrieve the predict_type of the active Learner (from branching) if we
# are getting.
predict_type_pipeops = graph_base_learner(
self$graph, resolve_branching = missing(rhs), lookup_field = "predict_type")
if (!missing(rhs)) {
walk(predict_type_pipeops, function(po) po$predict_type = rhs)
return(rhs)
}
pt = unique(unlist(map(predict_type_pipeops, "predict_type"), recursive = FALSE, use.names = FALSE))
if (!length(pt)) return(names(mlr_reflections$learner_predict_types[[self$task_type]])[[1]])
if (length(pt) > 1) {
# if there are multiple predict types, predict the "first" one, according to reflections
return(intersect(names(mlr_reflections$learner_predict_types[[self$task_type]]), pt)[[1]])
}
pt
},
param_set = function(rhs) {
if (!missing(rhs) && !identical(rhs, self$graph$param_set)) {
Expand Down Expand Up @@ -397,37 +413,6 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
assert_list(prediction, types = "Prediction", len = 1,
.var.name = sprintf("Prediction returned by Graph %s", self$id))
prediction[[1]]
},
get_predict_type = function() {
# recursively walk backwards through the graph
get_po_predict_type = function(x) {
if (!is.null(x$predict_type)) return(x$predict_type)
prdcssrs = self$graph$edges[dst_id == x$id, ]$src_id
if (length(prdcssrs)) {
# all non-null elements
predict_types = discard(map(self$graph$pipeops[prdcssrs], get_po_predict_type), is.null)
if (length(unique(predict_types)) == 1L)
return(unlist(unique(predict_types)))
}
return(NULL)
}
predict_type = get_po_predict_type(self$graph$pipeops[[self$graph$rhs]])
if (is.null(predict_type))
names(mlr_reflections$learner_predict_types[[self$task_type]])[[1]]
else
predict_type
},
set_predict_type = function(predict_type) {
# recursively walk backwards through the graph
set_po_predict_type = function(x, predict_type) {
assert_subset(predict_type, unlist(mlr_reflections$learner_predict_types[[self$task_type]]))
if (!is.null(x$predict_type)) x$predict_type = predict_type
prdcssrs = self$graph$edges[dst_id == x$id, ]$src_id
if (length(prdcssrs)) {
map(self$graph$pipeops[prdcssrs], set_po_predict_type, predict_type = predict_type)
}
}
set_po_predict_type(self$graph$pipeops[[self$graph$rhs]], predict_type)
}
)
)
Expand Down Expand Up @@ -720,9 +705,9 @@ andpaste = function(x, sep = ", ", lastsep = ", and ") {
paste0(paste(first(x, -1), collapse = sep), lastsep, last(x))
}

graph_base_learner = function(graph, resolve_branching = TRUE) {
graph_base_learner = function(graph, resolve_branching = TRUE, lookup_field = "learner_model") {
# GraphLearner$base_learner(), where return_all is TRUE, return_po is TRUE, and recursive is 1.
# We are looking for all PipeOps with a `$learner_model` field, possibly resolving branching.
# We are looking for all PipeOps with the non-NULL field named `lookup_field`, typically "learner_model", possibly resolving branching.

gm_output = graph$output
if (nrow(gm_output) != 1) {
Expand All @@ -744,8 +729,8 @@ graph_base_learner = function(graph, resolve_branching = TRUE) {
last_pipeop = graph$pipeops[[current_pipeop]]
if (get0(current_pipeop, pipeops_visited, ifnotfound = FALSE)) return(list())
assign(current_pipeop, TRUE, pipeops_visited)
learner_model = if ("learner_model" %in% names(last_pipeop)) last_pipeop$learner_model
if (!is.null(learner_model)) return(list(last_pipeop))
field_content = get0(lookup_field, last_pipeop, ifnotfound = NULL)
if (!is.null(field_content)) return(list(last_pipeop))
next_pipeop = graph$edges[dst_id == current_pipeop, src_id]
if (length(next_pipeop) > 1) {
# more than one predecessor
Expand Down
23 changes: 23 additions & 0 deletions tests/testthat/test_GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,29 @@ test_that("graphlearner type inference", {
g$train(tsk("boston_housing"))
expect_equal(as_learner(g)$task_type, "regr") # should not fail because of multiplicity in the graph

g_branch = ppl("branch", list(rpart = lrn("classif.rpart"), debug = lrn("classif.debug")))

l_branch = as_learner(g_branch)

expect_equal(l_branch$predict_type, "response")

l_branch$graph$pipeops$classif.debug$predict_type = "prob"

expect_equal(l_branch$predict_type, "response")

l_branch$param_set$values$branch.selection = "debug"

expect_equal(l_branch$predict_type, "prob")

l_branch$predict_type = "prob"

expect_equal(l_branch$graph$pipeops$classif.debug$learner$predict_type, "prob")

expect_equal(l_branch$predict_type, "prob")

l_branch$param_set$values$branch.selection = "rpart"

expect_equal(l_branch$predict_type, "prob")
})

test_that("graphlearner type inference - branched", {
Expand Down

0 comments on commit 48305a4

Please sign in to comment.