From d9fcbfe81bf3c35c100f6c9fb5f57cd8868338fc Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 25 Sep 2024 16:55:43 +0100 Subject: [PATCH 1/4] add tests on `prop` argument --- tests/testthat/_snaps/initial_split.md | 24 +++++++++++++++++++++++ tests/testthat/_snaps/mc.md | 16 +++++++++++++++ tests/testthat/_snaps/validation_split.md | 24 +++++++++++++++++++++++ tests/testthat/test-initial_split.R | 6 ++++++ tests/testthat/test-mc.R | 6 ++++++ tests/testthat/test-validation_split.R | 10 ++++++++++ 6 files changed, 86 insertions(+) diff --git a/tests/testthat/_snaps/initial_split.md b/tests/testthat/_snaps/initial_split.md index 69344098..633d69bd 100644 --- a/tests/testthat/_snaps/initial_split.md +++ b/tests/testthat/_snaps/initial_split.md @@ -38,3 +38,27 @@ <24/8/32> +# prop is checked + + Code + initial_split(mtcars, prop = 1) + Condition + Error in `mc_splits()`: + ! `prop` must be a number on (0, 1). + +--- + + Code + initial_time_split(mtcars, prop = 1) + Condition + Error in `initial_time_split()`: + ! `prop` must be a number on (0, 1). + +--- + + Code + group_initial_split(mtcars, group = "cyl", prop = 1) + Condition + Error in `balance_prop()`: + ! `prop` must be a number between 0 and 1. + diff --git a/tests/testthat/_snaps/mc.md b/tests/testthat/_snaps/mc.md index 1fbc3899..a4e0a077 100644 --- a/tests/testthat/_snaps/mc.md +++ b/tests/testthat/_snaps/mc.md @@ -1,5 +1,13 @@ # bad args + Code + mc_cv(mtcars, prop = 1) + Condition + Error in `mc_splits()`: + ! `prop` must be a number on (0, 1). + +--- + Code mc_cv(warpbreaks, strata = warpbreaks$tension) Condition @@ -70,6 +78,14 @@ Error in `group_mc_cv()`: ! `group` must be a single string, not `NULL`. +--- + + Code + group_mc_cv(mtcars, group = "cyl", prop = 1) + Condition + Error in `balance_prop()`: + ! `prop` must be a number between 0 and 1. + --- Code diff --git a/tests/testthat/_snaps/validation_split.md b/tests/testthat/_snaps/validation_split.md index 692f1440..e925925b 100644 --- a/tests/testthat/_snaps/validation_split.md +++ b/tests/testthat/_snaps/validation_split.md @@ -79,6 +79,30 @@ # bad args + Code + validation_split(mtcars, prop = 1) + Condition + Error in `mc_splits()`: + ! `prop` must be a number on (0, 1). + +--- + + Code + validation_time_split(mtcars, prop = 1) + Condition + Error in `validation_time_split()`: + ! `prop` must be a number on (0, 1). + +--- + + Code + group_validation_split(mtcars, group = "cyl", prop = 1) + Condition + Error in `balance_prop()`: + ! `prop` must be a number between 0 and 1. + +--- + Code validation_split(warpbreaks, strata = warpbreaks$tension) Condition diff --git a/tests/testthat/test-initial_split.R b/tests/testthat/test-initial_split.R index 775b8a47..3ca0bd80 100644 --- a/tests/testthat/test-initial_split.R +++ b/tests/testthat/test-initial_split.R @@ -118,3 +118,9 @@ test_that("printing initial split objects", { expect_snapshot(initial_split(mtcars)) expect_snapshot(initial_time_split(mtcars)) }) + +test_that("prop is checked", { + expect_snapshot(error = TRUE, {initial_split(mtcars, prop = 1)}) + expect_snapshot(error = TRUE, {initial_time_split(mtcars, prop = 1)}) + expect_snapshot(error = TRUE, {group_initial_split(mtcars, group = "cyl", prop = 1)}) +}) diff --git a/tests/testthat/test-mc.R b/tests/testthat/test-mc.R index 742fd13f..a9b6a691 100644 --- a/tests/testthat/test-mc.R +++ b/tests/testthat/test-mc.R @@ -72,6 +72,9 @@ test_that("strata", { test_that("bad args", { + expect_snapshot(error = TRUE, { + mc_cv(mtcars, prop = 1) + }) expect_snapshot(error = TRUE, { mc_cv(warpbreaks, strata = warpbreaks$tension) }) @@ -107,6 +110,9 @@ test_that("grouping - bad args", { expect_snapshot(error = TRUE, { group_mc_cv(warpbreaks) }) + expect_snapshot(error = TRUE, { + group_mc_cv(mtcars, group = "cyl", prop = 1) + }) expect_snapshot(error = TRUE, { group_mc_cv(warpbreaks, group = "tension", balance = "groups") }) diff --git a/tests/testthat/test-validation_split.R b/tests/testthat/test-validation_split.R index 10ddeebb..4c0c7b14 100644 --- a/tests/testthat/test-validation_split.R +++ b/tests/testthat/test-validation_split.R @@ -259,6 +259,16 @@ test_that("strata", { test_that("bad args", { withr::local_options(lifecycle_verbosity = "quiet") + expect_snapshot(error = TRUE, { + validation_split(mtcars, prop = 1) + }) + expect_snapshot(error = TRUE, { + validation_time_split(mtcars, prop = 1) + }) + expect_snapshot(error = TRUE, { + group_validation_split(mtcars, group = "cyl", prop = 1) + }) + expect_snapshot(error = TRUE, { validation_split(warpbreaks, strata = warpbreaks$tension) }) From 92e653a5b444812ce77c917649acbc78b5809047 Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Wed, 25 Sep 2024 17:14:50 +0100 Subject: [PATCH 2/4] rework `check_prop()` and use it --- R/initial_split.R | 7 ++++--- R/make_groups.R | 15 --------------- R/mc.R | 6 ++---- R/misc.R | 11 +++++++++++ R/validation_split.R | 3 +++ tests/testthat/_snaps/initial_split.md | 12 ++++++------ tests/testthat/_snaps/mc.md | 8 ++++---- tests/testthat/_snaps/validation_split.md | 10 +++++----- 8 files changed, 35 insertions(+), 37 deletions(-) diff --git a/R/initial_split.R b/R/initial_split.R index c754e424..a11a9f8e 100644 --- a/R/initial_split.R +++ b/R/initial_split.R @@ -44,6 +44,8 @@ initial_split <- function(data, prop = 3 / 4, strata = NULL, breaks = 4, pool = 0.1, ...) { check_dots_empty() + check_prop(prop) + res <- mc_cv( data = data, @@ -74,9 +76,7 @@ initial_split <- function(data, prop = 3 / 4, #' @export initial_time_split <- function(data, prop = 3 / 4, lag = 0, ...) { check_dots_empty() - if (!is.numeric(prop) | prop >= 1 | prop <= 0) { - cli_abort("{.arg prop} must be a number on (0, 1).") - } + check_prop(prop) if (!is.numeric(lag) | !(lag %% 1 == 0)) { cli_abort("{.arg lag} must be a whole number.") @@ -156,6 +156,7 @@ testing.rsplit <- function(x, ...) { #' @export group_initial_split <- function(data, group, prop = 3 / 4, ..., strata = NULL, pool = 0.1) { check_dots_empty() + check_prop(prop) if (missing(strata)) { res <- group_mc_cv( diff --git a/R/make_groups.R b/R/make_groups.R index 12635172..59fe1823 100644 --- a/R/make_groups.R +++ b/R/make_groups.R @@ -231,7 +231,6 @@ balance_observations_helper <- function(data_split, v, target_per_fold) { balance_prop <- function(prop, data_ind, v, replace = FALSE, strata = NULL, ...) { rlang::check_dots_empty() - check_prop(prop, replace) # This is the core difference between stratification and not: # @@ -290,20 +289,6 @@ balance_prop_helper <- function(prop, data_ind, v, replace) { list_rbind() } -check_prop <- function(prop, replace) { - acceptable_prop <- is.numeric(prop) - acceptable_prop <- acceptable_prop && - ((prop <= 1 && replace) || (prop < 1 && !replace)) - acceptable_prop <- acceptable_prop && prop > 0 - if (!acceptable_prop) { - cli_abort( - "{.arg prop} must be a number between 0 and 1.", - call = rlang::caller_env() - ) - } -} - - collapse_groups <- function(freq_table, data_ind, v) { data_ind <- dplyr::left_join( data_ind, diff --git a/R/mc.R b/R/mc.R index 35ebfd49..e97f030f 100644 --- a/R/mc.R +++ b/R/mc.R @@ -52,6 +52,7 @@ mc_cv <- function(data, prop = 3 / 4, times = 25, strata = NULL, breaks = 4, pool = 0.1, ...) { check_dots_empty() + check_prop(prop) if (!missing(strata)) { strata <- tidyselect::vars_select(names(data), !!enquo(strata)) @@ -103,10 +104,6 @@ mc_complement <- function(ind, n) { mc_splits <- function(data, prop = 3 / 4, times = 25, strata = NULL, breaks = 4, pool = 0.1) { - if (!is.numeric(prop) | prop >= 1 | prop <= 0) { - cli_abort("{.arg prop} must be a number on (0, 1).") - } - n <- nrow(data) if (is.null(strata)) { indices <- purrr::map(rep(n, times), sample, size = floor(n * prop)) @@ -170,6 +167,7 @@ group_mc_cv <- function(data, group, prop = 3 / 4, times = 25, ..., strata = NULL, pool = 0.1) { check_dots_empty() + check_prop(prop) group <- validate_group({{ group }}, data) diff --git a/R/misc.R b/R/misc.R index ab174bae..6e0f1969 100644 --- a/R/misc.R +++ b/R/misc.R @@ -98,6 +98,17 @@ add_class <- function(x, cls) { x } +check_prop <- function(prop, call = caller_env()) { + check_number_decimal(prop, call = call) + if (!(prop > 0)) { + cli_abort("{.arg prop} must be greater than 0.", call = call) + } + if (!(prop < 1)) { + cli_abort("{.arg prop} must be less than 1.", call = call) + } + invisible(NULL) +} + check_strata <- function(strata, data, call = caller_env()) { check_string(strata, allow_null = TRUE, call = call) diff --git a/R/validation_split.R b/R/validation_split.R index bbe509b1..da97c2b1 100644 --- a/R/validation_split.R +++ b/R/validation_split.R @@ -59,6 +59,7 @@ validation_split <- function(data, prop = 3 / 4, ) check_dots_empty() + check_prop(prop) if (!missing(strata)) { strata <- tidyselect::vars_select(names(data), !!enquo(strata)) @@ -114,6 +115,7 @@ validation_time_split <- function(data, prop = 3 / 4, lag = 0, ...) { check_dots_empty() + check_prop(prop) if (!is.numeric(prop) | prop >= 1 | prop <= 0) { rlang::abort("`prop` must be a number on (0, 1).") } @@ -155,6 +157,7 @@ group_validation_split <- function(data, group, prop = 3 / 4, ..., strata = NULL check_dots_empty() group <- validate_group({{ group }}, data) + check_prop(prop) if (!missing(strata)) { strata <- check_grouped_strata({{ group }}, {{ strata }}, pool, data) diff --git a/tests/testthat/_snaps/initial_split.md b/tests/testthat/_snaps/initial_split.md index 633d69bd..37bb892a 100644 --- a/tests/testthat/_snaps/initial_split.md +++ b/tests/testthat/_snaps/initial_split.md @@ -4,7 +4,7 @@ initial_time_split(drinks, prop = 2) Condition Error in `initial_time_split()`: - ! `prop` must be a number on (0, 1). + ! `prop` must be less than 1. --- @@ -43,8 +43,8 @@ Code initial_split(mtcars, prop = 1) Condition - Error in `mc_splits()`: - ! `prop` must be a number on (0, 1). + Error in `initial_split()`: + ! `prop` must be less than 1. --- @@ -52,13 +52,13 @@ initial_time_split(mtcars, prop = 1) Condition Error in `initial_time_split()`: - ! `prop` must be a number on (0, 1). + ! `prop` must be less than 1. --- Code group_initial_split(mtcars, group = "cyl", prop = 1) Condition - Error in `balance_prop()`: - ! `prop` must be a number between 0 and 1. + Error in `group_initial_split()`: + ! `prop` must be less than 1. diff --git a/tests/testthat/_snaps/mc.md b/tests/testthat/_snaps/mc.md index a4e0a077..4de31cac 100644 --- a/tests/testthat/_snaps/mc.md +++ b/tests/testthat/_snaps/mc.md @@ -3,8 +3,8 @@ Code mc_cv(mtcars, prop = 1) Condition - Error in `mc_splits()`: - ! `prop` must be a number on (0, 1). + Error in `mc_cv()`: + ! `prop` must be less than 1. --- @@ -83,8 +83,8 @@ Code group_mc_cv(mtcars, group = "cyl", prop = 1) Condition - Error in `balance_prop()`: - ! `prop` must be a number between 0 and 1. + Error in `group_mc_cv()`: + ! `prop` must be less than 1. --- diff --git a/tests/testthat/_snaps/validation_split.md b/tests/testthat/_snaps/validation_split.md index e925925b..fbce3bc2 100644 --- a/tests/testthat/_snaps/validation_split.md +++ b/tests/testthat/_snaps/validation_split.md @@ -82,8 +82,8 @@ Code validation_split(mtcars, prop = 1) Condition - Error in `mc_splits()`: - ! `prop` must be a number on (0, 1). + Error in `validation_split()`: + ! `prop` must be less than 1. --- @@ -91,15 +91,15 @@ validation_time_split(mtcars, prop = 1) Condition Error in `validation_time_split()`: - ! `prop` must be a number on (0, 1). + ! `prop` must be less than 1. --- Code group_validation_split(mtcars, group = "cyl", prop = 1) Condition - Error in `balance_prop()`: - ! `prop` must be a number between 0 and 1. + Error in `group_validation_split()`: + ! `prop` must be less than 1. --- From bc6d8dea3cef286b803aed3966d4c4779d9a418e Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Thu, 26 Sep 2024 10:46:11 +0100 Subject: [PATCH 3/4] Update tests/testthat/test-mc.R Co-authored-by: Simon P. Couch --- tests/testthat/test-mc.R | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/testthat/test-mc.R b/tests/testthat/test-mc.R index a9b6a691..63744aef 100644 --- a/tests/testthat/test-mc.R +++ b/tests/testthat/test-mc.R @@ -72,6 +72,9 @@ test_that("strata", { test_that("bad args", { + expect_snapshot(error = TRUE, { + mc_cv(mtcars, prop = -1) + }) expect_snapshot(error = TRUE, { mc_cv(mtcars, prop = 1) }) From ca5c220bbb09aa6c9bfe82ffe0014e36a51201af Mon Sep 17 00:00:00 2001 From: Hannah Frick Date: Thu, 26 Sep 2024 11:40:29 +0100 Subject: [PATCH 4/4] add snapshot --- tests/testthat/_snaps/mc.md | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/testthat/_snaps/mc.md b/tests/testthat/_snaps/mc.md index 4de31cac..55d1d0e1 100644 --- a/tests/testthat/_snaps/mc.md +++ b/tests/testthat/_snaps/mc.md @@ -1,5 +1,13 @@ # bad args + Code + mc_cv(mtcars, prop = -1) + Condition + Error in `mc_cv()`: + ! `prop` must be greater than 0. + +--- + Code mc_cv(mtcars, prop = 1) Condition