Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor check_tbl_col_ascending #190

Merged
merged 13 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 52 additions & 66 deletions R/check_tbl_value_col_ascending.R
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
#' If not, the check is skipped and a `<message/check_info>` condition class
#' object is returned.
#'
#' @inherit check_tbl_colnames params
#' @inherit check_tbl_values params
#' @inherit check_tbl_col_types return
#' @export
check_tbl_value_col_ascending <- function(tbl, file_path, hub_path, round_id) {
check_tbl_value_col_ascending <- function(tbl, file_path, hub_path, round_id,
derived_task_ids = get_hub_derived_task_ids(hub_path)) {
check_output_types <- intersect(c("cdf", "quantile"), unique(tbl[["output_type"]]))

# Exit early if there are no values to check
no_values_to_check <- all(!c("cdf", "quantile") %in% tbl[["output_type"]])
if (no_values_to_check) {
if (length(check_output_types) == 0L) {
return(
capture_check_info(
file_path,
Expand All @@ -25,30 +26,23 @@ check_tbl_value_col_ascending <- function(tbl, file_path, hub_path, round_id) {
)
}

# create a model output table subset to only the CDF and or quantile values
# regardless of whether they are optional or required
config_tasks <- hubUtils::read_config(hub_path, "tasks")
round_output_types <- get_round_output_type_names(config_tasks, round_id)
only_cdf_or_quantile <- intersect(c("cdf", "quantile"), round_output_types)
reference_tbl <- expand_model_out_grid(
config_tasks = config_tasks,
round_id = round_id,
all_character = FALSE,
force_output_types = TRUE,
output_types = only_cdf_or_quantile
)

# FIX for <https://github.com/hubverse-org/hubValidations/issues/78>
# sort the table by config by merging from config ----------------
tbl_sorted <- order_output_type_ids(tbl, reference_tbl)
# TODO: return an informative error or message if the table has no rows
# If this is the case, this likely means that there are invalid combinations
# of values.
output_type_tbl <- split_cdf_quantile(tbl_sorted)
if (!is.null(derived_task_ids)) {
tbl[derived_task_ids] <- NA_character_
}

# Check that values are non-decreasing for each output type separately to reduce
# memory pressure
error_tbl <- purrr::map(
output_type_tbl,
check_values_ascending
check_output_types,
\(.x) {
check_values_ascending_by_output_type(
.x, tbl,
config_tasks, round_id,
derived_task_ids
)
}
) %>%
purrr::list_rbind()

Expand All @@ -73,7 +67,41 @@ check_tbl_value_col_ascending <- function(tbl, file_path, hub_path, round_id) {
)
}

#' Check that values for each model task in specific output types are ascending
#'
#' This function allows us to map over individual output types one at a time to
#' reduce memory pressure.
#' @param output_type the output type(s) to check. Must be a character vector
#' @noRd
check_values_ascending_by_output_type <- function(output_type, tbl,
config_tasks, round_id,
derived_task_ids) {
# FIX for <https://github.com/hubverse-org/hubValidations/issues/78>
# This function splits the table by model task (via
# `expand_model_out_grid(bind_model_tasks = FALSE)`) and then performs an
# inner join to auto-sort for this particular output type regardless if the
# output type is inherently sortable.
model_task_tbls <- match_tbl_to_model_task(
tbl,
config_tasks = config_tasks,
round_id = round_id,
output_types = output_type,
derived_task_ids = derived_task_ids
) %>%
purrr::compact()

purrr::map(model_task_tbls, check_values_ascending) %>%
purrr::list_rbind()
}

#' Check that values for each model task are ascending
#'
#' @param tbl an all character table with a single output type
#' @return
#' - If the check succeeds, and all values are non-decreasing: NULL
#' - If the check fails, a summary table showing the model tasks that
#' had decreasing values for this output type
#' @noRd
check_values_ascending <- function(tbl) {
group_cols <- names(tbl)[!names(tbl) %in% hubUtils::std_colnames]
tbl[["value"]] <- as.numeric(tbl[["value"]])
Expand All @@ -93,45 +121,3 @@ check_values_ascending <- function(tbl) {
dplyr::ungroup() %>%
dplyr::mutate(.env$output_type)
}

split_cdf_quantile <- function(tbl) {
split(tbl, tbl[["output_type"]])[c("cdf", "quantile")] %>%
purrr::compact()
}

#' Order the output type ids in the order of the config
#'
#' This function uses the output from [expand_model_out_grid()] to create
#' a lookup table that contains the correct ordering for all of the output type
#' IDs. Performing an inner join with this lookup table as the reference will
#' auto sort the model output by the output type ID.
#'
#' @param tbl a model output table
#' @param reference_tbl output from [expand_model_out_grid()]
#'
#' @note
#' 1. this assumes that the output_type_id values in the `tbl` are complete,
#' which is explicitly checked by the [check_tbl_values_required()]
#' 2. this assumes that both `tbl` and `reference_tbl` have the same column
#' types
#' @noRd
#' @examples
#' reference_tbl <- data.frame(
#' target = c(rep("a", 3), rep("b", 5)),
#' output_type = rep("quantile", 8),
#' output_type_id = c("0", "0.5", "1", "0", "0.25", "0.5", "0.75", "1")
#' )
#' tbl <- reference_tbl
#' tbl$value <- c(
#' seq(from = 0, to = 1, length.out = 3),
#' seq(from = 0, to = 1, length.out = 5)
#' )
#' order_output_type_ids(tbl[sample(nrow(tbl)), ] reference_tbl)
order_output_type_ids <- function(tbl, reference_tbl) {
group_cols <- names(tbl)[!names(tbl) %in% hubUtils::std_colnames]
join_by <- c(group_cols, "output_type", "output_type_id")
lookup <- unique(reference_tbl[join_by])
tbl$output_type_id <- as.character(tbl$output_type_id)
lookup$output_type_id <- as.character(lookup$output_type_id)
dplyr::inner_join(lookup, tbl, by = join_by)
}
2 changes: 1 addition & 1 deletion R/match_tbl_to_model_task.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ match_tbl_to_model_task <- function(tbl, config_tasks, round_id,
config_tasks,
round_id = round_id,
required_vals_only = FALSE,
all_character = TRUE,
all_character = all_character,
as_arrow_table = FALSE,
bind_model_tasks = FALSE,
output_types = output_types,
Expand Down
5 changes: 3 additions & 2 deletions R/validate_model_data.R
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,11 @@ validate_model_data <- function(hub_path, file_path, round_id_col = NULL,

checks$value_col_non_desc <- try_check(
check_tbl_value_col_ascending(
tbl,
tbl_chr,
file_path = file_path,
hub_path = hub_path,
round_id = round_id
round_id = round_id,
derived_task_ids = derived_task_ids
), file_path
)

Expand Down
1 change: 1 addition & 0 deletions hubValidations.Rproj
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
Version: 1.0
ProjectId: b97d02d3-a7d4-40df-a852-afaa4ff9371e

RestoreWorkspace: No
SaveWorkspace: No
Expand Down
15 changes: 13 additions & 2 deletions man/check_tbl_value_col_ascending.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions tests/testthat/_snaps/check_tbl_value_col_ascending.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@
$ parent : NULL
$ where : chr "team1-goodmodel/2022-10-08-team1-goodmodel.csv"
$ error_tbl : tibble [1 x 5] (S3: tbl_df/tbl/data.frame)
..$ origin_date: Date[1:1], format: "2022-10-08"
..$ origin_date: chr "2022-10-08"
..$ target : chr "wk inc flu hosp"
..$ horizon : int 1
..$ horizon : chr "1"
..$ location : chr "US"
..$ output_type: chr "quantile"
$ call : chr "check_tbl_value_col_ascending"
Expand All @@ -57,9 +57,9 @@
$ parent : NULL
$ where : chr "hub-ensemble/2023-05-08-hub-ensemble.parquet"
$ error_tbl : tibble [1 x 5] (S3: tbl_df/tbl/data.frame)
..$ forecast_date: Date[1:1], format: "2023-05-08"
..$ horizon : int 1
..$ forecast_date: chr "2023-05-08"
..$ target : chr "wk ahead inc flu hosp"
..$ horizon : chr "1"
..$ location : chr "US"
..$ output_type : chr "quantile"
$ call : chr "check_tbl_value_col_ascending"
Expand All @@ -78,9 +78,9 @@
$ parent : NULL
$ where : chr "hub-ensemble/2023-05-08-hub-ensemble.parquet"
$ error_tbl : tibble [1 x 5] (S3: tbl_df/tbl/data.frame)
..$ forecast_date: Date[1:1], format: "2023-05-08"
..$ horizon : int 1
..$ forecast_date: chr "2023-05-08"
..$ target : chr "wk ahead inc flu hosp"
..$ horizon : chr "1"
..$ location : chr "US"
..$ output_type : chr "quantile"
$ call : chr "check_tbl_value_col_ascending"
Expand Down
51 changes: 25 additions & 26 deletions tests/testthat/test-check_tbl_value_col_ascending.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ test_that("check_tbl_value_col_ascending works", {
hub_path <- system.file("testhubs/simple", package = "hubValidations")
file_path <- "team1-goodmodel/2022-10-08-team1-goodmodel.csv"
file_meta <- parse_file_name(file_path)
tbl <- hubValidations::read_model_out_file(file_path, hub_path)
tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr")

expect_snapshot(
check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
Expand All @@ -12,7 +12,7 @@ test_that("check_tbl_value_col_ascending works", {
file_path <- "hub-ensemble/2023-05-08-hub-ensemble.parquet"
file_meta <- parse_file_name(file_path)

tbl <- hubValidations::read_model_out_file(file_path, hub_path)
tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr")

expect_snapshot(
check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
Expand All @@ -22,7 +22,7 @@ test_that("check_tbl_value_col_ascending works", {
test_that("check_tbl_value_col_ascending works when output type IDs not ordered", {
hub_path <- test_path("testdata/hub-unordered/")
file_path <- "ISI-NotOrdered/2024-01-10-ISI-NotOrdered.csv"
tbl <- read_model_out_file(file_path, hub_path)
tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr")
file_meta <- parse_file_name(file_path)
expect_snapshot(
check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
Expand All @@ -33,7 +33,7 @@ test_that("check_tbl_value_col_ascending errors correctly", {
hub_path <- system.file("testhubs/simple", package = "hubValidations")
file_path <- "team1-goodmodel/2022-10-08-team1-goodmodel.csv"
file_meta <- parse_file_name(file_path)
tbl <- hubValidations::read_model_out_file(file_path, hub_path)
tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr")

tbl$value[c(1, 10)] <- 150

Expand All @@ -44,7 +44,7 @@ test_that("check_tbl_value_col_ascending errors correctly", {
hub_path <- system.file("testhubs/flusight", package = "hubUtils")
file_path <- "hub-ensemble/2023-05-08-hub-ensemble.parquet"
file_meta <- parse_file_name(file_path)
tbl <- hubValidations::read_model_out_file(file_path, hub_path)
tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr")
tbl_error <- tbl
# TODO: 2025-01-07 investigate the purpose of adding an invalid target, which
# causes the test to fail
Expand Down Expand Up @@ -72,7 +72,7 @@ test_that("check_tbl_value_col_ascending skips correctly", {
hub_path <- system.file("testhubs/simple", package = "hubValidations")
file_path <- "team1-goodmodel/2022-10-08-team1-goodmodel.csv"
file_meta <- parse_file_name(file_path)
tbl <- hubValidations::read_model_out_file(file_path, hub_path)
tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr")
tbl <- tbl[tbl$output_type == "mean", ]

expect_snapshot(
Expand Down Expand Up @@ -116,9 +116,10 @@ test_that("(#78) check_tbl_value_col_ascending will sort even if the data doesn'
convert_to_cdf <- function(x) {
ifelse(x == "quantile", "cdf", x)
}
tbl <- hubValidations::read_model_out_file(file_path, hub_path) %>%
tbl <- read_model_out_file(file_path, hub_path) %>%
dplyr::mutate(output_type_id = make_unsortable(.data[["output_type_id"]])) %>%
dplyr::mutate(output_type = convert_to_cdf(.data[["output_type"]]))
dplyr::mutate(output_type = convert_to_cdf(.data[["output_type"]])) %>%
hubData::coerce_to_character()

# validating when it is sorted -----------------------------------------
res <- check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
Expand Down Expand Up @@ -149,9 +150,9 @@ test_that("(#78) check_tbl_value_col_ascending will sort even if the data doesn'
file_meta$round_id
)
expected <- tibble::tibble(
origin_date = as.Date("2022-10-08"),
origin_date = "2022-10-08",
target = "wk inc flu hosp",
horizon = 1,
horizon = "1",
location = "US",
output_type = "cdf"
)
Expand All @@ -161,30 +162,28 @@ test_that("(#78) check_tbl_value_col_ascending will sort even if the data doesn'
expect_equal(actual, expected, ignore_attr = TRUE)
})


test_that("(#78) check_tbl_value_col_ascending works when output type IDs differ by target", {
hub_path <- test_path("testdata/hub-diff-otid-per-task/")
file_path <- "ISI-NotOrdered/2024-01-10-ILI-model.csv"
tbl <- hubValidations::read_model_out_file(file_path, hub_path)
tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr")
file_meta <- parse_file_name(file_path)

res_ok <- check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
expect_s3_class(res_ok, "check_success")
expect_null(res_ok$error_tbl)
})

test_that("(#78) order_output_type_ids() can handle separate model tasks", {
# <https://github.com/hubverse-org/hubValidations/pull/105/files#r1904460868>
reference_tbl <- data.frame(
target = c(rep("a", 3), rep("b", 5)),
output_type = rep("quantile", 8),
output_type_id = c("0", "0.5", "1", "0", "0.25", "0.5", "0.75", "1")
)
tbl <- reference_tbl
tbl$value <- c(
seq(from = 0, to = 1, length.out = 3),
seq(from = 0, to = 1, length.out = 5)
)
expect_null(check_values_ascending(tbl))
expect_null(order_output_type_ids(tbl, reference_tbl) |> check_values_ascending())
test_that("(#189) check_tbl_value_col_ascending ignores derived task IDs", {
hub_path <- test_path("testdata/hub-177")
file_path <- "FluSight-baseline/2024-12-14-FluSight-baseline.parquet"
tbl <- read_model_out_file(file_path, hub_path, coerce_types = "chr")
file_meta <- parse_file_name(file_path)

# Introduce invalid value to derived task id that should be ignored when using
# `derived_task_ids`.
tbl[1, "target_end_date"] <- "random_date"

res_ok <- check_tbl_value_col_ascending(tbl, file_path, hub_path, file_meta$round_id)
expect_s3_class(res_ok, "check_success")
expect_null(res_ok$error_tbl)
})
Loading