Skip to content

Commit

Permalink
Merge pull request #551 from tidymodels/check-prop
Browse files Browse the repository at this point in the history
Check `prop` argument
  • Loading branch information
hfrick authored Sep 26, 2024
2 parents fe24aaa + ca5c220 commit 580a12f
Show file tree
Hide file tree
Showing 11 changed files with 118 additions and 23 deletions.
7 changes: 4 additions & 3 deletions R/initial_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
initial_split <- function(data, prop = 3 / 4,
strata = NULL, breaks = 4, pool = 0.1, ...) {
check_dots_empty()
check_prop(prop)

res <-
mc_cv(
data = data,
Expand Down Expand Up @@ -74,9 +76,7 @@ initial_split <- function(data, prop = 3 / 4,
#' @export
initial_time_split <- function(data, prop = 3 / 4, lag = 0, ...) {
check_dots_empty()
if (!is.numeric(prop) | prop >= 1 | prop <= 0) {
cli_abort("{.arg prop} must be a number on (0, 1).")
}
check_prop(prop)

if (!is.numeric(lag) | !(lag %% 1 == 0)) {
cli_abort("{.arg lag} must be a whole number.")
Expand Down Expand Up @@ -156,6 +156,7 @@ testing.rsplit <- function(x, ...) {
#' @export
group_initial_split <- function(data, group, prop = 3 / 4, ..., strata = NULL, pool = 0.1) {
check_dots_empty()
check_prop(prop)

if (missing(strata)) {
res <- group_mc_cv(
Expand Down
15 changes: 0 additions & 15 deletions R/make_groups.R
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,6 @@ balance_observations_helper <- function(data_split, v, target_per_fold) {

balance_prop <- function(prop, data_ind, v, replace = FALSE, strata = NULL, ...) {
rlang::check_dots_empty()
check_prop(prop, replace)

# This is the core difference between stratification and not:
#
Expand Down Expand Up @@ -290,20 +289,6 @@ balance_prop_helper <- function(prop, data_ind, v, replace) {
list_rbind()
}

check_prop <- function(prop, replace) {
acceptable_prop <- is.numeric(prop)
acceptable_prop <- acceptable_prop &&
((prop <= 1 && replace) || (prop < 1 && !replace))
acceptable_prop <- acceptable_prop && prop > 0
if (!acceptable_prop) {
cli_abort(
"{.arg prop} must be a number between 0 and 1.",
call = rlang::caller_env()
)
}
}


collapse_groups <- function(freq_table, data_ind, v) {
data_ind <- dplyr::left_join(
data_ind,
Expand Down
6 changes: 2 additions & 4 deletions R/mc.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
mc_cv <- function(data, prop = 3 / 4, times = 25,
strata = NULL, breaks = 4, pool = 0.1, ...) {
check_dots_empty()
check_prop(prop)

if (!missing(strata)) {
strata <- tidyselect::vars_select(names(data), !!enquo(strata))
Expand Down Expand Up @@ -103,10 +104,6 @@ mc_complement <- function(ind, n) {

mc_splits <- function(data, prop = 3 / 4, times = 25,
strata = NULL, breaks = 4, pool = 0.1) {
if (!is.numeric(prop) | prop >= 1 | prop <= 0) {
cli_abort("{.arg prop} must be a number on (0, 1).")
}

n <- nrow(data)
if (is.null(strata)) {
indices <- purrr::map(rep(n, times), sample, size = floor(n * prop))
Expand Down Expand Up @@ -170,6 +167,7 @@ group_mc_cv <- function(data, group, prop = 3 / 4, times = 25, ...,
strata = NULL, pool = 0.1) {

check_dots_empty()
check_prop(prop)

group <- validate_group({{ group }}, data)

Expand Down
11 changes: 11 additions & 0 deletions R/misc.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ add_class <- function(x, cls) {
x
}

check_prop <- function(prop, call = caller_env()) {
check_number_decimal(prop, call = call)
if (!(prop > 0)) {
cli_abort("{.arg prop} must be greater than 0.", call = call)
}
if (!(prop < 1)) {
cli_abort("{.arg prop} must be less than 1.", call = call)
}
invisible(NULL)
}

check_strata <- function(strata, data, call = caller_env()) {
check_string(strata, allow_null = TRUE, call = call)

Expand Down
3 changes: 3 additions & 0 deletions R/validation_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ validation_split <- function(data, prop = 3 / 4,
)

check_dots_empty()
check_prop(prop)

if (!missing(strata)) {
strata <- tidyselect::vars_select(names(data), !!enquo(strata))
Expand Down Expand Up @@ -114,6 +115,7 @@ validation_time_split <- function(data, prop = 3 / 4, lag = 0, ...) {

check_dots_empty()

check_prop(prop)
if (!is.numeric(prop) | prop >= 1 | prop <= 0) {
rlang::abort("`prop` must be a number on (0, 1).")
}
Expand Down Expand Up @@ -155,6 +157,7 @@ group_validation_split <- function(data, group, prop = 3 / 4, ..., strata = NULL
check_dots_empty()

group <- validate_group({{ group }}, data)
check_prop(prop)

if (!missing(strata)) {
strata <- check_grouped_strata({{ group }}, {{ strata }}, pool, data)
Expand Down
26 changes: 25 additions & 1 deletion tests/testthat/_snaps/initial_split.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
initial_time_split(drinks, prop = 2)
Condition
Error in `initial_time_split()`:
! `prop` must be a number on (0, 1).
! `prop` must be less than 1.

---

Expand Down Expand Up @@ -38,3 +38,27 @@
<Training/Testing/Total>
<24/8/32>

# prop is checked

Code
initial_split(mtcars, prop = 1)
Condition
Error in `initial_split()`:
! `prop` must be less than 1.

---

Code
initial_time_split(mtcars, prop = 1)
Condition
Error in `initial_time_split()`:
! `prop` must be less than 1.

---

Code
group_initial_split(mtcars, group = "cyl", prop = 1)
Condition
Error in `group_initial_split()`:
! `prop` must be less than 1.

24 changes: 24 additions & 0 deletions tests/testthat/_snaps/mc.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,21 @@
# bad args

Code
mc_cv(mtcars, prop = -1)
Condition
Error in `mc_cv()`:
! `prop` must be greater than 0.

---

Code
mc_cv(mtcars, prop = 1)
Condition
Error in `mc_cv()`:
! `prop` must be less than 1.

---

Code
mc_cv(warpbreaks, strata = warpbreaks$tension)
Condition
Expand Down Expand Up @@ -70,6 +86,14 @@
Error in `group_mc_cv()`:
! `group` must be a single string, not `NULL`.

---

Code
group_mc_cv(mtcars, group = "cyl", prop = 1)
Condition
Error in `group_mc_cv()`:
! `prop` must be less than 1.

---

Code
Expand Down
24 changes: 24 additions & 0 deletions tests/testthat/_snaps/validation_split.md
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,30 @@

# bad args

Code
validation_split(mtcars, prop = 1)
Condition
Error in `validation_split()`:
! `prop` must be less than 1.

---

Code
validation_time_split(mtcars, prop = 1)
Condition
Error in `validation_time_split()`:
! `prop` must be less than 1.

---

Code
group_validation_split(mtcars, group = "cyl", prop = 1)
Condition
Error in `group_validation_split()`:
! `prop` must be less than 1.

---

Code
validation_split(warpbreaks, strata = warpbreaks$tension)
Condition
Expand Down
6 changes: 6 additions & 0 deletions tests/testthat/test-initial_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -118,3 +118,9 @@ test_that("printing initial split objects", {
expect_snapshot(initial_split(mtcars))
expect_snapshot(initial_time_split(mtcars))
})

test_that("prop is checked", {
expect_snapshot(error = TRUE, {initial_split(mtcars, prop = 1)})
expect_snapshot(error = TRUE, {initial_time_split(mtcars, prop = 1)})
expect_snapshot(error = TRUE, {group_initial_split(mtcars, group = "cyl", prop = 1)})
})
9 changes: 9 additions & 0 deletions tests/testthat/test-mc.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,12 @@ test_that("strata", {


test_that("bad args", {
expect_snapshot(error = TRUE, {
mc_cv(mtcars, prop = -1)
})
expect_snapshot(error = TRUE, {
mc_cv(mtcars, prop = 1)
})
expect_snapshot(error = TRUE, {
mc_cv(warpbreaks, strata = warpbreaks$tension)
})
Expand Down Expand Up @@ -107,6 +113,9 @@ test_that("grouping - bad args", {
expect_snapshot(error = TRUE, {
group_mc_cv(warpbreaks)
})
expect_snapshot(error = TRUE, {
group_mc_cv(mtcars, group = "cyl", prop = 1)
})
expect_snapshot(error = TRUE, {
group_mc_cv(warpbreaks, group = "tension", balance = "groups")
})
Expand Down
10 changes: 10 additions & 0 deletions tests/testthat/test-validation_split.R
Original file line number Diff line number Diff line change
Expand Up @@ -259,6 +259,16 @@ test_that("strata", {
test_that("bad args", {
withr::local_options(lifecycle_verbosity = "quiet")

expect_snapshot(error = TRUE, {
validation_split(mtcars, prop = 1)
})
expect_snapshot(error = TRUE, {
validation_time_split(mtcars, prop = 1)
})
expect_snapshot(error = TRUE, {
group_validation_split(mtcars, group = "cyl", prop = 1)
})

expect_snapshot(error = TRUE, {
validation_split(warpbreaks, strata = warpbreaks$tension)
})
Expand Down

0 comments on commit 580a12f

Please sign in to comment.