Skip to content

Commit

Permalink
Merge pull request #808 from mlr-org/baselearner
Browse files Browse the repository at this point in the history
more sophisticated base_learner
  • Loading branch information
mb706 authored Aug 23, 2024
2 parents ee831de + 4ff1902 commit 55cce01
Show file tree
Hide file tree
Showing 13 changed files with 1,188 additions and 85 deletions.
4 changes: 2 additions & 2 deletions R/Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ Graph = R6Class("Graph",
}
if (is.null(dst_channel)) {
if (length(self$pipeops[[dst_id]]$input$name) > 1) {
stopf("dst_channel must not be NULL if src_id pipeop has more than one input channel.")
stopf("dst_channel must not be NULL if dst_id pipeop has more than one input channel.")
}
dst_channel = 1L
}
Expand Down Expand Up @@ -435,7 +435,7 @@ Graph = R6Class("Graph",
set_names = function(old, new) {
ids = names2(self$pipeops)
assert_subset(old, ids)
assert_character(new, any.missing = FALSE)
assert_character(new, any.missing = FALSE, min.chars = 1)
new_ids = map_values(ids, old, new)
names(self$pipeops) = new_ids
imap(self$pipeops, function(x, nn) x$id = nn)
Expand Down
387 changes: 323 additions & 64 deletions R/GraphLearner.R

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ PipeOp = R6Class("PipeOp",
lapply(param_set, function(x) assert_param_set(eval(x)))
private$.param_set_source = param_set
}
self$id = assert_string(id)
self$id = assert_string(id, min.chars = 1)

self$properties = assert_subset(properties, mlr_reflections$pipeops$properties)
self$param_set$values = insert_named(self$param_set$values, param_vals)
Expand Down
1 change: 1 addition & 0 deletions R/PipeOpNMF.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@
#'
#' pop$state
#' }
#' \dontshow{ try(rm("format.list", envir = .BaseNamespaceEnv$.__S3MethodsTable__.), silent = TRUE) # BiocGenerics overwrites printer for our tables mlr-org/mlr3#1112 }
#' \dontshow{ \} }
#' \dontshow{ \} }
#' @family PipeOps
Expand Down
31 changes: 20 additions & 11 deletions R/PipeOpThreshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@
#' at level `0.5`.
#'
#' @section Fields:
#' Only fields inherited from [`PipeOp`].
#' Fields inherited from [`PipeOp`], as well as:
#' * `predict_type` :: `character(1)`\cr
#' Type of prediction to return. Either `"prob"` (default) or `"response"`.
#' Setting to `"response"` should rarely be used; it may potentially save some memory but has
#' no other benefits.
#'
#' @section Methods:
#' Only methods inherited from [`PipeOp`].
Expand Down Expand Up @@ -67,7 +71,17 @@ PipeOpThreshold = R6Class("PipeOpThreshold",
tags = "target transform")
}
),
active = list(
predict_type = function(rhs) {
if (!missing(rhs)) {
assert_choice(rhs, c("prob", "response"))
private$.predict_type = rhs
}
private$.predict_type
}
),
private = list(
.predict_type = "prob",
.train = function(inputs) {
self$state = list()
list(NULL)
Expand All @@ -86,17 +100,12 @@ PipeOpThreshold = R6Class("PipeOpThreshold",
}
}

list(prd$set_threshold(thr))
}
),
active = list(
predict_type = function(val) {
if (!missing(val)) {
if (!identical(val, private$.learner)) {
stop("$predict_type for PipeOpThreshold is read-only.")
}
prd$set_threshold(thr)
if (self$predict_type == "response") {
prd$predict_types = "response"
prd$data$prob = NULL
}
return("response")
list(prd)
}
)
)
Expand Down
9 changes: 9 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,15 @@ multiplicity_recurse = function(.multip, .fun, ...) {
}
}

multiplicity_flatten = function(.multip) {
# returns list(.multip) if .multip is not a Multiplicity
# Otherwise, it returns a list with all the elements contained in .multip, independent of their nesting level
if (!is.Multiplicity(.multip)) {
return(list(.multip))
}
unlist(map(.multip, multiplicity_flatten), recursive = FALSE, use.names = FALSE)
}

