Skip to content

Commit

Permalink
Merge pull request #792 from mlr-org/add_crate
Browse files Browse the repository at this point in the history
Add crate to custom_checks
  • Loading branch information
mb706 authored Aug 14, 2024
2 parents 525c1b0 + e9d0702 commit 22a408c
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 53 deletions.
6 changes: 4 additions & 2 deletions R/LearnerAvg.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,10 @@ LearnerClassifAvg = R6Class("LearnerClassifAvg", inherit = LearnerClassif,
ps = ps(
measure = p_uty(custom_check = check_class_or_character("MeasureClassif", mlr_measures), tags = "train"),
optimizer = p_uty(custom_check = check_optimizer, tags = "train"),
log_level = p_uty(tags = "train",
function(x) check_string(x) %check||% check_integerish(x))
log_level = p_uty(
custom_check = crate(function(x) check_string(x) %check||% check_integerish(x), .parent = topenv()),
tags = "train"
)
)
ps$values = list(measure = "classif.ce", optimizer = "nloptr", log_level = "warn")
super$initialize(
Expand Down
25 changes: 14 additions & 11 deletions R/PipeOpColRoles.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,20 @@ PipeOpColRoles = R6Class("PipeOpColRoles",
initialize = function(id = "colroles", param_vals = list()) {
ps = ps(
# named list, each entry with a vector of roles
new_role = p_uty(tags = c("train", "predict"), custom_check = function(x) {
first_check = check_list(x, types = "character", any.missing = FALSE, min.len = 1L, names = "named")
# return the error directly if this failed
if (is.character(first_check)) {
return(first_check)
}
# changing anything target related is not supported
# a value of "character()" will lead to the column being dropped
all_col_roles = unique(unlist(mlr3::mlr_reflections$task_col_roles))
check_subset(unlist(x), all_col_roles[all_col_roles != "target"])
})
new_role = p_uty(
tags = c("train", "predict"),
custom_check = crate(function(x) {
first_check = check_list(x, types = "character", any.missing = FALSE, min.len = 1L, names = "named")
# return the error directly if this failed
if (is.character(first_check)) {
return(first_check)
}
# changing anything target related is not supported
# a value of "character()" will lead to the column being dropped
all_col_roles = unique(unlist(mlr3::mlr_reflections$task_col_roles))
check_subset(unlist(x), all_col_roles[all_col_roles != "target"])
}, .parent = topenv())
)
)
super$initialize(id, param_set = ps, param_vals = param_vals, can_subset_cols = FALSE)
}
Expand Down
4 changes: 2 additions & 2 deletions R/PipeOpMutate.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ PipeOpMutate = R6Class("PipeOpMutate",
# checks that `mutation` is
# * a named list of `formula`
# * that each element has only a lhs
check_mutation_formulae = function(x) {
check_mutation_formulae = crate(function(x) {
check_list(x, types = "formula", names = "unique") %check&&%
Reduce(`%check&&%`, lapply(x, function(xel) {
if (length(xel) != 2) {
Expand All @@ -132,6 +132,6 @@ check_mutation_formulae = function(x) {
}
TRUE
}), TRUE)
}
}, .parent = topenv())

mlr_pipeops$add("mutate", PipeOpMutate)
37 changes: 20 additions & 17 deletions R/PipeOpProxy.R
Original file line number Diff line number Diff line change
Expand Up @@ -88,23 +88,26 @@ PipeOpProxy = R6Class("PipeOpProxy",
# input can be a vararg input channel
inname = if (innum) rep_suffix("input", innum) else "..."
ps = ps(
content = p_uty(tags = c("train", "predidct", "required"), custom_check = function(x) {
# content must be an object that can be coerced to a Graph and the output number must match
tryCatch({
graph = as_graph(x)
# graph$output access may be slow, so we cache it here
graph_outnum = nrow(graph$output)
graph_input = nrow(graph$input)
if (graph_outnum != 1 && graph_outnum != outnum) {
"Graph's output number must either be 1 or match `outnum`"
} else if (innum > 1 && graph_input != innum && (graph_input > innum || "..." %nin% graph$input$name)) {
"Graph's input number when `innum` > 1 must either match `innum` or the Graph must contain a '...' (vararg) channel."
} else {
TRUE
}
},
error = function(error_condition) "`content` must be an object that can be converted to a Graph")
})
content = p_uty(
custom_check = crate(function(x) {
# content must be an object that can be coerced to a Graph and the output number must match
tryCatch({
graph = as_graph(x)
# graph$output access may be slow, so we cache it here
graph_outnum = nrow(graph$output)
graph_input = nrow(graph$input)
if (graph_outnum != 1 && graph_outnum != outnum) {
"Graph's output number must either be 1 or match `outnum`"
} else if (innum > 1 && graph_input != innum && (graph_input > innum || "..." %nin% graph$input$name)) {
"Graph's input number when `innum` > 1 must either match `innum` or the Graph must contain a '...' (vararg) channel."
} else {
TRUE
}
},
error = function(error_condition) "`content` must be an object that can be converted to a Graph")
}, innum, outnum, .parent = topenv()),
tags = c("train", "predidct", "required")
)
)
ps$values = list(content = PipeOpFeatureUnion$new(innum = innum))
super$initialize(id, param_set = ps, param_vals = param_vals,
Expand Down
9 changes: 5 additions & 4 deletions R/PipeOpRenameColumns.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,11 @@ PipeOpRenameColumns = R6Class("PipeOpRenameColumns",
public = list(
initialize = function(id = "renamecolumns", param_vals = list()) {
ps = ps(
renaming = p_uty(tags = c("train", "predict", "required"), custom_check = function(x) {
check_character(x, any.missing = FALSE, names = "strict") %check&&%
check_names(x, type = "strict")
}),
renaming = p_uty(
custom_check = crate(function(x) check_character(x, any.missing = FALSE, names = "strict") %check&&% check_names(x, type = "strict"),
.parent = topenv()),
tags = c("train", "predict", "required")
),
ignore_missing = p_lgl(tags = c("train", "predict", "required"))
)
ps$values = list(renaming = character(0), ignore_missing = FALSE)
Expand Down
8 changes: 4 additions & 4 deletions R/PipeOpTrafo.R
Original file line number Diff line number Diff line change
Expand Up @@ -351,8 +351,8 @@ PipeOpTargetMutate = R6Class("PipeOpTargetMutate",
initialize = function(id = "targetmutate", param_vals = list(), new_task_type = NULL) {
private$.new_task_type = assert_choice(new_task_type, mlr_reflections$task_types$type, null.ok = TRUE)
ps = ps(
trafo = p_uty(tags = c("train", "predict"), custom_check = function(x) check_function(x, nargs = 1L)),
inverter = p_uty(tags = "predict", custom_check = function(x) check_function(x, nargs = 1L))
trafo = p_uty(tags = c("train", "predict"), custom_check = crate(function(x) check_function(x, nargs = 1L), .parent = topenv())),
inverter = p_uty(tags = "predict", custom_check = crate(function(x) check_function(x, nargs = 1L), .parent = topenv()))
)
# We could add a condition here for new_task_type on trafo and inverter when mlr-org/paradox#278 has an answer.
# HOWEVER conditions are broken in paradox, it is a terrible idea to use them in PipeOps,
Expand Down Expand Up @@ -573,8 +573,8 @@ PipeOpUpdateTarget = R6Class("PipeOpUpdateTarget",
initialize = function(id = "update_target", param_vals = list()) {
ps = ps(
trafo = p_uty(tags = c("train", "predict"), custom_check = function(x) check_function(x, nargs = 1L)),
new_target_name = p_uty(tags = c("train", "predict"), custom_check = function(x) check_character(x, any.missing = FALSE, len = 1L)),
new_task_type = p_uty(tags = c("train", "predict"), custom_check = function(x) check_choice(x, choices = mlr_reflections$task_types$type)),
new_target_name = p_uty(tags = c("train", "predict"), custom_check = crate(function(x) check_character(x, any.missing = FALSE, len = 1L), .parent = topenv())),
new_task_type = p_uty(tags = c("train", "predict"), custom_check = crate(function(x) check_choice(x, choices = mlr_reflections$task_types$type), .parent = topenv())),
drop_original_target = p_lgl(tags = c("train", "predict"))
)
ps$values = list(trafo = identity, drop_original_target = TRUE)
Expand Down
6 changes: 4 additions & 2 deletions R/PipeOpTuneThreshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,10 @@ PipeOpTuneThreshold = R6Class("PipeOpTuneThreshold",
ps = ps(
measure = p_uty(custom_check = check_class_or_character("Measure", mlr_measures), tags = "train"),
optimizer = p_uty(custom_check = check_optimizer, tags = "train"),
log_level = p_uty(tags = "train",
function(x) check_string(x) %check||% check_integerish(x))
log_level = p_uty(
custom_check = crate(function(x) check_string(x) %check||% check_integerish(x), .parent = topenv()),
tags = "train"
)
)
ps$values = list(measure = "classif.ce", optimizer = "gensa", log_level = "warn")
super$initialize(id, param_set = ps, param_vals = param_vals, packages = "bbotk",
Expand Down
51 changes: 40 additions & 11 deletions R/PipeOpVtreat.R
Original file line number Diff line number Diff line change
Expand Up @@ -140,26 +140,55 @@ PipeOpVtreat = R6Class("PipeOpVtreat",
rareSig = p_dbl(lower = 0, upper = 1, special_vals = list(NULL), tags = c("train", "regression", "classification", "multinomial")), # default NULL for regression, classification, 1 for multinomial
collarProb = p_dbl(lower = 0, upper = 1, default = 0, tags = c("train", "regression", "classification", "multinomial"), depends = quote(doCollar == TRUE)),
doCollar = p_lgl(default = FALSE, tags = c("train", "regression", "classification", "multinomial")),
codeRestriction = p_uty(default = NULL, custom_check = function(x) checkmate::check_character(x, any.missing = FALSE, null.ok = TRUE),
tags = c("train", "regression", "classification", "multinomial")),
customCoders = p_uty(default = NULL, custom_check = function(x) checkmate::check_list(x, null.ok = TRUE), tags = c("train", "regression", "classification", "multinomial")),
splitFunction = p_uty(default = NULL, custom_check = function(x) checkmate::check_function(x, args = c("nSplits", "nRows", "dframe", "y"), null.ok = TRUE),
tags = c("train", "regression", "classification", "multinomial")),
codeRestriction = p_uty(
default = NULL,
custom_check = crate(function(x) checkmate::check_character(x, any.missing = FALSE, null.ok = TRUE), .parent = topenv()),
tags = c("train", "regression", "classification", "multinomial")
),
customCoders = p_uty(
default = NULL,
custom_check = crate(function(x) checkmate::check_list(x, null.ok = TRUE), .parent = topenv()),
tags = c("train", "regression", "classification", "multinomial")
),
splitFunction = p_uty(
default = NULL,
custom_check = crate(function(x) checkmate::check_function(x, args = c("nSplits", "nRows", "dframe", "y"), null.ok = TRUE), .parent = topenv()),
tags = c("train", "regression", "classification", "multinomial")
),
ncross = p_int(lower = 2L, upper = Inf, default = 3L, tags = c("train", "regression", "classification", "multinomial")),
forceSplit = p_lgl(default = FALSE, tags = c("train", "regression", "classification", "multinomial")),
catScaling = p_lgl(tags = c("train", "regression", "classification", "multinomial")), # default TRUE for regression, classification, FALSE for multinomial
verbose = p_lgl(default = FALSE, tags = c("train", "regression", "classification", "multinomial")),
use_paralell = p_lgl(default = TRUE, tags = c("train", "regression", "classification", "multinomial")),
missingness_imputation = p_uty(default = NULL, custom_check = function(x) checkmate::check_function(x, args = c("values", "weights"), null.ok = TRUE),
tags = c("train", "regression", "classification", "multinomial")),
missingness_imputation = p_uty(
default = NULL,
custom_check = crate(function(x) checkmate::check_function(x, args = c("values", "weights"), null.ok = TRUE), .parent = topenv()),
tags = c("train", "regression", "classification", "multinomial")
),
pruneSig = p_dbl(lower = 0, upper = 1, special_vals = list(NULL), default = NULL, tags = c("train", "regression", "classification")),
scale = p_lgl(default = FALSE, tags = c("train", "regression", "classification", "multinomial")),
varRestriction = p_uty(default = NULL, custom_check = function(x) checkmate::check_list(x, null.ok = TRUE), tags = c("train", "regression", "classification")),
trackedValues = p_uty(default = NULL, custom_check = function(x) checkmate::check_list(x, null.ok = TRUE), tags = c("train", "regression", "classification")),
varRestriction = p_uty(
default = NULL,
custom_check = crate(function(x) checkmate::check_list(x, null.ok = TRUE), .parent = topenv()),
tags = c("train", "regression", "classification")
),
trackedValues = p_uty(
default = NULL,
custom_check = crate(function(x) checkmate::check_list(x, null.ok = TRUE), .parent = topenv()),
tags = c("train", "regression", "classification")
),
# NOTE: check_for_duplicate_frames not needed
y_dependent_treatments = p_uty(default = "catB", custom_check = function(x) checkmate::check_character(x, any.missing = FALSE), tags = c("train", "multinomial")),
y_dependent_treatments = p_uty(
default = "catB",
custom_check = crate(function(x) checkmate::check_character(x, any.missing = FALSE), .parent = topenv()),
tags = c("train", "multinomial")
),
# NOTE: imputation_map is also in multinomial_parameters(); this is redundant so only include it here
imputation_map = p_uty(default = NULL, custom_check = function(x) checkmate::check_list(x, null.ok = TRUE), tags = c("train", "predict"))
imputation_map = p_uty(
default = NULL,
custom_check = crate(function(x) checkmate::check_list(x, null.ok = TRUE), .parent = topenv()),
tags = c("train", "predict")
)
# NOTE: parallelCluster missing intentionally and will be set to NULL
)
ps$values = list(recommended = TRUE, cols_to_copy = selector_none())
Expand Down

0 comments on commit 22a408c

Please sign in to comment.