Skip to content

Commit

Permalink
handle unbranch with vararg input
Browse files Browse the repository at this point in the history
  • Loading branch information
mb706 committed Aug 23, 2024
1 parent a01cddc commit 511f846
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
11 changes: 7 additions & 4 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,10 @@ infer_task_type = function(graph) {
c(task_type, "classif")[[1]] # "classif" as final fallback
}


# for a PipeOpUnbranch, search for its predecessor PipeOp that is currently "active",
# i.e. that gets non-NOP-input in the current hyperparameter configuration of PipeOpBranch ops.
# Returns a list, named by PipeOpUnbranch IDs, containing the incoming PipeOp IDs.
# PipeOpBranch ops that are connected to overall Graph input get an empty string as predecessor ID.
get_po_unbranch_active_input = function(graph) {
# query a given PipeOpBranch what its selected output is
# Currently, PipeOpBranch 'selection' can be either integer-valued or a string.
Expand Down Expand Up @@ -677,7 +680,7 @@ get_po_unbranch_active_input = function(graph) {
# we have already checked that this is unique.
state_current = TRUE
reason_current = inedges$reason[inedges$state]
po_unbranch_active_input[[pipeop_id]] = inedges$dst_channel[inedges$state]
po_unbranch_active_input[[pipeop_id]] = inedges$src_id[inedges$state]
} else {
# all inputs are in agreement
state_current = any(inedges$state)
Expand Down Expand Up @@ -737,8 +740,8 @@ graph_base_learner = function(graph, resolve_branching = TRUE, lookup_field = "l
if (!inherits(last_pipeop, "PipeOpUnbranch") || !resolve_branching) {
return(unique(unlist(lapply(next_pipeop, search_base_learner_pipeops), recursive = FALSE, use.names = FALSE)))
}
current_active_input = po_unbranch_active_input[[current_pipeop]]
next_pipeop = graph$edges[dst_id == current_pipeop & dst_channel == current_active_input, src_id]
next_pipeop = po_unbranch_active_input[[current_pipeop]]
if (next_pipeop == "") next_pipeop = character(0)
}
if (length(next_pipeop) == 0) return(list())
current_pipeop = next_pipeop
Expand Down
20 changes: 19 additions & 1 deletion tests/testthat/test_GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -572,16 +572,34 @@ test_that("base_learner() works", {
expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.rpart$learner_model)
branching_learner$param_set$values$branch.selection = "classif.debug"
expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.debug$learner_model)

#
branching_learner = as_learner(ppl("branch", pos(c("pca", "ica")), prefix_branchops = "brunch") %>>% ppl("branch", lrns(c("classif.rpart", "classif.debug"))))
expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.rpart$learner_model)
branching_learner$param_set$values$branch.selection = "classif.debug"
expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.debug$learner_model)

# with '...' inputs in unbranch
branching_learner = as_learner(po("branch", c("classif.rpart", "classif.debug")) %>>% lrns(c("classif.rpart", "classif.debug")) %>>% po("unbranch"))
expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.rpart$learner_model)
branching_learner$param_set$values$branch.selection = "classif.debug"
expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.debug$learner_model)
#
branching_learner = as_learner(po("branch_1", c("pca", "ica")) %>>% pos(c("pca", "ica")) %>>% po("unbranch_1") %>>%
po("branch", c("classif.rpart", "classif.debug")) %>>% lrns(c("classif.rpart", "classif.debug")) %>>% po("unbranch"))
expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.rpart$learner_model)
branching_learner$param_set$values$branch.selection = "classif.debug"
expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.debug$learner_model)


# unbranch with single input, without corresponding PipeOpBranch, is legal
x = as_learner(po("pca") %>>% lrn("classif.rpart") %>>% po("unbranch", 1))
expect_identical(x$base_learner(), x$graph_model$pipeops$classif.rpart$learner_model)

# unbranch with '...' input
x = as_learner(po("pca") %>>% lrn("classif.rpart") %>>% po("unbranch"))
expect_identical(x$base_learner(), x$graph_model$pipeops$classif.rpart$learner_model)


# ParamInt selection parameter
x = as_learner(ppl("branch", list(lrn("classif.rpart") %>>% po("unbranch", 1, id = "poub1"), lrn("classif.debug") %>>% po("unbranch", 1, id = "poub2"))))
expect_identical(x$base_learner(), x$graph_model$pipeops$classif.rpart$learner_model)
Expand Down

0 comments on commit 511f846

Please sign in to comment.