diff --git a/inst/dust/cows.cpp b/inst/dust/cows.cpp index 1ed01a7..7257799 100644 --- a/inst/dust/cows.cpp +++ b/inst/dust/cows.cpp @@ -1,6 +1,27 @@ #include #include +enum likelihood_type { INCIDENCE, SURVIVAL }; + +likelihood_type read_likelihood_type(cpp11::list pars, const char * name) { + cpp11::sexp r_likelihood_choice = pars[name]; + if (r_likelihood_choice == R_NilValue) { + return INCIDENCE; + } + if (TYPEOF(r_likelihood_choice) != STRSXP || LENGTH(r_likelihood_choice) != 1) { + cpp11::stop("Expected '%s' to be a string", name); + } + std::string likelihood_choice = cpp11::as_cpp(r_likelihood_choice); + if (likelihood_choice == "incidence") { + return INCIDENCE; + } else if (likelihood_choice == "survival") { + return SURVIVAL; + } else { + cpp11::stop("Invalid value for '%s': '%s'", + name, likelihood_choice.c_str()); + } +} + template void sum_over_regions(real_type *cows, const size_t n_herds, @@ -61,6 +82,7 @@ class cows { real_type alpha; real_type time_test; real_type n_test; + likelihood_type likelihood_choice; std::vector region_start; std::vector herd_to_region_lookup; std::vector p_region_export; @@ -349,10 +371,12 @@ class cows { dust2::r::read_real_vector(pars, n_regions, asc_rate.data(), "asc_rate", true); } + const auto likelihood_choice = read_likelihood_type(pars, "likelihood_choice"); + const bool outbreak_detection_proportion_only = dust2::r::read_bool(pars, "outbreak_detection_proportion_only", false); const auto outbreak_detection_parameters{outbreak_detection_proportion_only}; - return shared_state{n_herds, n_regions, gamma, sigma, beta, alpha, time_test, n_test, region_start, herd_to_region_lookup, p_region_export, p_cow_export, n_cows_per_herd, movement_matrix, start_count, start_herd, asc_rate, dispersion, condition_on_export, outbreak_detection_parameters}; + return shared_state{n_herds, n_regions, gamma, sigma, beta, alpha, time_test, n_test, likelihood_choice, region_start, herd_to_region_lookup, p_region_export, p_cow_export, n_cows_per_herd, movement_matrix, start_count, start_herd, asc_rate, dispersion, condition_on_export, outbreak_detection_parameters}; } static internal_state build_internal(const shared_state& shared) { diff --git a/src/cows.cpp b/src/cows.cpp index 7100f37..5b95801 100644 --- a/src/cows.cpp +++ b/src/cows.cpp @@ -3,6 +3,27 @@ #include #include +enum likelihood_type { INCIDENCE, SURVIVAL }; + +likelihood_type read_likelihood_type(cpp11::list pars, const char * name) { + cpp11::sexp r_likelihood_choice = pars[name]; + if (r_likelihood_choice == R_NilValue) { + return INCIDENCE; + } + if (TYPEOF(r_likelihood_choice) != STRSXP || LENGTH(r_likelihood_choice) != 1) { + cpp11::stop("Expected '%s' to be a string", name); + } + std::string likelihood_choice = cpp11::as_cpp(r_likelihood_choice); + if (likelihood_choice == "incidence") { + return INCIDENCE; + } else if (likelihood_choice == "survival") { + return SURVIVAL; + } else { + cpp11::stop("Invalid value for '%s': '%s'", + name, likelihood_choice.c_str()); + } +} + template void sum_over_regions(real_type *cows, const size_t n_herds, @@ -63,6 +84,7 @@ class cows { real_type alpha; real_type time_test; real_type n_test; + likelihood_type likelihood_choice; std::vector region_start; std::vector herd_to_region_lookup; std::vector p_region_export; @@ -351,10 +373,12 @@ class cows { dust2::r::read_real_vector(pars, n_regions, asc_rate.data(), "asc_rate", true); } + const auto likelihood_choice = read_likelihood_type(pars, "likelihood_choice"); + const bool outbreak_detection_proportion_only = dust2::r::read_bool(pars, "outbreak_detection_proportion_only", false); const auto outbreak_detection_parameters{outbreak_detection_proportion_only}; - return shared_state{n_herds, n_regions, gamma, sigma, beta, alpha, time_test, n_test, region_start, herd_to_region_lookup, p_region_export, p_cow_export, n_cows_per_herd, movement_matrix, start_count, start_herd, asc_rate, dispersion, condition_on_export, outbreak_detection_parameters}; + return shared_state{n_herds, n_regions, gamma, sigma, beta, alpha, time_test, n_test, likelihood_choice, region_start, herd_to_region_lookup, p_region_export, p_cow_export, n_cows_per_herd, movement_matrix, start_count, start_herd, asc_rate, dispersion, condition_on_export, outbreak_detection_parameters}; } static internal_state build_internal(const shared_state& shared) { diff --git a/tests/testthat/test-cowflu.R b/tests/testthat/test-cowflu.R index 10f3365..32d8eb1 100644 --- a/tests/testthat/test-cowflu.R +++ b/tests/testthat/test-cowflu.R @@ -44,3 +44,17 @@ test_that("basic epi dynamics are reasonable", { y <- sys$packer_state$unpack(s) y$outbreak_region }) + + +test_that("likelihood choice is validated", { + pars <- test_toy_inputs() + n_particles <- 5 + pars$likelihood_choice <- TRUE + expect_error( + dust2::dust_system_create(cows(), pars, n_particles = 1), + "Expected 'likelihood_choice' to be a string") + pars$likelihood_choice <- "banana" + expect_error( + dust2::dust_system_create(cows(), pars, n_particles = 1), + "Invalid value for 'likelihood_choice': 'banana'") +})