From 511f8467a9732a38ec03648db113d03bd7b19cc6 Mon Sep 17 00:00:00 2001 From: mb706 Date: Fri, 23 Aug 2024 14:54:54 +0200 Subject: [PATCH] handle unbranch with vararg input --- R/GraphLearner.R | 11 +++++++---- tests/testthat/test_GraphLearner.R | 20 +++++++++++++++++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/R/GraphLearner.R b/R/GraphLearner.R index 64f69702b..b09c0ed0b 100644 --- a/R/GraphLearner.R +++ b/R/GraphLearner.R @@ -607,7 +607,10 @@ infer_task_type = function(graph) { c(task_type, "classif")[[1]] # "classif" as final fallback } - +# for a PipeOpUnbranch, search for its predecessor PipeOp that is currently "active", +# i.e. that gets non-NOP-input in the current hyperparameter configuration of PipeOpBranch ops. +# Returns a list, named by PipeOpUnbranch IDs, containing the incoming PipeOp IDs. +# PipeOpBranch ops that are connected to overall Graph input get an empty string as predecessor ID. get_po_unbranch_active_input = function(graph) { # query a given PipeOpBranch what its selected output is # Currently, PipeOpBranch 'selection' can be either integer-valued or a string. @@ -677,7 +680,7 @@ get_po_unbranch_active_input = function(graph) { # we have already checked that this is unique. state_current = TRUE reason_current = inedges$reason[inedges$state] - po_unbranch_active_input[[pipeop_id]] = inedges$dst_channel[inedges$state] + po_unbranch_active_input[[pipeop_id]] = inedges$src_id[inedges$state] } else { # all inputs are in agreement state_current = any(inedges$state) @@ -737,8 +740,8 @@ graph_base_learner = function(graph, resolve_branching = TRUE, lookup_field = "l if (!inherits(last_pipeop, "PipeOpUnbranch") || !resolve_branching) { return(unique(unlist(lapply(next_pipeop, search_base_learner_pipeops), recursive = FALSE, use.names = FALSE))) } - current_active_input = po_unbranch_active_input[[current_pipeop]] - next_pipeop = graph$edges[dst_id == current_pipeop & dst_channel == current_active_input, src_id] + next_pipeop = po_unbranch_active_input[[current_pipeop]] + if (next_pipeop == "") next_pipeop = character(0) } if (length(next_pipeop) == 0) return(list()) current_pipeop = next_pipeop diff --git a/tests/testthat/test_GraphLearner.R b/tests/testthat/test_GraphLearner.R index 60b41d7b5..f992aefbd 100644 --- a/tests/testthat/test_GraphLearner.R +++ b/tests/testthat/test_GraphLearner.R @@ -572,16 +572,34 @@ test_that("base_learner() works", { expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.rpart$learner_model) branching_learner$param_set$values$branch.selection = "classif.debug" expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.debug$learner_model) - + # branching_learner = as_learner(ppl("branch", pos(c("pca", "ica")), prefix_branchops = "brunch") %>>% ppl("branch", lrns(c("classif.rpart", "classif.debug")))) expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.rpart$learner_model) branching_learner$param_set$values$branch.selection = "classif.debug" expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.debug$learner_model) + # with '...' inputs in unbranch + branching_learner = as_learner(po("branch", c("classif.rpart", "classif.debug")) %>>% lrns(c("classif.rpart", "classif.debug")) %>>% po("unbranch")) + expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.rpart$learner_model) + branching_learner$param_set$values$branch.selection = "classif.debug" + expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.debug$learner_model) + # + branching_learner = as_learner(po("branch_1", c("pca", "ica")) %>>% pos(c("pca", "ica")) %>>% po("unbranch_1") %>>% + po("branch", c("classif.rpart", "classif.debug")) %>>% lrns(c("classif.rpart", "classif.debug")) %>>% po("unbranch")) + expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.rpart$learner_model) + branching_learner$param_set$values$branch.selection = "classif.debug" + expect_identical(branching_learner$base_learner(), branching_learner$graph_model$pipeops$classif.debug$learner_model) + + # unbranch with single input, without corresponding PipeOpBranch, is legal x = as_learner(po("pca") %>>% lrn("classif.rpart") %>>% po("unbranch", 1)) expect_identical(x$base_learner(), x$graph_model$pipeops$classif.rpart$learner_model) + # unbranch with '...' input + x = as_learner(po("pca") %>>% lrn("classif.rpart") %>>% po("unbranch")) + expect_identical(x$base_learner(), x$graph_model$pipeops$classif.rpart$learner_model) + + # ParamInt selection parameter x = as_learner(ppl("branch", list(lrn("classif.rpart") %>>% po("unbranch", 1, id = "poub1"), lrn("classif.debug") %>>% po("unbranch", 1, id = "poub2")))) expect_identical(x$base_learner(), x$graph_model$pipeops$classif.rpart$learner_model)