diff --git a/R/boot.R b/R/boot.R index edb40593..13cff66b 100644 --- a/R/boot.R +++ b/R/boot.R @@ -76,7 +76,7 @@ bootstraps <- if (length(strata) == 0) strata <- NULL } - strata_check(strata, data) + check_strata(strata, data) split_objs <- boot_splits( diff --git a/R/initial_validation_split.R b/R/initial_validation_split.R index 4d89ebc6..95f8f884 100644 --- a/R/initial_validation_split.R +++ b/R/initial_validation_split.R @@ -68,7 +68,7 @@ initial_validation_split <- function(data, strata <- NULL } } - strata_check(strata, data) + check_strata(strata, data) split_train <- mc_cv( data = data, @@ -209,7 +209,7 @@ group_initial_validation_split <- function(data, strata <- NULL } } - strata_check(strata, data) + check_strata(strata, data) if (missing(strata)) { split_train <- group_mc_cv( diff --git a/R/mc.R b/R/mc.R index 12d80677..35ebfd49 100644 --- a/R/mc.R +++ b/R/mc.R @@ -58,7 +58,7 @@ mc_cv <- function(data, prop = 3 / 4, times = 25, if (length(strata) == 0) strata <- NULL } - strata_check(strata, data) + check_strata(strata, data) split_objs <- mc_splits( diff --git a/R/misc.R b/R/misc.R index bb202a39..ab174bae 100644 --- a/R/misc.R +++ b/R/misc.R @@ -98,16 +98,17 @@ add_class <- function(x, cls) { x } -strata_check <- function(strata, data) { +check_strata <- function(strata, data, call = caller_env()) { + check_string(strata, allow_null = TRUE, call = call) + if (!is.null(strata)) { - if (!is.character(strata) | length(strata) != 1) { - cli_abort("{.arg strata} should be a single name or character value.") - } if (inherits(data[, strata], "Surv")) { - cli_abort("{.arg strata} cannot be a {.cls Surv} object. Use the time or event variable directly.") - } - if (!(strata %in% names(data))) { - cli_abort("{strata} is not in {.arg data}.") + cli_abort(c( + "{.field strata} cannot be a {.cls Surv} object.", + "i" = "Use the time or event variable directly." + ), + call = call + ) } } invisible(NULL) diff --git a/R/validation_split.R b/R/validation_split.R index 26e5e947..bbe509b1 100644 --- a/R/validation_split.R +++ b/R/validation_split.R @@ -67,7 +67,7 @@ validation_split <- function(data, prop = 3 / 4, } } - strata_check(strata, data) + check_strata(strata, data) split_objs <- mc_splits( diff --git a/R/vfold.R b/R/vfold.R index 2a2ac616..faded51e 100644 --- a/R/vfold.R +++ b/R/vfold.R @@ -71,7 +71,7 @@ vfold_cv <- function(data, v = 10, repeats = 1, if (length(strata) == 0) strata <- NULL } - strata_check(strata, data) + check_strata(strata, data) check_repeats(repeats) if (repeats == 1) { diff --git a/tests/testthat/_snaps/boot.md b/tests/testthat/_snaps/boot.md index 442c72ca..8119cead 100644 --- a/tests/testthat/_snaps/boot.md +++ b/tests/testthat/_snaps/boot.md @@ -26,8 +26,8 @@ Code bootstraps(warpbreaks, strata = c("tension", "wool")) Condition - Error in `strata_check()`: - ! `strata` should be a single name or character value. + Error in `bootstraps()`: + ! `strata` must be a single string or `NULL`, not a character vector. --- diff --git a/tests/testthat/_snaps/make_strata.md b/tests/testthat/_snaps/make_strata.md index 35a70cfb..2c5ff244 100644 --- a/tests/testthat/_snaps/make_strata.md +++ b/tests/testthat/_snaps/make_strata.md @@ -67,8 +67,9 @@ # don't stratify on Surv objects Code - strata_check("surv", df) + check_strata("surv", df) Condition - Error in `strata_check()`: - ! `strata` cannot be a object. Use the time or event variable directly. + Error: + ! strata cannot be a object. + i Use the time or event variable directly. diff --git a/tests/testthat/_snaps/mc.md b/tests/testthat/_snaps/mc.md index 22446560..f7ab64cf 100644 --- a/tests/testthat/_snaps/mc.md +++ b/tests/testthat/_snaps/mc.md @@ -12,8 +12,8 @@ Code mc_cv(warpbreaks, strata = c("tension", "wool")) Condition - Error in `strata_check()`: - ! `strata` should be a single name or character value. + Error in `mc_cv()`: + ! `strata` must be a single string or `NULL`, not a character vector. # printing diff --git a/tests/testthat/_snaps/validation_split.md b/tests/testthat/_snaps/validation_split.md index a9215489..692f1440 100644 --- a/tests/testthat/_snaps/validation_split.md +++ b/tests/testthat/_snaps/validation_split.md @@ -91,8 +91,8 @@ Code validation_split(warpbreaks, strata = c("tension", "wool")) Condition - Error in `strata_check()`: - ! `strata` should be a single name or character value. + Error in `validation_split()`: + ! `strata` must be a single string or `NULL`, not a character vector. # printing diff --git a/tests/testthat/_snaps/vfold.md b/tests/testthat/_snaps/vfold.md index d590cc6b..add53300 100644 --- a/tests/testthat/_snaps/vfold.md +++ b/tests/testthat/_snaps/vfold.md @@ -7,7 +7,7 @@ Stratifying groups that make up 1% of the data may be statistically risky. * Consider increasing `pool` to at least 0.1 -# bad args +# strata arg is checked Code vfold_cv(iris, strata = iris$Species) @@ -21,11 +21,28 @@ Code vfold_cv(iris, strata = c("Species", "Sepal.Width")) Condition - Error in `strata_check()`: - ! `strata` should be a single name or character value. + Error in `vfold_cv()`: + ! `strata` must be a single string or `NULL`, not a character vector. + +--- + + Code + vfold_cv(iris, strata = NA) + Condition + Error in `vfold_cv()`: + ! Selections can't have missing values. --- + Code + vfold_cv(dat, strata = b) + Condition + Error in `vfold_cv()`: + ! strata cannot be a object. + i Use the time or event variable directly. + +# bad args + Code vfold_cv(iris, v = -500) Condition diff --git a/tests/testthat/test-make_strata.R b/tests/testthat/test-make_strata.R index 3704783b..08ace53f 100644 --- a/tests/testthat/test-make_strata.R +++ b/tests/testthat/test-make_strata.R @@ -39,7 +39,7 @@ test_that("bad data", { -# strata_check() ---------------------------------------------------------- +# check_strata() ---------------------------------------------------------- test_that("don't stratify on Surv objects", { df <- data.frame( @@ -58,6 +58,6 @@ test_that("don't stratify on Surv objects", { ) expect_snapshot(error = TRUE, { - strata_check("surv", df) + check_strata("surv", df) }) }) diff --git a/tests/testthat/test-vfold.R b/tests/testthat/test-vfold.R index 9b81bae1..f6e7bd1f 100644 --- a/tests/testthat/test-vfold.R +++ b/tests/testthat/test-vfold.R @@ -74,14 +74,37 @@ test_that("strata", { ) }) - -test_that("bad args", { +test_that("strata arg is checked", { expect_snapshot(error = TRUE, { vfold_cv(iris, strata = iris$Species) }) + + # errors from `check_strata()` expect_snapshot(error = TRUE, { vfold_cv(iris, strata = c("Species", "Sepal.Width")) }) + + expect_snapshot(error = TRUE, { + vfold_cv(iris, strata = NA) + }) + + # make Surv object without a dependeny on the survival package + surv_obj <- structure( + c(306, 455, 1010, 210, 883, 1, 1, 0, 1, 1), + dim = c(5L, 2L), + dimnames = list(NULL, c("time", "status")), + type = "right", + class = "Surv" + ) + dat <- data.frame(a = 1:5) + # add Surv object like this for older R versions (<= 4.2.3) + dat$b <- surv_obj + expect_snapshot(error = TRUE, { + vfold_cv(dat, strata = b) + }) +}) + +test_that("bad args", { expect_snapshot(error = TRUE, { vfold_cv(iris, v = -500) })