Skip to content

Commit

Permalink
Merge pull request #283 from cmu-delphi/ds/check_enough_train_data
Browse files Browse the repository at this point in the history
feat: check_enough_train_data
  • Loading branch information
dshemetov authored Jan 22, 2024
2 parents ee11b1e + b869222 commit 6ddffb2
Show file tree
Hide file tree
Showing 14 changed files with 474 additions and 34 deletions.
2 changes: 2 additions & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
^renv$
^renv\.lock$
^epipredict\.Rproj$
^\.Rproj\.user$
^LICENSE\.md$
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ inst/doc
.DS_Store
/doc/
/Meta/
.Rprofile
renv.lock
renv/
10 changes: 5 additions & 5 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: epipredict
Title: Basic epidemiology forecasting methods
Version: 0.0.7
Version: 0.0.8
Authors@R: c(
person("Daniel", "McDonald", , "[email protected]", role = c("aut", "cre")),
person("Ryan", "Tibshirani", , "[email protected]", role = "aut"),
Expand All @@ -22,11 +22,11 @@ License: MIT + file LICENSE
URL: https://github.com/cmu-delphi/epipredict/,
https://cmu-delphi.github.io/epipredict
BugReports: https://github.com/cmu-delphi/epipredict/issues/
Depends:
Depends:
epiprocess (>= 0.6.0),
parsnip (>= 1.0.0),
R (>= 3.5.0)
Imports:
Imports:
cli,
distributional,
dplyr,
Expand All @@ -48,7 +48,7 @@ Imports:
usethis,
vctrs,
workflows (>= 1.0.0)
Suggests:
Suggests:
covidcast,
data.table,
epidatr (>= 1.0.0),
Expand All @@ -61,7 +61,7 @@ Suggests:
rmarkdown,
testthat (>= 3.0.0),
xgboost
VignetteBuilder:
VignetteBuilder:
knitr
Remotes:
cmu-delphi/epidatr,
Expand Down
12 changes: 12 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ S3method(adjust_frosting,frosting)
S3method(apply_frosting,default)
S3method(apply_frosting,epi_workflow)
S3method(augment,epi_workflow)
S3method(bake,check_enough_train_data)
S3method(bake,epi_recipe)
S3method(bake,step_epi_ahead)
S3method(bake,step_epi_lag)
Expand Down Expand Up @@ -48,6 +49,7 @@ S3method(mean,dist_quantiles)
S3method(median,dist_quantiles)
S3method(predict,epi_workflow)
S3method(predict,flatline)
S3method(prep,check_enough_train_data)
S3method(prep,epi_recipe)
S3method(prep,step_epi_ahead)
S3method(prep,step_epi_lag)
Expand All @@ -60,6 +62,7 @@ S3method(print,arx_class)
S3method(print,arx_fcast)
S3method(print,canned_epipred)
S3method(print,cdc_baseline_fcast)
S3method(print,check_enough_train_data)
S3method(print,epi_recipe)
S3method(print,epi_workflow)
S3method(print,flat_fcast)
Expand Down Expand Up @@ -104,6 +107,7 @@ S3method(snap,default)
S3method(snap,dist_default)
S3method(snap,dist_quantiles)
S3method(snap,distribution)
S3method(tidy,check_enough_train_data)
S3method(tidy,frosting)
S3method(tidy,layer)
S3method(update,layer)
Expand All @@ -127,6 +131,7 @@ export(arx_forecaster)
export(bake)
export(cdc_baseline_args_list)
export(cdc_baseline_forecaster)
export(check_enough_train_data)
export(create_layer)
export(default_epi_recipe_blueprint)
export(detect_layer)
Expand Down Expand Up @@ -191,6 +196,12 @@ import(epiprocess)
import(parsnip)
import(recipes)
importFrom(cli,cli_abort)
importFrom(dplyr,across)
importFrom(dplyr,all_of)
importFrom(dplyr,group_by)
importFrom(dplyr,n)
importFrom(dplyr,summarise)
importFrom(dplyr,ungroup)
importFrom(epiprocess,growth_rate)
importFrom(generics,augment)
importFrom(generics,fit)
Expand Down Expand Up @@ -225,6 +236,7 @@ importFrom(stats,residuals)
importFrom(tibble,as_tibble)
importFrom(tibble,is_tibble)
importFrom(tibble,tibble)
importFrom(tidyr,drop_na)
importFrom(vctrs,as_list_of)
importFrom(vctrs,field)
importFrom(vctrs,new_rcrd)
Expand Down
55 changes: 30 additions & 25 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,44 +1,49 @@
# epipredict (development)

# epipredict 0.0.8

- add `check_enough_train_data` that will error if training data is too small
- added `check_enough_train_data` to `arx_forecaster`

# epipredict 0.0.7

* simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()`
- simplify `layer_residual_quantiles()` to avoid timesuck in `utils::methods()`

# epipredict 0.0.6

* rename the `dist_quantiles()` to be more descriptive, breaking change)
* removes previous `pivot_quantiles()` (now `*_wider()`, breaking change)
* add `pivot_quantiles_wider()` for easier plotting
* add complement `pivot_quantiles_longer()`
* add `cdc_baseline_forecaster()` and `flusight_hub_formatter()`
- rename the `dist_quantiles()` to be more descriptive, breaking change)
- removes previous `pivot_quantiles()` (now `*_wider()`, breaking change)
- add `pivot_quantiles_wider()` for easier plotting
- add complement `pivot_quantiles_longer()`
- add `cdc_baseline_forecaster()` and `flusight_hub_formatter()`

# epipredict 0.0.5

* add `smooth_quantile_reg()`
* improved printing of various methods / internals
* canned forecasters get a class
* fixed quantile bug in `flatline_forecaster()`
* add functionality to output the unfit workflow from the canned forecasters
- add `smooth_quantile_reg()`
- improved printing of various methods / internals
- canned forecasters get a class
- fixed quantile bug in `flatline_forecaster()`
- add functionality to output the unfit workflow from the canned forecasters

# epipredict 0.0.4

* add quantile_reg()
* clean up documentation bugs
* add smooth_quantile_reg()
* add classifier
* training window step debugged
* `min_train_window` argument removed from canned forecasters
- add quantile_reg()
- clean up documentation bugs
- add smooth_quantile_reg()
- add classifier
- training window step debugged
- `min_train_window` argument removed from canned forecasters

# epipredict 0.0.3

* add forecasters
* implement postprocessing
* vignettes avaliable
* arx_forecaster
* pkgdown
- add forecasters
- implement postprocessing
- vignettes avaliable
- arx_forecaster
- pkgdown

# epipredict 0.0.0.9000

* Publish public for easy navigation
* Two simple forecasters as test beds
* Working vignette
- Publish public for easy navigation
- Two simple forecasters as test beds
- Working vignette
29 changes: 27 additions & 2 deletions R/arx_classifier.R
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,21 @@ arx_class_epi_workflow <- function(
role = "outcome"
) %>%
step_epi_naomit() %>%
step_training_window(n_recent = args_list$n_training)
step_training_window(n_recent = args_list$n_training) %>%
{
if (!is.null(args_list$check_enough_data_n)) {
check_enough_train_data(
.,
all_predictors(),
!!outcome,
n = args_list$check_enough_data_n,
epi_keys = args_list$check_enough_data_epi_keys,
drop_na = FALSE
)
} else {
.
}
}

forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
Expand Down Expand Up @@ -228,6 +242,11 @@ arx_class_epi_workflow <- function(
#' @param additional_gr_args List. Optional arguments controlling growth rate
#' calculation. See [epiprocess::growth_rate()] and the related Vignette for
#' more details.
#' @param check_enough_data_n Integer. A lower limit for the number of rows per
#' epi_key that are required for training. If `NULL`, this check is ignored.
#' @param check_enough_data_epi_keys Character vector. A character vector of
#' column names on which to group the data and check threshold within each
#' group. Useful if training per group (for example, per geo_value).
#'
#' @return A list containing updated parameter choices with class `arx_clist`.
#' @export
Expand All @@ -251,6 +270,8 @@ arx_class_args_list <- function(
log_scale = FALSE,
additional_gr_args = list(),
nafill_buffer = Inf,
check_enough_data_n = NULL,
check_enough_data_epi_keys = NULL,
...) {
rlang::check_dots_empty()
.lags <- lags
Expand All @@ -275,6 +296,8 @@ arx_class_args_list <- function(
)
)
}
arg_is_pos(check_enough_data_n, allow_null = TRUE)
arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE)

breaks <- sort(breaks)
if (min(breaks) > -Inf) breaks <- c(-Inf, breaks)
Expand All @@ -296,7 +319,9 @@ arx_class_args_list <- function(
method,
log_scale,
additional_gr_args,
nafill_buffer
nafill_buffer,
check_enough_data_n,
check_enough_data_epi_keys
),
class = c("arx_class", "alist")
)
Expand Down
29 changes: 27 additions & 2 deletions R/arx_forecaster.R
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,21 @@ arx_fcast_epi_workflow <- function(
r <- r %>%
step_epi_ahead(!!outcome, ahead = args_list$ahead) %>%
step_epi_naomit() %>%
step_training_window(n_recent = args_list$n_training)
step_training_window(n_recent = args_list$n_training) %>%
{
if (!is.null(args_list$check_enough_data_n)) {
check_enough_train_data(
.,
all_predictors(),
!!outcome,
n = args_list$check_enough_data_n,
epi_keys = args_list$check_enough_data_epi_keys,
drop_na = FALSE
)
} else {
.
}
}

forecast_date <- args_list$forecast_date %||% max(epi_data$time_value)
target_date <- args_list$target_date %||% forecast_date + args_list$ahead
Expand Down Expand Up @@ -199,6 +213,11 @@ arx_fcast_epi_workflow <- function(
#' create a prediction. For this reason, setting `nafill_buffer < min(lags)`
#' will be treated as _additional_ allowed recent data rather than the
#' total amount of recent data to examine.
#' @param check_enough_data_n Integer. A lower limit for the number of rows per
#' epi_key that are required for training. If `NULL`, this check is ignored.
#' @param check_enough_data_epi_keys Character vector. A character vector of
#' column names on which to group the data and check threshold within each
#' group. Useful if training per group (for example, per geo_value).
#' @param ... Space to handle future expansions (unused).
#'
#'
Expand All @@ -220,6 +239,8 @@ arx_args_list <- function(
nonneg = TRUE,
quantile_by_key = character(0L),
nafill_buffer = Inf,
check_enough_data_n = NULL,
check_enough_data_epi_keys = NULL,
...) {
# error checking if lags is a list
rlang::check_dots_empty()
Expand All @@ -236,6 +257,8 @@ arx_args_list <- function(
arg_is_pos(n_training)
if (is.finite(n_training)) arg_is_pos_int(n_training)
if (is.finite(nafill_buffer)) arg_is_pos_int(nafill_buffer, allow_null = TRUE)
arg_is_pos(check_enough_data_n, allow_null = TRUE)
arg_is_chr(check_enough_data_epi_keys, allow_null = TRUE)

max_lags <- max(lags)
structure(
Expand All @@ -250,7 +273,9 @@ arx_args_list <- function(
nonneg,
max_lags,
quantile_by_key,
nafill_buffer
nafill_buffer,
check_enough_data_n,
check_enough_data_epi_keys
),
class = c("arx_fcast", "alist")
)
Expand Down
Loading

0 comments on commit 6ddffb2

Please sign in to comment.