Skip to content

Commit

Permalink
base_learner without GraphLearner
Browse files Browse the repository at this point in the history
  • Loading branch information
mb706 committed Aug 23, 2024
1 parent ad2c248 commit c286530
Showing 1 changed file with 56 additions and 55 deletions.
111 changes: 56 additions & 55 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
graph$state = NULL

id = assert_string(id, null.ok = TRUE) %??% paste(graph$ids(sorted = TRUE), collapse = ".")
self$id = id # init early so 'infer_task_type()' can use it in error messages
self$id = id # init early so 'base_learner()' can use it in error messages
private$.graph = graph

output = graph$output
Expand All @@ -179,7 +179,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}

if (is.null(task_type)) {
task_type = infer_task_type(self, graph)
task_type = infer_task_type(graph)
}
assert_subset(task_type, mlr_reflections$task_types$type)

Expand Down Expand Up @@ -219,70 +219,28 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
if (recursive <= 0) return(if (return_all) list(self) else self)
if (return_all && recursive > 1) stop("recursive must be <= 1 if return_all is TRUE")

# graph_base_learner() corresponds to base_learner(recursive = 1, return_po = TRUE, return_all = TRUE)
result = graph_base_learner(self$graph_model, resolve_branching = resolve_branching)

if (!return_all) {
candidates = self$base_learner(recursive = 1, return_po = TRUE, return_all = TRUE, resolve_branching = resolve_branching)
if (length(candidates) < 1) stopf("No base learner found in Graph %s.", self$id)
if (length(candidates) > 1) stopf("Graph %s has no unique PipeOp containing a Learner.", self$id)
if (length(result) < 1) stopf("No base learner found in Graph %s.", self$id)
if (length(result) > 1) stopf("Graph %s has no unique PipeOp containing a Learner.", self$id)
if (!return_po) {
result = multiplicity_flatten(candidates[[1]]$learner_model)
result = multiplicity_flatten(result[[1]]$learner_model)
if (length(result) != 1) {
# if learner_model is not a Multiplicity, multiplicity_flatten will return a list of length 1
stopf("Graph %s's base learner is a Multiplicity that does not contain exactly one Learner.", self$id)
}
return(result[[1]]$base_learner(recursive - 1))
} else {
return(candidates[[1]])
return(result[[1]])
}
}

# if we are here, return_all is TRUE, and recursive is therefore 1.
if (!return_po) {
result = self$base_learner(recursive = 1, return_po = TRUE, return_all = TRUE, resolve_branching = resolve_branching)
return(lapply(result, function(x) x$learner_model))
result = map(result, "learner_model")
}

# If we are here, return_all is TRUE, return_po is TRUE, recursive is 1.
# We are looking for all PipeOps with a `$learner_model` field, possibly resolving branching.

gm = self$graph_model

gm_output = gm$output
if (nrow(gm_output) != 1) {
# should never happen, since we checked this in initialize(), but theoretically the user could have changed the graph by-reference
stop("Graph has no unique output.")
}
last_pipeop_id = gm_output$op.id

# pacify static checks
src_id = NULL
dst_id = NULL
src_channel = NULL
dst_channel = NULL
delayedAssign("po_unbranch_active_input", get_po_unbranch_active_input(gm)) # only call get_pobranch_active_output() if we encounter a PipeOpUnbranch

pipeops_visited = new.env(parent = emptyenv())
search_base_learner_pipeops = function(current_pipeop) {
repeat {
last_pipeop = gm$pipeops[[current_pipeop]]
if (get0(current_pipeop, pipeops_visited, ifnotfound = FALSE)) return(list())
assign(current_pipeop, TRUE, pipeops_visited)
learner_model = if ("learner_model" %in% names(last_pipeop)) last_pipeop$learner_model
if (!is.null(learner_model)) return(list(last_pipeop))
next_pipeop = gm$edges[dst_id == current_pipeop, src_id]
if (length(next_pipeop) > 1) {
# more than one predecessor
if (!inherits(last_pipeop, "PipeOpUnbranch") || !resolve_branching) {
return(unique(unlist(lapply(next_pipeop, search_base_learner_pipeops), recursive = FALSE, use.names = FALSE)))
}
current_active_input = po_unbranch_active_input[[current_pipeop]]
next_pipeop = gm$edges[dst_id == current_pipeop & dst_channel == current_active_input, src_id]
}
if (length(next_pipeop) == 0) return(list())
current_pipeop = next_pipeop
}
}

