Skip to content

Commit

Permalink
Add support for changing likelihood type
Browse files Browse the repository at this point in the history
  • Loading branch information
richfitz committed Sep 18, 2024
1 parent e6c564e commit abe721a
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 2 deletions.
26 changes: 25 additions & 1 deletion inst/dust/cows.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,27 @@
#include <dust2/common.hpp>
#include <numeric>

enum likelihood_type { INCIDENCE, SURVIVAL };

likelihood_type read_likelihood_type(cpp11::list pars, const char * name) {
cpp11::sexp r_likelihood_choice = pars[name];
if (r_likelihood_choice == R_NilValue) {
return INCIDENCE;
}
if (TYPEOF(r_likelihood_choice) != STRSXP || LENGTH(r_likelihood_choice) != 1) {
cpp11::stop("Expected '%s' to be a string", name);
}
std::string likelihood_choice = cpp11::as_cpp<std::string>(r_likelihood_choice);
if (likelihood_choice == "incidence") {
return INCIDENCE;
} else if (likelihood_choice == "survival") {
return SURVIVAL;
} else {
cpp11::stop("Invalid value for '%s': '%s'",
name, likelihood_choice.c_str());
}
}

template <typename real_type>
void sum_over_regions(real_type *cows,
const size_t n_herds,
Expand Down Expand Up @@ -61,6 +82,7 @@ class cows {
real_type alpha;
real_type time_test;
real_type n_test;
likelihood_type likelihood_choice;
std::vector<size_t> region_start;
std::vector<size_t> herd_to_region_lookup;
std::vector<real_type> p_region_export;
Expand Down Expand Up @@ -349,10 +371,12 @@ class cows {
dust2::r::read_real_vector(pars, n_regions, asc_rate.data(), "asc_rate", true);
}

const auto likelihood_choice = read_likelihood_type(pars, "likelihood_choice");

const bool outbreak_detection_proportion_only = dust2::r::read_bool(pars, "outbreak_detection_proportion_only", false);
const auto outbreak_detection_parameters{outbreak_detection_proportion_only};

return shared_state{n_herds, n_regions, gamma, sigma, beta, alpha, time_test, n_test, region_start, herd_to_region_lookup, p_region_export, p_cow_export, n_cows_per_herd, movement_matrix, start_count, start_herd, asc_rate, dispersion, condition_on_export, outbreak_detection_parameters};
return shared_state{n_herds, n_regions, gamma, sigma, beta, alpha, time_test, n_test, likelihood_choice, region_start, herd_to_region_lookup, p_region_export, p_cow_export, n_cows_per_herd, movement_matrix, start_count, start_herd, asc_rate, dispersion, condition_on_export, outbreak_detection_parameters};
}

static internal_state build_internal(const shared_state& shared) {
Expand Down
26 changes: 25 additions & 1 deletion src/cows.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 14 additions & 0 deletions tests/testthat/test-cowflu.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,17 @@ test_that("basic epi dynamics are reasonable", {
y <- sys$packer_state$unpack(s)
y$outbreak_region
})


test_that("likelihood choice is validated", {
pars <- test_toy_inputs()
n_particles <- 5
pars$likelihood_choice <- TRUE
expect_error(
dust2::dust_system_create(cows(), pars, n_particles = 1),
"Expected 'likelihood_choice' to be a string")
pars$likelihood_choice <- "banana"
expect_error(
dust2::dust_system_create(cows(), pars, n_particles = 1),
"Invalid value for 'likelihood_choice': 'banana'")
})

0 comments on commit abe721a

Please sign in to comment.