diff --git a/R/check_tbl_value_col_ascending.R b/R/check_tbl_value_col_ascending.R index a9e311de..7db040d2 100644 --- a/R/check_tbl_value_col_ascending.R +++ b/R/check_tbl_value_col_ascending.R @@ -8,14 +8,15 @@ #' If not, the check is skipped and a `` condition class #' object is returned. #' -#' @inherit check_tbl_colnames params +#' @inherit check_tbl_values params #' @inherit check_tbl_col_types return #' @export -check_tbl_value_col_ascending <- function(tbl, file_path, hub_path, round_id) { +check_tbl_value_col_ascending <- function(tbl, file_path, hub_path, round_id, + derived_task_ids = get_hub_derived_task_ids(hub_path)) { + check_output_types <- intersect(c("cdf", "quantile"), unique(tbl[["output_type"]])) # Exit early if there are no values to check - no_values_to_check <- all(!c("cdf", "quantile") %in% tbl[["output_type"]]) - if (no_values_to_check) { + if (length(check_output_types) == 0L) { return( capture_check_info( file_path, @@ -25,30 +26,23 @@ check_tbl_value_col_ascending <- function(tbl, file_path, hub_path, round_id) { ) } - # create a model output table subset to only the CDF and or quantile values - # regardless of whether they are optional or required config_tasks <- hubUtils::read_config(hub_path, "tasks") - round_output_types <- get_round_output_type_names(config_tasks, round_id) - only_cdf_or_quantile <- intersect(c("cdf", "quantile"), round_output_types) - reference_tbl <- expand_model_out_grid( - config_tasks = config_tasks, - round_id = round_id, - all_character = FALSE, - force_output_types = TRUE, - output_types = only_cdf_or_quantile - ) - # FIX for - # sort the table by config by merging from config ---------------- - tbl_sorted <- order_output_type_ids(tbl, reference_tbl) - # TODO: return an informative error or message if the table has no rows - # If this is the case, this likely means that there are invalid combinations - # of values. - output_type_tbl <- split_cdf_quantile(tbl_sorted) + if (!is.null(derived_task_ids)) { + tbl[derived_task_ids] <- NA_character_ + } + # Check that values are non-decreasing for each output type separately to reduce + # memory pressure error_tbl <- purrr::map( - output_type_tbl, - check_values_ascending + check_output_types, + \(.x) { + check_values_ascending_by_output_type( + .x, tbl, + config_tasks, round_id, + derived_task_ids + ) + } ) %>% purrr::list_rbind() @@ -73,7 +67,41 @@ check_tbl_value_col_ascending <- function(tbl, file_path, hub_path, round_id) { ) } +#' Check that values for each model task in specific output types are ascending +#' +#' This function allows us to map over individual output types one at a time to +#' reduce memory pressure. +#' @param output_type the output type(s) to check. Must be a character vector +#' @noRd +check_values_ascending_by_output_type <- function(output_type, tbl, + config_tasks, round_id, + derived_task_ids) { + # FIX for + # This function splits the table by model task (via + # `expand_model_out_grid(bind_model_tasks = FALSE)`) and then performs an + # inner join to auto-sort for this particular output type regardless if the + # output type is inherently sortable. + model_task_tbls <- match_tbl_to_model_task( + tbl, + config_tasks = config_tasks, + round_id = round_id, + output_types = output_type, + derived_task_ids = derived_task_ids + ) %>% + purrr::compact() + + purrr::map(model_task_tbls, check_values_ascending) %>% + purrr::list_rbind() +} +#' Check that values for each model task are ascending +#' +#' @param tbl an all character table with a single output type +#' @return +#' - If the check succeeds, and all values are non-decreasing: NULL +#' - If the check fails, a summary table showing the model tasks that +#' had decreasing values for this output type +#' @noRd check_values_ascending <- function(tbl) { group_cols <- names(tbl)[!names(tbl) %in% hubUtils::std_colnames] tbl[["value"]] <- as.numeric(tbl[["value"]]) @@ -93,45 +121,3 @@ check_values_ascending <- function(tbl) { dplyr::ungroup() %>% dplyr::mutate(.env$output_type) } - -split_cdf_quantile <- function(tbl) { - split(tbl, tbl[["output_type"]])[c("cdf", "quantile")] %>% - purrr::compact() -} - -#' Order the output type ids in the order of the config -#' -#' This function uses the output from [expand_model_out_grid()] to create -#' a lookup table that contains the correct ordering for all of the output type -#' IDs. Performing an inner join with this lookup table as the reference will -#' auto sort the model output by the output type ID. -#' -#' @param tbl a model output table -#' @param reference_tbl output from [expand_model_out_grid()] -#' -#' @note -#' 1. this assumes that the output_type_id values in the `tbl` are complete, -#' which is explicitly checked by the [check_tbl_values_required()] -#' 2. this assumes that both `tbl` and `reference_tbl` have the same column -#' types -#' @noRd -#' @examples -#' reference_tbl <- data.frame( -#' target = c(rep("a", 3), rep("b", 5)), -#' output_type = rep("quantile", 8), -#' output_type_id = c("0", "0.5", "1", "0", "0.25", "0.5", "0.75", "1") -#' ) -#' tbl <- reference_tbl -#' tbl$value <- c( -#' seq(from = 0, to = 1, length.out = 3), -#' seq(from = 0, to = 1, length.out = 5) -#' ) -#' order_output_type_ids(tbl[sample(nrow(tbl)), ] reference_tbl) -order_output_type_ids <- function(tbl, reference_tbl) { - group_cols <- names(tbl)[!names(tbl) %in% hubUtils::std_colnames] - join_by <- c(group_cols, "output_type", "output_type_id") - lookup <- unique(reference_tbl[join_by]) - tbl$output_type_id <- as.character(tbl$output_type_id) - lookup$output_type_id <- as.character(lookup$output_type_id) - dplyr::inner_join(lookup, tbl, by = join_by) -} diff --git a/R/match_tbl_to_model_task.R b/R/match_tbl_to_model_task.R index 089192e7..64132855 100644 --- a/R/match_tbl_to_model_task.R +++ b/R/match_tbl_to_model_task.R @@ -36,7 +36,7 @@ match_tbl_to_model_task <- function(tbl, config_tasks, round_id, config_tasks, round_id = round_id, required_vals_only = FALSE, - all_character = TRUE, + all_character = all_character, as_arrow_table = FALSE, bind_model_tasks = FALSE, output_types = output_types, diff --git a/R/validate_model_data.R b/R/validate_model_data.R index 8f4963b6..c86e43e5 100644 --- a/R/validate_model_data.R +++ b/R/validate_model_data.R @@ -211,10 +211,11 @@ validate_model_data <- function(hub_path, file_path, round_id_col = NULL, checks$value_col_non_desc <- try_check( check_tbl_value_col_ascending( - tbl, + tbl_chr, file_path = file_path, hub_path = hub_path, - round_id = round_id + round_id = round_id, + derived_task_ids = derived_task_ids ), file_path ) diff --git a/hubValidations.Rproj b/hubValidations.Rproj index 69fafd4b..be951660 100644 --- a/hubValidations.Rproj +++ b/hubValidations.Rproj @@ -1,4 +1,5 @@ Version: 1.0 +ProjectId: b97d02d3-a7d4-40df-a852-afaa4ff9371e RestoreWorkspace: No SaveWorkspace: No diff --git a/man/check_tbl_value_col_ascending.Rd b/man/check_tbl_value_col_ascending.Rd index 3679cae2..a93ed343 100644 --- a/man/check_tbl_value_col_ascending.Rd +++ b/man/check_tbl_value_col_ascending.Rd @@ -5,10 +5,16 @@ \title{Check that \code{quantile} and \code{cdf} output type values of model output data are non-descending} \usage{ -check_tbl_value_col_ascending(tbl, file_path, hub_path, round_id) +check_tbl_value_col_ascending( + tbl, + file_path, + hub_path, + round_id, + derived_task_ids = get_hub_derived_task_ids(hub_path) +) } \arguments{ -\item{tbl}{a tibble/data.frame of the contents of the file being validated.} +\item{tbl}{a tibble/data.frame of the contents of the file being validated. Column types must \strong{all be character}.} \item{file_path}{character string. Path to the file being validated relative to the hub's model-output directory.} @@ -24,6 +30,11 @@ The hub must be fully configured with valid \code{admin.json} and \code{tasks.js files within the \code{hub-config} directory.} \item{round_id}{character string. The round identifier.} + +\item{derived_task_ids}{Character vector of derived task ID names (task IDs whose +values depend on other task IDs) to ignore. Columns for such task ids will +contain \code{NA}s. Defaults to extracting derived task IDs from hub \code{task.json}. See +\code{\link[=get_hub_derived_task_ids]{get_hub_derived_task_ids()}} for more details.} } \value{ Depending on whether validation has succeeded, one of: diff --git a/tests/testthat/_snaps/check_tbl_value_col_ascending.md b/tests/testthat/_snaps/check_tbl_value_col_ascending.md index eaa917eb..9ca877ae 100644 --- a/tests/testthat/_snaps/check_tbl_value_col_ascending.md +++ b/tests/testthat/_snaps/check_tbl_value_col_ascending.md @@ -36,9 +36,9 @@ $ parent : NULL $ where : chr "team1-goodmodel/2022-10-08-team1-goodmodel.csv" $ error_tbl : tibble [1 x 5] (S3: tbl_df/tbl/data.frame) - ..$ origin_date: Date[1:1], format: "2022-10-08" + ..$ origin_date: chr "2022-10-08" ..$ target : chr "wk inc flu hosp" - ..$ horizon : int 1 + ..$ horizon : chr "1" ..$ location : chr "US" ..$ output_type: chr "quantile" $ call : chr "check_tbl_value_col_ascending" @@ -57,9 +57,9 @@ $ parent : NULL $ where : chr "hub-ensemble/2023-05-08-hub-ensemble.parquet" $ error_tbl : tibble [1 x 5] (S3: tbl_df/tbl/data.frame) - ..$ forecast_date: Date[1:1], format: "2023-05-08" - ..$ horizon : int 1 + ..$ forecast_date: chr "2023-05-08" ..$ target : chr "wk ahead inc flu hosp" + ..$ horizon : chr "1" ..$ location : chr "US" ..$ output_type : chr "quantile" $ call : chr "check_tbl_value_col_ascending" @@ -78,9 +78,9 @@ $ parent : NULL $ where : chr "hub-ensemble/2023-05-08-hub-ensemble.parquet" $ error_tbl : tibble [1 x 5] (S3: tbl_df/tbl/data.frame) - ..$ forecast_date: Date[1:1], format: "2023-05-08" - ..$ horizon : int 1 + ..$ forecast_date: chr "2023-05-08" ..$ target : chr "wk ahead inc flu hosp" + ..$ horizon : chr "1" ..$ location : chr "US" ..$ output_type : chr "quantile" $ call : chr "check_tbl_value_col_ascending" diff --git a/tests/testthat/test-check_tbl_value_col_ascending.R b/tests/testthat/test-check_tbl_value_col_ascending.R index 32216f45..9cb1f274 100644 --- a/tests/testthat/test-check_tbl_value_col_ascending.R +++ b/tests/testthat/test-check_tbl_value_col_ascending.R @@ -2,7 +2,7 @@ test_that("check_tbl_value_col_ascending works", { hub_path <- system.file("testhubs/simple", package = "hubValidations") file_path <- "team1-goodmodel/2022-10-08-team1-goodmodel.csv" file_meta <- parse_file_name(file_path) - tbl <- hubValidations::read_model_out_file(file_path, hub_path) + tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr") expect_snapshot( check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id) @@ -12,7 +12,7 @@ test_that("check_tbl_value_col_ascending works", { file_path <- "hub-ensemble/2023-05-08-hub-ensemble.parquet" file_meta <- parse_file_name(file_path) - tbl <- hubValidations::read_model_out_file(file_path, hub_path) + tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr") expect_snapshot( check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id) @@ -22,7 +22,7 @@ test_that("check_tbl_value_col_ascending works", { test_that("check_tbl_value_col_ascending works when output type IDs not ordered", { hub_path <- test_path("testdata/hub-unordered/") file_path <- "ISI-NotOrdered/2024-01-10-ISI-NotOrdered.csv" - tbl <- read_model_out_file(file_path, hub_path) + tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr") file_meta <- parse_file_name(file_path) expect_snapshot( check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id) @@ -33,7 +33,7 @@ test_that("check_tbl_value_col_ascending errors correctly", { hub_path <- system.file("testhubs/simple", package = "hubValidations") file_path <- "team1-goodmodel/2022-10-08-team1-goodmodel.csv" file_meta <- parse_file_name(file_path) - tbl <- hubValidations::read_model_out_file(file_path, hub_path) + tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr") tbl$value[c(1, 10)] <- 150 @@ -44,7 +44,7 @@ test_that("check_tbl_value_col_ascending errors correctly", { hub_path <- system.file("testhubs/flusight", package = "hubUtils") file_path <- "hub-ensemble/2023-05-08-hub-ensemble.parquet" file_meta <- parse_file_name(file_path) - tbl <- hubValidations::read_model_out_file(file_path, hub_path) + tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr") tbl_error <- tbl # TODO: 2025-01-07 investigate the purpose of adding an invalid target, which # causes the test to fail @@ -72,7 +72,7 @@ test_that("check_tbl_value_col_ascending skips correctly", { hub_path <- system.file("testhubs/simple", package = "hubValidations") file_path <- "team1-goodmodel/2022-10-08-team1-goodmodel.csv" file_meta <- parse_file_name(file_path) - tbl <- hubValidations::read_model_out_file(file_path, hub_path) + tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr") tbl <- tbl[tbl$output_type == "mean", ] expect_snapshot( @@ -116,9 +116,10 @@ test_that("(#78) check_tbl_value_col_ascending will sort even if the data doesn' convert_to_cdf <- function(x) { ifelse(x == "quantile", "cdf", x) } - tbl <- hubValidations::read_model_out_file(file_path, hub_path) %>% + tbl <- read_model_out_file(file_path, hub_path) %>% dplyr::mutate(output_type_id = make_unsortable(.data[["output_type_id"]])) %>% - dplyr::mutate(output_type = convert_to_cdf(.data[["output_type"]])) + dplyr::mutate(output_type = convert_to_cdf(.data[["output_type"]])) %>% + hubData::coerce_to_character() # validating when it is sorted ----------------------------------------- res <- check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id) @@ -149,9 +150,9 @@ test_that("(#78) check_tbl_value_col_ascending will sort even if the data doesn' file_meta$round_id ) expected <- tibble::tibble( - origin_date = as.Date("2022-10-08"), + origin_date = "2022-10-08", target = "wk inc flu hosp", - horizon = 1, + horizon = "1", location = "US", output_type = "cdf" ) @@ -161,11 +162,10 @@ test_that("(#78) check_tbl_value_col_ascending will sort even if the data doesn' expect_equal(actual, expected, ignore_attr = TRUE) }) - test_that("(#78) check_tbl_value_col_ascending works when output type IDs differ by target", { hub_path <- test_path("testdata/hub-diff-otid-per-task/") file_path <- "ISI-NotOrdered/2024-01-10-ILI-model.csv" - tbl <- hubValidations::read_model_out_file(file_path, hub_path) + tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr") file_meta <- parse_file_name(file_path) res_ok <- check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id) @@ -173,18 +173,17 @@ test_that("(#78) check_tbl_value_col_ascending works when output type IDs differ expect_null(res_ok$error_tbl) }) -test_that("(#78) order_output_type_ids() can handle separate model tasks", { - # - reference_tbl <- data.frame( - target = c(rep("a", 3), rep("b", 5)), - output_type = rep("quantile", 8), - output_type_id = c("0", "0.5", "1", "0", "0.25", "0.5", "0.75", "1") - ) - tbl <- reference_tbl - tbl$value <- c( - seq(from = 0, to = 1, length.out = 3), - seq(from = 0, to = 1, length.out = 5) - ) - expect_null(check_values_ascending(tbl)) - expect_null(order_output_type_ids(tbl, reference_tbl) |> check_values_ascending()) +test_that("(#189) check_tbl_value_col_ascending ignores derived task IDs", { + hub_path <- test_path("testdata/hub-177") + file_path <- "FluSight-baseline/2024-12-14-FluSight-baseline.parquet" + tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr") + file_meta <- parse_file_name(file_path) + + # Introduce invalid value to derived task id that should be ignored when using + # `derived_task_ids`. + tbl[1, "target_end_date"] <- "random_date" + + res_ok <- check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id) + expect_s3_class(res_ok, "check_success") + expect_null(res_ok$error_tbl) })