Skip to content

Commit

Permalink
Merge pull request #7 from mrc-ide/cows-4
Browse files Browse the repository at this point in the history
Improve migration
  • Loading branch information
richfitz authored Aug 16, 2024
2 parents 38fca37 + 903b194 commit 963c7bb
Show file tree
Hide file tree
Showing 4 changed files with 135 additions and 22 deletions.
53 changes: 45 additions & 8 deletions inst/dust/cows.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ class cows {
std::vector<real_type> movement_matrix;
real_type start_count;
size_t start_herd;
bool condition_on_export;
};

struct internal_state {
Expand Down Expand Up @@ -151,13 +152,39 @@ class cows {
// Above, we change the populations (we do this BEFORE calculating import/exports)
for (size_t i = 0; i < shared.n_herds; ++i) {
const auto j = shared.herd_to_region_lookup[i];
const auto export_cows = mcstate::random::random_real<real_type>(rng_state) < shared.p_region_export[j] * dt;
const auto export_cows = internal.N[i] > 0 && mcstate::random::random_real<real_type>(rng_state) < shared.p_region_export[j] * dt;
if (export_cows) {
const auto p_cow_export = shared.p_cow_export[j] * dt;
internal.export_S[i] = mcstate::random::binomial<real_type>(rng_state, S_next[i], p_cow_export);
internal.export_E[i] = mcstate::random::binomial<real_type>(rng_state, E_next[i], p_cow_export);
internal.export_I[i] = mcstate::random::binomial<real_type>(rng_state, I_next[i], p_cow_export);
internal.export_R[i] = mcstate::random::binomial<real_type>(rng_state, R_next[i], p_cow_export);
const auto p_cow_export = shared.p_cow_export[j] * dt; // TODO: proper conversion to probability needed
// Option 1: rejection sampling:
size_t n_exported = 0;
do {
internal.export_S[i] = mcstate::random::binomial<real_type>(rng_state, S_next[i], p_cow_export);
internal.export_E[i] = mcstate::random::binomial<real_type>(rng_state, E_next[i], p_cow_export);
internal.export_I[i] = mcstate::random::binomial<real_type>(rng_state, I_next[i], p_cow_export);
internal.export_R[i] = mcstate::random::binomial<real_type>(rng_state, R_next[i], p_cow_export);
n_exported = internal.export_S[i] + internal.export_E[i] + internal.export_I[i] + internal.export_R[i];
} while (shared.condition_on_export && n_exported == 0);
// Option 2:
//
// Sample the number of cows in each compartment from a
// beta-binomial, sharing the beta draw across the four draws,
// but redrawing each time around the rejection.
//
// Option 3:
//
// If p is very small, then sample from a conditioned binomial
// for the total over all cows, then draw SEIR allocation from
// a multivartiate hypergeometric, which is not actually
// implemented in mcstate2 yet.
}
}

// Convert N into cumulative counts within a region:
for (size_t i = 0; i < shared.n_regions; ++i) {
const size_t i_start = shared.region_start[i];
const size_t i_end = shared.region_start[i + 1];
for (size_t j = i_start + 1; j < i_end; ++j) {
internal.N[j] += internal.N[j - 1];
}
}

Expand All @@ -171,7 +198,15 @@ class cows {
const size_t i_region_dst = std::distance(p, std::upper_bound(p, p + shared.n_regions, u1));
const auto within_region = i_region_src == i_region_dst;
const real_type u2 = mcstate::random::random_real<real_type>(rng_state);
const size_t i_dst = shared.region_start[i_region_dst] + std::floor(u2 * (shared.region_start[i_region_dst + 1] - shared.region_start[i_region_dst]));

const auto i_region_start = shared.region_start[i_region_dst];
const auto i_region_end = shared.region_start[i_region_dst + 1];

const size_t n_herds_in_region = i_region_end - i_region_start;
const size_t n_cows_in_region = internal.N[i_region_end - 1];
const auto it_N = internal.N.begin() + i_region_start;
const size_t i_dst = std::distance(it_N, std::upper_bound(it_N, it_N + n_herds_in_region, u2 * n_cows_in_region));

const bool allow_movement = within_region || state_travel_allowed ||
mcstate::random::hypergeometric(rng_state, internal.export_I[i_src], export_N - internal.export_I[i_src], std::min(shared.n_test, internal.export_I[i_src])) == 0;
if (allow_movement) {
Expand Down Expand Up @@ -229,6 +264,8 @@ class cows {
std::vector<real_type> movement_matrix(n_regions * n_regions);
dust2::r::read_real_vector(pars, n_regions * n_regions, movement_matrix.data(), "movement_matrix", true);

const bool condition_on_export = dust2::r::read_bool(pars, "condition_on_export", false);

const real_type time_test = dust2::r::read_real(pars, "time_test", 30);
const real_type n_test = dust2::r::read_real(pars, "n_test", 30);

Expand All @@ -240,7 +277,7 @@ class cows {
const real_type alpha = dust2::r::read_real(pars, "alpha");
const real_type sigma = dust2::r::read_real(pars, "sigma");

return shared_state{n_herds, n_regions, gamma, sigma, beta, alpha, time_test, n_test, region_start, herd_to_region_lookup, p_region_export, p_cow_export, n_cows_per_herd, movement_matrix, start_count, start_herd};
return shared_state{n_herds, n_regions, gamma, sigma, beta, alpha, time_test, n_test, region_start, herd_to_region_lookup, p_region_export, p_cow_export, n_cows_per_herd, movement_matrix, start_count, start_herd, condition_on_export};
}

static internal_state build_internal(const shared_state& shared) {
Expand Down
53 changes: 45 additions & 8 deletions src/cows.cpp

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 17 additions & 6 deletions tests/testthat/helper-cowflu.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,18 @@
test_fixed_inputs <- function() {
inputs <- readRDS("inputs.rds")
cowflu_fixed_inputs(inputs$data$p_region_export,
inputs$data$p_cow_export,
inputs$movement_matrix,
which(inputs$data$name == "Texas"))
## Simple toy example, symmetric migration with three regions
test_toy_inputs <- function(alpha = 0.2, beta = 0.9, gamma = 0.1,
sigma = 0.125, start_count = 5) {
cowflu_inputs(
alpha = alpha,
beta = beta,
gamma = gamma,
sigma = sigma,
cowflu_fixed_inputs(
n_herds_per_region = c(3, 7, 11),
p_region_export = c(.5, .5, .5),
p_cow_export = c(0.2, 0.2, 0.2),
n_cows_per_herd = c(rep(200, 3), rep(1000, 7), rep(3000, 11)),
movement_matrix = cbind(c(.6, .2, .2), c(.2, .6, .2), c(.2, .2, .6)),
time_test = 10000,
start_herd = 4,
start_count = start_count))
}
28 changes: 28 additions & 0 deletions tests/testthat/test-cowflu.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
test_that("basic epi dynamics are reasonable", {
pars <- test_toy_inputs()
n_particles <- 3
times <- 0:75
sys <- dust2::dust_system_create(cows(), pars, n_particles = n_particles, dt = 0.25)
dust2::dust_system_set_state_initial(sys)
s0 <- array(dust2::dust_system_state(sys),
c(pars$n_herds + pars$n_regions, 4))
s <- dust2::dust_system_simulate(sys, times)
s1 <- array(s, c(pars$n_herds + pars$n_regions, 4, n_particles, length(times)))

s1_herds <- s1[-(22:24), , , ]
s1_total <- s1[22:24, , , ]

tot <- apply(s1_herds, 2:4, sum)
expect_true(all(diff(t(tot[1, , ])) <= 0)) # S decreases
expect_true(all(diff(t(tot[4, , ])) >= 0)) # R increases

i1 <- 1:3
i2 <- 4:10
i3 <- 11:21
expect_equal(apply(s1_herds[i1, , , ], 2:4, sum),
s1_total[1, , , ])
expect_equal(apply(s1_herds[i2, , , ], 2:4, sum),
s1_total[2, , , ])
expect_equal(apply(s1_herds[i3, , , ], 2:4, sum),
s1_total[3, , , ])
})

0 comments on commit 963c7bb

Please sign in to comment.