# replace when new mlr3misc version is released https://github.com/mlr-org/mlr3misc/pull/80
dictionary_sugar_inc_get = function(dict, .key, ...) {
newkey = gsub("_\\d+$", "", .key)
Expand Down
141 changes: 141 additions & 0 deletions attic/branchfuns.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@


discover_active_branch = function(unbranch_pipeop) {
# take a part of the graph$edges and find the corresponding PipeOpBranch (if any) and its output channel that corresponds to it.
# I.e. if we have a PipeOp 'A' with input channel 'x', there may be edges that connect here, the 'edges' could then look like
# src_id dst_id src_channel dst_channel
# B A output1 x
# C A output1 x
# C A output2 x
# Suppose that B is connected to the 'alpha' output of a PipeOpBranch with ID 'greek',
# and C is connected to the 'aleph' output of a PipeOpBranch with ID 'hebrew'.
# Then the function would return a data.table(id = c('greek', 'hebrew'), channel = c('alpha', 'aleph')).
# This function iterates through Non-Branch-PipeOPs with single input and output, and uses recursion if it encounters
# PipeOps with multiple inputs.
# Note that the function
# (1) does not PipeOps that are exclusively ancestors of PipeOpBranch PipeOps, and
# (2) returns as many rows as there are possible paths to any PipeOpBranch, so the return value is not one-to-one with input rows.
# If a channel is connected to the Graph input without an intermediate PipeOpBranch, the return data.table will contain a row with NAs.
detect_corresponding_branch_output = function(edges) {
graphinput = data.table(id = NA_character, channel = NA_character)
pipeops_visited = new.env(parent = emptyenv())
detect_corresponding_branch_output_inner = function(edges) {
if (nrow(edges) == 0) return(list(graphinput))
result = list()
for (edge_i in seq_along(edge$src_id)) {
inpipeop_id = edge$src_id[[edge_i]]
repeat {
if (get0(inpipeop_id, pipeops_visited, ifnotfound = FALSE)) break
assign(inpipeop_id, TRUE, pipeops_visited)
inpipeop = gm$pipeops[[inpipeop_id]]
if (inherits(inpipeop, "PipeOpBranch")) {
# we found a PipeOpBranch: return the output channel that corresponds to the edge currently being processed
result[[length(result) + 1]] = list(edge[edge_i, .(id = src_id, channel = src_channel)])
break
}
# not a PipeOpBranch: look at all edges the go into the PipeOp currently being processed
edges_prior = gm$edges[dst_id == inpipeop_id, ]
if (nrow(edges_prior) != 1) {
# unless we have exactly one predecessor, we recurse:
# this handles the empty case (connected to graph input) as well as the pipeop-with-multiple-inputs case
result[[length(result) + 1]] = detect_corresponding_branch_output_inner(edges_prior)
break
}
inpipeop_id = edges_prior$src_id[[1]]
}
}
unlist(result, recursive = FALSE, use.names = FALSE)
}
unique(rbindlist(detect_corresponding_branch_output_inner(edges)))
}
inbranches = gm$pipeops[[unbranch_pipeop]]$input$name
inroutes = map_dtr(inbranches, function(inchannel) {
edges = gm$edges[dst_id == unbranch_pipeop & dst_channel == inchannel, ]
detect_corresponding_branch_output(edges)[, unbranchchannel := inchannel]
})
# inroutes:
# data.table with columns 'id', 'channel', and 'unbranchchannel'.
# For each input-channel 'unbranchchannel' of the pipeop under investigation, it lists the
# id(s) and output channel(s) of PipeOpBranches that connect there.

# add column 'live': is the current route selected?
inroutes[, live := ifelse(is.na(channel), TRUE, channel == get_pobranch_active_output(id)), by = "id"]
inroutes_livestats = inroutes[, .(any_live = any(live), all_live = all(live)), by = "unbranchchannel"]
if (any(inroutes_livestats$any_live != inroutes_livestats$all_live)) {
first_inconsistency = inroutes_livestats[which(any_live != all_live)[1L], unbranchchannel]
stopf("Inconsistent selection of PipeOpBranch outputs:\nPipeOp outputs %s are not selected, but conflict with %s",
paste(inroutes[!live & unbranchchannel == first_inconsistency, unique(sprintf("'%s.%s'", id, channel))], collapse = ", "),
inroutes[live & unbranchchannel == first_inconsistency, if (any(is.na(id))) "direct Graph input, which is always selected." else sprintf("selected output '%s.%s'.", id[[1]], channel[[1]])]
)
}
inroutes = inroutes[inroutes[any_live == TRUE, unbranchchannel], on = "unbranchchannel"]
if (length(inroutes) == 1) {
return(inroutes$unbranchchannel)
}










detect_corresponding_branch_state = function(edges) {
graphinput = data.table(id = NA_character_, channel = NA_character_, live = TRUE, last_pobranch = NA_character_)
pipeops_visited = new.env(parent = emptyenv())
pobranch_active = new.env(parent = emptyenv())
assert_consistent_selection = function(edgeinfo) {
if (all(edgeinfo$live) == any(edgeinfo$live)) return(invisible(edgeinfo))
stopf("Inconsistent selection of PipeOpBranch outputs:\nPipeOp outputs %s are not selected, but conflict with %s",
edgeinfo[!live, unique(last_pobranch)],
inroutes[live, if (any(is.na(last_pobranch))) "direct Graph input, which is always selected." else sprintf("selected output '%s'.", last_pobranch[[1]])]
)
}
detect_corresponding_branch_state_inner = function(edges) {
if (nrow(edges) == 0) return(list(graphinput))
result = list()
for (edge_i in seq_along(edge$src_id)) {
inpipeop_id = edge$src_id[[edge_i]]
repeat {
if (get0(inpipeop_id, pipeops_visited, ifnotfound = FALSE)) break
# assign(inpipeop_id, TRUE, pipeops_visited)
inpipeop_active_output = get0(inpipeop_id, pobranch_active, ifnotfound = NA_character_)
if (is.na(inpipeop_active_output)) {
inpipeop = gm$pipeops[[inpipeop_id]]
if (inherits(inpipeop, "PipeOpBranch")) {
inpipeop_active_output = get_pobranch_active_output(inpipeop)
assign(inpipeop_id, inpipeop_active_output, pobranch_active)
}
}
if (!is.na(inpipeop_active_output)) {
if (inpipeop_active_output == edge$src_channel[[edge_i]]) {
prior_state = detect_corresponding_branch_state_inner(gm$edges[dst_id == inpipeop_id, ])
assert_consistent_selection(prior_state)
current_result = list(edge[edge_i, .(id = src_id, channel = src_channel, live = any(prior_state$live), last_pobranch = sprintf("%s.%s", inpipeop_id, inpipeop_active_output))])
} else {
current_result = list(edge[edge_i, .(id = src_id, channel = src_channel, live = FALSE, last_pobranch = inpipeop_active_output)])
}

}
{
# we found a PipeOpBranch: return the output channel that corresponds to the edge currently being processed
result[[length(result) + 1]] = list(edge[, .(id = src_id, channel = src_channel)])
break
}
# not a PipeOpBranch: look at all edges the go into the PipeOp currently being processed
edges_prior = gm$edges[dst_id == inpipeop_id, ]
if (nrow(edges_prior) != 1) {
# unless we have exactly one predecessor, we recurse:
# this handles the empty case (connected to graph input) as well as the pipeop-with-multiple-inputs case
result[[length(result) + 1]] = detect_corresponding_branch_output_inner(edges_prior)
break
}
inpipeop_id = edges_prior$src_id[[1]]
}
}
unlist(result, recursive = FALSE, use.names = FALSE)
}
unique(rbindlist(detect_corresponding_branch_output_inner(edges)))
}
65 changes: 62 additions & 3 deletions man/mlr_learners_graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/mlr_pipeops_nmf.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/mlr_pipeops_threshold.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ lg$set_threshold("warn")
options(warnPartialMatchArgs = TRUE)
options(warnPartialMatchAttr = TRUE)
options(warnPartialMatchDollar = TRUE)
options(mlr3.warn_deprecated = FALSE) # avoid triggers when expect_identical() accesses deprecated fields


# simulate packages that extend existing task type
Expand Down
Loading

0 comments on commit 55cce01

Please sign in to comment.