Skip to content

Commit

Permalink
Merge pull request #805 from mlr-org/handle_nonvarname_classnames
Browse files Browse the repository at this point in the history
PipeOpTuneThreshold handles classes with non-varname levels.
  • Loading branch information
mb706 authored Aug 17, 2024
2 parents 44ec01a + a7a5e23 commit 14589d6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 8 deletions.
21 changes: 13 additions & 8 deletions R/PipeOpTuneThreshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -115,23 +115,26 @@ PipeOpTuneThreshold = R6Class("PipeOpTuneThreshold",
pred$set_threshold(self$state$threshold)
return(list(pred))
},
.objfun = function(xs, pred, measure) {
lvls = colnames(pred$prob)
res = pred$set_threshold(unlist(xs))$score(measure)
.objfun = function(xs, pred, measure, paramname_to_column_map) {
thresholds = unlist(xs)
names(thresholds) = paramname_to_column_map[names(thresholds)]
res = pred$set_threshold(thresholds)$score(measure)
if (!measure$minimize) res = -res
return(setNames(list(res), measure$id))
},
.optimize_objfun = function(pred) {
optimizer = self$param_set$values$optimizer
if (inherits(optimizer, "character")) optimizer = bbotk::opt(optimizer)
if (inherits(optimizer, "OptimizerGenSA")) optimizer$param_set$values$trace.mat = TRUE # https://github.com/mlr-org/bbotk/issues/214
ps = private$.make_param_set(pred)
pnames = make.names(colnames(pred$prob), unique = TRUE)
paramname_to_column_map = setNames(colnames(pred$prob), pnames)
ps = private$.make_param_set(pred, pnames)
measure = self$param_set$values$measure
if (is.character(measure)) measure = msr(measure) else measure
codomain = do.call(paradox::ps, structure(list(p_dbl(tags = ifelse(measure$minimize, "minimize", "maximize"))), names = measure$id))

objfun = bbotk::ObjectiveRFun$new(
fun = function(xs) private$.objfun(xs, pred = pred, measure = measure),
fun = function(xs) private$.objfun(xs, pred = pred, measure = measure, paramname_to_column_map = paramname_to_column_map),
domain = ps, codomain = codomain
)
inst = bbotk::OptimInstanceSingleCrit$new(
Expand All @@ -146,10 +149,12 @@ PipeOpTuneThreshold = R6Class("PipeOpTuneThreshold",
on.exit(lgr$set_threshold(old_threshold))
lgr$set_threshold(self$param_set$values$log_level)
optimizer$optimize(inst)
unlist(inst$result_x_domain)
result = unlist(inst$result_x_domain)
names(result) = paramname_to_column_map[names(result)]
result
},
.make_param_set = function(pred) {
pset = setNames(map(colnames(pred$prob), function(x) p_dbl(0,1)), colnames(pred$prob))
.make_param_set = function(pred, pnames) {
pset = setNames(map(pnames, function(x) p_dbl(0,1)), pnames)
mlr3misc::invoke(paradox::ps, .args = pset)
},
.task_to_prediction = function(input) {
Expand Down
17 changes: 17 additions & 0 deletions tests/testthat/test_pipeop_tunethreshold.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,20 @@ test_that("tunethreshold graph works", {


})

test_that("threshold works for classes that are not valid R names", {
skip_if_not_installed("rpart")
ppl = po("learner_cv", lrn("classif.rpart", predict_type = "prob")) %>>% po("tunethreshold")

cols = c("0", "1", "-", "_")
testtask = as_task_classif(
data.frame(x = rep(1:3, each = 24), y = factor(rep(letters[1:3], each = 24)),
target = factor(rep(c(cols, make.names(cols)), each = 9))),
target = "target", id = "testtask"
)

ppl$train(testtask)

ppl$predict(testtask)

})

0 comments on commit 14589d6

Please sign in to comment.