From 7a118d12df58e1fcb486e5887990af07cb1ab344 Mon Sep 17 00:00:00 2001 From: Paul Lietar Date: Wed, 10 Jul 2024 14:32:19 +0100 Subject: [PATCH] Try to optimize competing hazards --- R/competing_hazards.R | 57 +++++++++++++++------ R/disease_progression.R | 5 +- R/human_infection.R | 6 +-- R/processes.R | 3 +- tests/testthat/test-competing-hazards.R | 21 +++++--- tests/testthat/test-infection-integration.R | 2 +- tests/testthat/test-pev.R | 2 +- 7 files changed, 66 insertions(+), 30 deletions(-) diff --git a/R/competing_hazards.R b/R/competing_hazards.R index b40ff175..9d74e9e6 100644 --- a/R/competing_hazards.R +++ b/R/competing_hazards.R @@ -13,15 +13,25 @@ CompetingOutcome <- R6::R6Class( stop("size must be positive integer") } private$targeted_process <- targeted_process - self$rates <- rep(0, size) + + self$target <- individual::Bitset$new(size) + self$rates <- NULL }, - set_rates = function(rates){ + set_rates = function(target, rates){ + stopifnot(target$size() == length(rates)) + + # TODO: add an assign method to Bitset + self$target$or(target) self$rates <- rates }, execute = function(t, target){ private$targeted_process(t, target) - self$rates <- rep(0, length(self$rates)) }, + reset = function() { + self$target$clear() + self$rates <- NULL + }, + target = NULL, rates = NULL ) ) @@ -29,46 +39,59 @@ CompetingOutcome <- R6::R6Class( CompetingHazard <- R6::R6Class( "CompetingHazard", private = list( - outcomes = list(), size = NULL, + outcomes = list(), # RNG is passed in because mockery is not able to stub runif # TODO: change when fixed rng = NULL ), public = list( - initialize = function(outcomes, rng = runif){ + initialize = function(size, outcomes, rng = runif){ if (length(outcomes) == 0){ stop("At least one outcome must be provided") } if (!all(sapply(outcomes, function(x) inherits(x, "CompetingOutcome")))){ stop("All outcomes must be of class CompetingOutcome") } + private$size <- size private$outcomes <- outcomes - private$size <- length(outcomes[[1]]$rates) private$rng <- rng }, resolve = function(t){ - event_rates <- do.call( - 'cbind', - lapply(private$outcomes, function(x) x$rates) - ) + candidates <- individual::Bitset$new(private$size) + for (o in private$outcomes) { + candidates$or(o$target) + } + targets.vector <- candidates$to_vector() + + rates <- matrix(ncol = length(private$outcomes), nrow = candidates$size(), 0) + for (i in seq_along(private$outcomes)) { + idx <- match( + private$outcomes[[i]]$target$to_vector(), + targets.vector) - total_rates <- rowSums(event_rates) - probs <- rate_to_prob(total_rates) * (event_rates / total_rates) + rates[idx, i] <- private$outcomes[[i]]$rates + } + + total_rates <- rowSums(rates) + probs <- rate_to_prob(total_rates) * (rates / total_rates) probs[is.na(probs)] <- 0 - rng <- private$rng(private$size) + rng <- private$rng(candidates$size()) + + cumulative <- rep(0, candidates$size()) - cumulative <- rep(0, private$size) for (o in seq_along(private$outcomes)) { next_cumulative <- cumulative + probs[,o] - selected <- which((rng > cumulative) & (rng <= next_cumulative)) + selected <- (rng > cumulative) & (rng <= next_cumulative) cumulative <- next_cumulative - target <- individual::Bitset$new(private$size)$insert(selected) - if (target$size() > 0){ + # TODO: change bitset_at to accept logical array + target <- bitset_at(candidates, which(selected)) + if (target$size() > 0) { private$outcomes[[o]]$execute(t, target) } + private$outcomes[[o]]$reset() } } ) diff --git a/R/disease_progression.R b/R/disease_progression.R index 11f63065..3269b6c5 100644 --- a/R/disease_progression.R +++ b/R/disease_progression.R @@ -10,7 +10,10 @@ create_recovery_rates_process <- function( recovery_outcome ) { function(timestep){ - recovery_outcome$set_rates(variables$recovery_rates$get_values()) + target <- variables$state$get_index_of(c("U", "Tr")) + recovery_outcome$set_rates( + target, + variables$recovery_rates$get_values(target)) } } diff --git a/R/human_infection.R b/R/human_infection.R index d8a39d90..6f16a7bb 100644 --- a/R/human_infection.R +++ b/R/human_infection.R @@ -121,9 +121,9 @@ calculate_infections <- function( ) ## capture infection rates to resolve in competing hazards - infection_rates <- rep(0, length = parameters$human_population) - infection_rates[source_vector] <- prob_to_rate(prob) - infection_outcome$set_rates(infection_rates) + infection_outcome$set_rates( + source_humans, + prob_to_rate(prob)) } #' @title Assigns infections to appropriate human states diff --git a/R/processes.R b/R/processes.R index c416a334..22793b7a 100644 --- a/R/processes.R +++ b/R/processes.R @@ -118,7 +118,8 @@ create_processes <- function( # Resolve competing hazards of infection with disease progression CompetingHazard$new( - outcomes = list(infection_outcome, recovery_outcome) + outcomes = list(infection_outcome, recovery_outcome), + size = parameters$human_population )$resolve ) diff --git a/tests/testthat/test-competing-hazards.R b/tests/testthat/test-competing-hazards.R index e1240d27..8d12a176 100644 --- a/tests/testthat/test-competing-hazards.R +++ b/tests/testthat/test-competing-hazards.R @@ -5,6 +5,8 @@ test_that("hazard resolves two disjoint outcomes", { size <- 4 + population <- individual::Bitset$new(size)$not() + outcome_1_process <- mockery::mock() outcome_1 <- CompetingOutcome$new( targeted_process = outcome_1_process, @@ -17,12 +19,13 @@ test_that("hazard resolves two disjoint outcomes", { ) hazard <- CompetingHazard$new( + size = size, outcomes = list(outcome_1, outcome_2), rng = mockery::mock(c(.05, .3, .2, .5)) ) - outcome_1$set_rates(c(10, 0, 10, 0)) - outcome_2$set_rates(c(0, 10, 0, 10)) + outcome_1$set_rates(population, c(10, 0, 10, 0)) + outcome_2$set_rates(population, c(0, 10, 0, 10)) hazard$resolve(0) @@ -42,6 +45,8 @@ test_that("hazard resolves two disjoint outcomes", { test_that("hazard resolves two competing outcomes", { size <- 4 + population <- individual::Bitset$new(size)$not() + outcome_1_process <- mockery::mock() outcome_1 <- CompetingOutcome$new( targeted_process = outcome_1_process, @@ -54,12 +59,13 @@ test_that("hazard resolves two competing outcomes", { ) hazard <- CompetingHazard$new( + size = size, outcomes = list(outcome_1, outcome_2), rng = mockery::mock(c(.7, .3, .2, .6)) ) - outcome_1$set_rates(c(5, 5, 5, 5)) - outcome_2$set_rates(c(5, 5, 5, 5)) + outcome_1$set_rates(population, c(5, 5, 5, 5)) + outcome_2$set_rates(population, c(5, 5, 5, 5)) hazard$resolve(0) @@ -79,6 +85,8 @@ test_that("hazard resolves two competing outcomes", { test_that("hazard resolves partial outcomes", { size <- 4 + population <- individual::Bitset$new(size)$not() + outcome_1_process <- mockery::mock() outcome_1 <- CompetingOutcome$new( targeted_process = outcome_1_process, @@ -91,12 +99,13 @@ test_that("hazard resolves partial outcomes", { ) hazard <- CompetingHazard$new( + size = size, outcomes = list(outcome_1, outcome_2), rng = mockery::mock(c(.8, .4, .2, .6)) ) - outcome_1$set_rates(prob_to_rate(rep(0.5, size))) - outcome_2$set_rates(prob_to_rate(rep(0.5, size))) + outcome_1$set_rates(population, prob_to_rate(rep(0.5, size))) + outcome_2$set_rates(population, prob_to_rate(rep(0.5, size))) hazard$resolve(0) diff --git a/tests/testthat/test-infection-integration.R b/tests/testthat/test-infection-integration.R index d399a527..d73b6657 100644 --- a/tests/testthat/test-infection-integration.R +++ b/tests/testthat/test-infection-integration.R @@ -707,7 +707,7 @@ test_that('prophylaxis is considered for medicated humans', { targeted_process = function(timestep, target){ infection_outcome_process(timestep, target, variables, renderer, parameters) }, - size = parameters$human_population + size = 4 ) infection_rates <- calculate_infections( diff --git a/tests/testthat/test-pev.R b/tests/testthat/test-pev.R index 9a5f1175..f3c3d793 100644 --- a/tests/testthat/test-pev.R +++ b/tests/testthat/test-pev.R @@ -154,7 +154,7 @@ test_that('Infection considers pev efficacy', { targeted_process = function(timestep, target){ infection_process_resolved_hazard(timestep, target, variables, renderer, parameters) }, - size = parameters$human_population + size = 4 ) # remove randomness from infection sampling