diff --git a/R/schedule.R b/R/schedule.R index 28b85843..24ff5330 100644 --- a/R/schedule.R +++ b/R/schedule.R @@ -83,11 +83,22 @@ get_tune_schedule <- function(wflow, param, grid) { if (has_submodels) { sched <- grid %>% dplyr::group_nest(!!!symbs$fits, .key = "predict_stage") - # Note: multi_predict() should only be triggered for a submodel parameter if + # Note 1: multi_predict() should only be triggered for a submodel parameter if # there are multiple rows in the `predict_stage` list column. i.e. the submodel # column will always be there but we only multipredict when there are 2+ # values to predict. - first_loop_info <- min_grid(model_spec, grid) + + # Note 2: The purpose of min_grid() is to determine the minimum grid for + # preprocessing and model parameters to fit. We compute it here and ignore + # any postprocessing tuning parmeters (if any). The postprocessing parameters + # will still be in the schedule since we schedule those before the results + # that use min_grid() are merged in. See issue #975 for an example and + # discussion. + first_loop_info <- + min_grid(model_spec, + grid %>% + dplyr::select(-dplyr::any_of(post_id)) %>% + dplyr::distinct()) } else { sched <- grid %>% dplyr::group_nest(!!!symbs$fits, .key = "predict_stage") diff --git a/tests/testthat/test-schedule.R b/tests/testthat/test-schedule.R index 3a0262a9..2e519aa4 100644 --- a/tests/testthat/test-schedule.R +++ b/tests/testthat/test-schedule.R @@ -645,31 +645,6 @@ test_that("grid processing schedule - recipe + model + tailor, submodels, irregu prm_used_pre_model_post, grid_pre_model_post) - # TODO trees seems to have an extra row: - - # # A tibble: 4 × 3 - # min_n predict_stage trees - # - # 1 2 1 - # 2 21 1 - # 3 40 1000 #<- shouldn't this row and the one below be combined? - # 4 40 1 - - # sched_pre_model_post$model_stage[[1]] %>% - # select(-trees) %>% - # unnest(predict_stage) %>% - # unnest(post_stage) %>% - # arrange(min_n, trees, lower_limit) - - # tibble::tribble( - # ~min_n, ~trees, ~lower_limit, ~trees0, - # 2L, 1L, 0, 1L, - # 21L, 1L, 0.5, 1L, - # 40L, 1L, 1, 1000L, - # 40L, 1L, 1, 1L, - # 40L, 1000L, 0, 1000L, - # 40L, 1000L, 0, 1L - # ) expect_named(sched_pre_model_post, c("threshold", "disp_df", "model_stage")) expect_equal( @@ -677,43 +652,61 @@ test_that("grid processing schedule - recipe + model + tailor, submodels, irregu grid_pre %>% arrange(threshold, disp_df) ) - # for (i in seq_along(sched_pre_model_post$model_stage)) { - # model_i <- sched_pre_model_post$model_stage[[i]] - # expect_named(model_i, c("min_n", "predict_stage", "trees")) - # expect_equal( - # model_i %>% select(min_n, trees) %>% arrange(min_n), - # grid_model$data[[i]] - # ) - # - # for (j in seq_along(sched_pre_model_post$model_stage[[i]]$predict_stage)) { - # predict_j <- model_i$predict_stage[[j]] - # - # # We need to figure out the trees that need predicting for the current - # # set of other parameters. - # - # # Get the settings that have already be resolved: - # other_ij <- - # model_i %>% - # select(-predict_stage, -trees) %>% - # slice(j) %>% - # vctrs::vec_cbind( - # sched_pre_model_post %>% - # select(threshold, disp_df) %>% - # slice(i) - # ) - # # What are the matching values from the grid? - # trees_ij <- - # grid_pre_model_post %>% - # inner_join(other_ij, by = c("min_n", "threshold", "disp_df")) %>% - # select(trees) - # - # - # expect_equal( - # predict_j %>% select(trees) %>% arrange(trees), - # trees_ij %>% arrange(trees) - # ) - # } - # } + for (i in seq_along(sched_pre_model_post$model_stage)) { + model_i <- sched_pre_model_post$model_stage[[i]] + + # Get the current set of preproc parameters to remove + other_i <- + sched_pre_model_post[i,] %>% + dplyr::select(-model_stage) + + # We expect to evaulate these specific models for this set of preprocessors + exp_i <- + grid_pre_model_post %>% + inner_join(other_i, by = c("threshold", "disp_df")) %>% + arrange(trees, min_n, lower_limit) %>% + select(trees, min_n, lower_limit) + + # What we will evaluate: + subgrid_i <- + model_i %>% + select(-trees) %>% + unnest(predict_stage) %>% + unnest(post_stage) %>% + arrange(trees, min_n, lower_limit) %>% + select(trees, min_n, lower_limit) + + expect_equal(subgrid_i, exp_i) + + # for (j in seq_along(sched_pre_model_post$model_stage[[i]]$predict_stage)) { + # predict_j <- model_i$predict_stage[[j]] + # + # # We need to figure out the trees that need predicting for the current + # # set of other parameters. + # + # # Get the settings that have already be resolved: + # other_ij <- + # model_i %>% + # select(-predict_stage, -trees) %>% + # slice(j) %>% + # vctrs::vec_cbind( + # sched_pre_model_post %>% + # select(threshold, disp_df) %>% + # slice(i) + # ) + # # What are the matching values from the grid? + # trees_ij <- + # grid_pre_model_post %>% + # inner_join(other_ij, by = c("min_n", "threshold", "disp_df")) %>% + # select(trees) + # + # + # expect_equal( + # predict_j %>% select(trees) %>% arrange(trees), + # trees_ij %>% arrange(trees) + # ) + # } + } expect_s3_class( sched_pre_model_post,