Skip to content

Commit

Permalink
Fix passing named list to graph$train with varargs
Browse files Browse the repository at this point in the history
Closes #626
  • Loading branch information
mb706 committed Aug 17, 2024
1 parent ac825b3 commit a7f480e
Show file tree
Hide file tree
Showing 3 changed files with 211 additions and 22 deletions.
68 changes: 47 additions & 21 deletions R/Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -581,20 +581,35 @@ graph_reduce = function(self, input, fun, single_input) {

assert_flag(single_input)

graph_input = self$input
graph_input_unique = graph_input = self$input
graph_output = self$output

edges = copy(self$edges)

if (!single_input) assert_list(input, .var.name = "input when single_input is FALSE")

if (!is.null(names(input)) && !single_input) {
assert_names(names(input), subset.of = graph_input$name, .var.name = "input when it has names and single_input is FALSE")
}

# create virtual "__initial__" and "__terminal__" nodes with edges to inputs / outputs of graph.
# if we have `single_input == FALSE` and one(!) vararg channel, we widen the vararg input
# if we have `single_input == FALSE` and vararg channels, we widen the vararg input
# appropriately.
if (!single_input && length(assert_list(input, .var.name = "input when single_input is FALSE")) > nrow(graph_input) && "..." %in% graph_input$channel.name) {
if (sum("..." == graph_input$channel.name) != 1) {
stop("Ambiguous distribution of inputs to vararg channels.\nAssigning more than one input to vararg channels when there are multiple vararg inputs does not work.")
# At this point we are agnostic about whether zero inputs to vararg are possible. In case this
# ever makes sense, the following should still work. We therefore don't check whether the number of
# inputs differs from the number of channels -- theoretically, there could be two varargs, one
# getting two inputs, the other none.
if (!single_input && "..." %in% graph_input$channel.name) {
if (sum("..." == graph_input$channel.name) != 1 && is.null(names(input))) {
stop("Ambiguous distribution of inputs to vararg channels.\nAssigning more than one input to vararg channels when there are multiple vararg inputs does not work.\nYou can try using a named input list. Vararg elements must be named '<pipeopname>....' (with four dots).")
}
# repeat the "..." as often as necessary
repeats = ifelse(graph_input$channel.name == "...", length(input) - nrow(graph_input) + 1, 1)
if (is.null(names(input))) {
repeats = ifelse(graph_input$channel.name == "...", length(input) - nrow(graph_input) + 1, 1)
} else {
repeats = nafill(as.numeric(table(names(input))[graph_input$name]), fill = 0)
}

graph_input = graph_input[rep(graph_input$name, repeats), , on = "name"]
}

Expand All @@ -607,25 +622,36 @@ graph_reduce = function(self, input, fun, single_input) {
# add new column to store content that is sent along an edge
edges$payload = list()

if (!single_input) {
# we need the input list length to be equal to the number of channels. This number was
# already increased appropriately if there is a single vararg channel.
assert_list(input, len = nrow(graph_input), .var.name = sprintf("input when single_input is FALSE and there are %s input channels", nrow(graph_input)))
if (single_input) {
edges[get("src_id") == "__initial__", "payload" := list(list(input))]
} else if (!is.null(names(input))) {
# input can be a named list (will be distributed to respective edges) or unnamed.
# if it is named, we check that names are unambiguous.
if (!is.null(names(input))) {
if (anyDuplicated(graph_input$name)) {
# FIXME this will unfortunately trigger if there is more than one named input for a vararg channel.
stopf("'input' must not be a named list because Graph %s input channels have duplicated names.", self$id)
}
assert_names(names(input), subset.of = graph_input$name, .var.name = sprintf("input when it has names and single_input is FALSE"))
edges[list("__initial__", names(input)), "payload" := list(input), on = c("src_id", "src_channel")]
} else {
# don't rely on unique graph_input$name!
edges[get("src_id") == "__initial__", "payload" := list(input)]

# don't use graph_input in the following, since rows with varargs are potentially duplicated.
# Also don't use innames_novararg, since vararg channels could be duplicately named regardless.
if (anyDuplicated(graph_input_unique$name)) {
stopf("'input' must not be a named list because Graph input channels have duplicated names: %s",
paste0(unique(innames_novararg[duplicated(innames_novararg)]), collapse = ", "))
}

innames_novararg = graph_input$name[graph_input$channel.name != "..."]
input_novararg = input[names(input) %in% innames_novararg]
assert_names(names(input_novararg), type = "unique", .var.name = "input that does not refer to vararg input channels (when input has names and single_input is FALSE)")

edges[list("__initial__", names(input_novararg)), "payload" := list(input_novararg), on = c("src_id", "src_channel")]

input_vararg = input[!names(input) %in% innames_novararg]
input_vararg = split(input_vararg, names(input_vararg))
for (vararg_channel in names(input_vararg)) {
edges[list("__initial__", vararg_channel), "payload" := list(input_vararg[[vararg_channel]]), on = c("src_id", "src_channel")]
}
} else {
edges[get("src_id") == "__initial__", "payload" := list(list(input))]
# we need the input list length to be equal to the number of channels. This number was
# already increased appropriately if there is a single vararg channel.
assert_list(input, len = nrow(graph_input), .var.name = sprintf("input when single_input is FALSE and there are %s input channels", nrow(graph_input)))
# don't rely on unique graph_input$name!
edges[get("src_id") == "__initial__", "payload" := list(input)]
}

# get the topo-sorted pipeop ids
Expand Down
5 changes: 4 additions & 1 deletion R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@
#' as well as a `$internal_valid_scores` field, which allows to access the internal validation scores after training.
#' * `"internal_tuning"`: the `PipeOp` is able to internally optimize hyperparameters.
#' This works analogously to the internal tuning implementation for [`mlr3::Learner`].
#' `PipeOp`s with that property also implement the standardized accessor `$internal_tuned_values` and have at least one
#' `PipeOp`s with that property also implement the standardized accessor `$internal_tuned_values` and have at least one
#' parameter tagged with `"internal_tuning"`.
#' An example for such a `PipeOp` is a `PipeOpLearner` that wraps a `Learner` with the `"internal_tuning"` property.
#'
Expand Down Expand Up @@ -562,6 +562,9 @@ multiplicity_type_nesting_level = function(str, varname) {
# @return `list`
unpack_multiplicities = function(input, expected_nesting_level, inputnames, poid) {
assert_list(input)
# in case of varargs, there could be more (or fewer) 'input's than 'expected_nesting_level's,
# so we have to make sure the positions match.
expected_nesting_level = expected_nesting_level[match(names(input), inputnames)]
unpacking = mapply(multiplicity_nests_deeper_than, input, expected_nesting_level)
if (!any(unpacking)) {
return(NULL) # no unpacking
Expand Down
160 changes: 160 additions & 0 deletions tests/testthat/test_Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,8 @@ test_that("edges that introduce loops cannot be added", {
expect_error(g$add_edge("p2", "p1", 1, 1), "Cycle detected")

expect_deep_clone(g, gclone) # check that edges did not change


})


Expand Down Expand Up @@ -506,3 +508,161 @@ test_that("Same output into multiple channels does not cause a bug", {
expect_true(res$po3.output1 == 2)
expect_true(res$po4.output1 == 2)
})

test_that("Graph with ambiguously named input", {
PipeOpDebugInname = R6Class(
inherit = PipeOp,
public = list(
initialize = function(id = "debug", inname = "input", param_vals = list()) {
output = data.table(
name = "output",
train = "numeric",
predict = "numeric"
)
input = data.table(
name = inname,
train = "numeric",
predict = "numeric"
)
super$initialize(
id = id,
param_set = ps(),
param_vals = param_vals,
input = input,
output = output
)
}
),
private = list(
.train = function(inputs) {
list(inputs[[1]] * 10)
}
)
)

gr = Graph$new()$
add_pipeop(PipeOpDebugInname$new(id = "a", inname = "b.c"))$
add_pipeop(PipeOpDebugInname$new(id = "a.b", inname = "c"))

expect_equal(
gr$train(list(1, 2), single_input = FALSE),
list(a.output = 10, a.b.output = 20)
)

expect_error(gr$train(list(a.b.c = 1, a.b.c = 2), single_input = FALSE),
"duplicated names: a.b.c")

})

test_that("Graph with vararg input", {
PipeOpDebugVararg = R6Class(
inherit = PipeOp,
public = list(
initialize = function(id = "debugvararg", param_vals = list()) {
output = data.table(
name = "output",
train = "numeric",
predict = "numeric"
)
input = data.table(
name = c("input", "..."),
train = c("numeric", "numeric"),
predict = c("numeric", "numeric")
)
super$initialize(
id = id,
param_set = ps(),
param_vals = param_vals,
input = input,
output = output
)
}
),
private = list(
.train = function(inputs) {
x = 1000 * inputs$input + sum(unlist(inputs[names(inputs) == "..."]))
list(output = x)
}
)
)

g1 <- Graph$new()$
add_pipeop(PipeOpDebugVararg$new())

expect_equal(g1$train(list(1, 2, 3), single_input = FALSE),
list(debugvararg.output = 1000 * 1 + 2 + 3))

expect_equal(g1$train(list(debugvararg.input = 1, debugvararg.... = 2), single_input = FALSE),
list(debugvararg.output = 1000 * 1 + 2))

expect_equal(g1$train(list(debugvararg.... = 2, debugvararg.input = 1), single_input = FALSE),
list(debugvararg.output = 1000 * 1 + 2))

expect_equal(g1$train(list(debugvararg.... = 2, debugvararg.input = 1, debugvararg.... = 3), single_input = FALSE),
list(debugvararg.output = 1000 * 1 + 2 + 3))

expect_error(g1$train(list(debugvararg.input = 1, debugvararg.... = 2, debugvararg.input = 3), single_input = FALSE),
"input that does not refer to vararg.*unique names")

g2 <- Graph$new()$
add_pipeop(PipeOpDebugVararg$new())$
add_pipeop(PipeOpDebugVararg$new(id = "debugvararg2"))

expect_error(g2$train(list(1, 2, 3, 4, 5), single_input = FALSE),
"more than one input")

# throw the error even when number of inputs == number of args given
expect_error(g2$train(list(1, 2, 3, 4), single_input = FALSE),
"more than one input")

expect_equal(g2$train(1), list(debugvararg.output = 1001, debugvararg2.output = 1001))

expect_equal(
g2$train(list(debugvararg.input = 1, debugvararg.... = 2, debugvararg2.input = 3, debugvararg2.... = 4), single_input = FALSE),
list(debugvararg.output = 1000 * 1 + 2, debugvararg2.output = 1000 * 3 + 4)
)

#multi-assignment to multi-vararg
expect_equal(
g2$train(list(debugvararg.... = 100, debugvararg2.... = 10000, debugvararg.input = 1, debugvararg2.... = 200, debugvararg.... = 2, debugvararg2.input = 3, debugvararg2.... = 4), single_input = FALSE),
list(debugvararg.output = 1000 * 1 + 100 + 2, debugvararg2.output = 10000 +1000 * 3 + 200 + 4)
)

expect_equal(
g2$train(list(debugvararg2.... = 10000, debugvararg.input = 1, debugvararg2.... = 200, debugvararg2.input = 3, debugvararg2.... = 4), single_input = FALSE),
list(debugvararg.output = 1000 * 1, debugvararg2.output = 10000 +1000 * 3 + 200 + 4)
)

g3 <- Graph$new()$
add_pipeop(PipeOpDebugVararg$new())$
add_pipeop(PipeOpDebugVararg$new(id = "debugvararg2"))$
add_pipeop(PipeOpCopy$new(3))$
add_edge("copy", "debugvararg", "output1", "...")$
add_edge("copy", "debugvararg2", "output1", "...")

expect_equal(g3$train(list(1, 2, 3), single_input = FALSE),
list(debugvararg.output = 1003, debugvararg2.output = 2003, copy.output2 = 3, copy.output3 = 3))

g3 <- Graph$new()$
add_pipeop(PipeOpDebugVararg$new())$
add_pipeop(PipeOpDebugVararg$new(id = "debugvararg2"))$
add_pipeop(PipeOpCopy$new(3))$
add_edge("copy", "debugvararg", "output1", "...")$
add_edge("copy", "debugvararg2", "output2", "...")

expect_equal(g3$train(list(1, 2, 3), single_input = FALSE),
list(debugvararg.output = 1003, debugvararg2.output = 2003, copy.output3 = 3))

g3 <- Graph$new()$
add_pipeop(PipeOpDebugVararg$new())$
add_pipeop(PipeOpDebugVararg$new(id = "debugvararg2"))$
add_pipeop(PipeOpCopy$new(3))$
add_edge("copy", "debugvararg", "output1", "...")$
add_edge("copy", "debugvararg", "output2", "...")$
add_edge("copy", "debugvararg2", "output1", "...")$
add_edge("copy", "debugvararg2", "output3", "...")

expect_equal(g3$train(list(1, 2, 3), single_input = FALSE),
list(debugvararg.output = 1006, debugvararg2.output = 2006))

})

0 comments on commit a7f480e

Please sign in to comment.