unique(search_base_learner_pipeops(last_pipeop_id))
result
},
marshal = function(...) {
learner_marshal(.learner = self, ...)
Expand Down Expand Up @@ -609,7 +567,7 @@ as_learner.PipeOp = function(x, clone = FALSE, ...) {
}


infer_task_type = function(self, graph) {
infer_task_type = function(graph) {
output = graph$output
# check the high level input and output
class_table = mlr_reflections$task_types
Expand Down Expand Up @@ -649,7 +607,7 @@ infer_task_type = function(self, graph) {
}
if (length(task_type) != 1L) {
# We could not infer type from any PipeOp output channels, so we try to infer it from the base learners
baselearners = self$base_learner(recursive = 1, return_all = TRUE, resolve_branching = FALSE)
baselearners = map(graph_base_learner(graph, resolve_branching = FALSE), "learner_model")
task_type = unique(unlist(map(baselearners, function(x) {
# Currently we should not have Multiplicities here, since Graph gets NULLed explicitly upon construction.
# If we ever allow initializing a Learner with a trained Graph, the following will be necessary.
Expand Down Expand Up @@ -761,3 +719,46 @@ andpaste = function(x, sep = ", ", lastsep = ", and ") {
if (length(x) == 1) return(x[[1]])
paste0(paste(first(x, -1), collapse = sep), lastsep, last(x))
}

graph_base_learner = function(graph, resolve_branching = TRUE) {
# GraphLearner$base_learner(), where return_all is TRUE, return_po is TRUE, and recursive is 1.
# We are looking for all PipeOps with a `$learner_model` field, possibly resolving branching.

gm_output = graph$output
if (nrow(gm_output) != 1) {
# should never happen, since we checked this in initialize(), but theoretically the user could have changed the graph by-reference
stop("Graph has no unique output.")
}
last_pipeop_id = gm_output$op.id

# pacify static checks
src_id = NULL
dst_id = NULL
src_channel = NULL
dst_channel = NULL
delayedAssign("po_unbranch_active_input", get_po_unbranch_active_input(graph)) # only call get_pobranch_active_output() if we encounter a PipeOpUnbranch

pipeops_visited = new.env(parent = emptyenv())
search_base_learner_pipeops = function(current_pipeop) {
repeat {
last_pipeop = graph$pipeops[[current_pipeop]]
if (get0(current_pipeop, pipeops_visited, ifnotfound = FALSE)) return(list())
assign(current_pipeop, TRUE, pipeops_visited)
learner_model = if ("learner_model" %in% names(last_pipeop)) last_pipeop$learner_model
if (!is.null(learner_model)) return(list(last_pipeop))
next_pipeop = graph$edges[dst_id == current_pipeop, src_id]
if (length(next_pipeop) > 1) {
# more than one predecessor
if (!inherits(last_pipeop, "PipeOpUnbranch") || !resolve_branching) {
return(unique(unlist(lapply(next_pipeop, search_base_learner_pipeops), recursive = FALSE, use.names = FALSE)))
}
current_active_input = po_unbranch_active_input[[current_pipeop]]
next_pipeop = graph$edges[dst_id == current_pipeop & dst_channel == current_active_input, src_id]
}
if (length(next_pipeop) == 0) return(list())
current_pipeop = next_pipeop
}
}

unique(search_base_learner_pipeops(last_pipeop_id))
}

0 comments on commit c286530

Please sign in to comment.