-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add functions for implementing continuous relaxations of discrete stochastic transitions #31
Comments
Here's a script with my current hacky functions, using the current greta_2 branch (which has time-varying parameters etc and bug fixes), applied to a problem where the population goes extinct. # do 1D version of stochastic dynamics in greta.dynamics
# simulate a discrete stochastic growth rate population model
set.seed(3)
n_times <- 50
# daily population growth rate
r_true <- 1.01
# initial population (just before timeseries)
pop_init_true <- 25
time <- seq_len(n_times)
pop_true <- rep(NA, n_times)
pop_previous <- pop_init_true
for (i in 1:n_times) {
pop_true[i] <- rpois(1, pop_previous * r_true)
pop_previous <- pop_true[i]
}
# add an observation process (binomial with fixed detection probability)
obs_prob <- 0.7
pop_obs <- rbinom(n_times, size = pop_true, prob = obs_prob)
# define functions for continuous relaxation and reparameterisation of the
# Poisson distributions
# given poisson rate parameter lambda and random uniform deviate u, a continuous
# relaxation of poisson random variable generation is computed using the inverse
# of the incomplete gamma function. ie. igammainv(lambda, 1 - u) is
# approximately equal to qpois(u, lambda) (and exactly equal to qgamma(1 - u,
# lambda) in the R implementation)
gamma_continuous_poisson <- function(lambda, u) {
igammainv(lambda, 1 - u)
}
# # check:
# lambda <- as_data(pi)
# u <- uniform(0, 1)
# y <- gamma_continuous_poisson(lambda, u)
# sims <- calculate(y, u, nsim = 1e5)
# max(abs(sims$y - qgamma(1 - sims$u, pi)))
# quantile(round(sims$y))
# quantile(rpois(1e5, pi))
# the inverse incomplete gamma function (the major part of the quantile function
# of a gamma distribution)
igammainv <- function(a, p) {
op <- greta::.internals$nodes$constructors$op
op("igammainv", a, p,
tf_operation = "tf_igammainv"
)
}
tf_igammainv <- function(a, p) {
tfp <- greta:::tfp
tfp$math$igammainv(a, p)
}
# given random variables z (with standard normal distribution a priori), and
# Poisson rate parameter lambda, return a strictly positive continuous random
# variable with the same mean and variance as a poisson random variable with
# rate lambda, by approximating the poisson as a lognormal distribution.
lognormal_continuous_poisson <- function(lambda, z) {
sigma <- sqrt(log1p(lambda / exp(2 * log(lambda))))
# sigma2 <- log1p(1 / lambda)
mu <- log(lambda) - sigma^2 / 2
exp(mu + z * sigma)
}
# Working: The lognormal mean and variance should both equal lambda. The
# lognormal mean and variance can both be expressed in terms of the parameters
# mu and sigma.
# mean = lambda = exp(mu + sigma^2 / 2)
# variance = lambda = (exp(sigma ^ 2) - 1) * exp(2 * mu + sigma ^ 2)
# solve to get sigma and mu as a function of lambda:
# mu = log(lambda) - sigma^2 / 2
# lambda = (exp(sigma ^ 2) - 1) * exp(2 * mu + sigma ^ 2)
# lambda = (exp(sigma ^ 2) - 1) * exp(2 * (log(lambda) - sigma^2 / 2) + sigma ^ 2)
# lambda = (exp(sigma ^ 2) - 1) * exp(sigma ^ 2) * exp(2 * (log(lambda) - sigma^2 / 2))
# lambda = (exp(sigma ^ 2) - 1) * exp(sigma ^ 2) * exp(2 * log(lambda) - sigma^2)
# lambda = (exp(sigma ^ 2) - 1) * exp(sigma ^ 2) * exp(2 * log(lambda)) / exp(sigma^2)
# lambda / exp(2 * log(lambda)) = (exp(sigma ^ 2) - 1) * exp(sigma ^ 2) * 1 / exp(sigma^2)
# lambda / exp(2 * log(lambda)) = (exp(sigma ^ 2) - 1)
# log(lambda / exp(2 * log(lambda)) + 1) = sigma ^ 2
# sigma = sqrt(log(lambda / exp(2 * log(lambda)) + 1)) = sigma
# mu = log(lambda) - sigma^2 / 2
# # check these numerically
# library(tidyverse)
# compare <- tibble(
# lambda = seq(0.01, 1000, length.out = 100)
# ) %>%
# mutate(
# sigma = sqrt(log(lambda / exp(2 * log(lambda)) + 1)),
# mu = log(lambda) - sigma^2 / 2
# ) %>%
# mutate(
# mean = exp(mu + sigma^2 / 2),
# variance = (exp(sigma ^ 2) - 1) * exp(2 * mu + sigma ^ 2)
# ) %>%
# mutate(
# diff_mean_variance = abs(mean - variance),
# diff_mean_lambda = abs(mean - lambda),
# diff_variance_lambda = abs(variance - lambda)
# ) %>%
# summarise(
# across(
# starts_with("diff"),
# ~max(.x)
# )
# )
library(greta.dynamics)
#> Loading required package: greta
#>
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#>
#> binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#>
#> %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#> eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#> tapply
# growth rate
r <- normal(1, 0.1, truncation = c(0, Inf))
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#>
# latent variable for stochastic transitions
latent_z_vec <- normal(0, 1, dim = n_times)
# latent_u_vec <- uniform(0, 1, dim = n_times)
# initial population size (with mean equal to truth)
init <- exponential(1 / pop_init_true)
# transition function for the population process as a difference equation in
# integer timestep
fun <- function(state, iter, r, latent_z) {
lambda <- state * r
state_new <- lognormal_continuous_poisson(lambda, latent_z)
# state_new <- gamma_continuous_poisson(lambda, latent_u)
state_new
}
# solve it (integer times)
pop_sim <- iterate_dynamic_function(
transition_function = fun,
initial_state = init,
niter = n_times,
tol = 0,
r = r,
latent_z = latent_z_vec,
parameter_is_time_varying = "latent_z",
# clamp the populations to reasonably values to avoid numerical under/overflow
state_limits = c(1e-3, 1e3)
)
# get the modelled true population
pop_modelled <- t(pop_sim$all_states)
# and the expected value of the observation distribution
pop_obs_expected <- pop_modelled * obs_prob
pop_obs_ga <- as_data(pop_obs)
distribution(pop_obs_ga) <- poisson(pop_obs_expected)
# fit this, reducing the stochasticity when finding the initial values
n_chains <- 4
inits <- replicate(n_chains,
initials(
init = pop_init_true,
latent_z_vec = rnorm(n_times, 0, 0.1)),
simplify = FALSE)
m <- model(r, latent_z_vec, init)
draws <- mcmc(m,
initial_values = inits)
#> running 4 chains simultaneously on up to 8 CPU cores
#>
#> warmup ====================================== 1000/1000 | eta: 0s
#> sampling ====================================== 1000/1000 | eta: 0s
# get posterior samples and plot summaries
sims <- calculate(pop_modelled, nsim = 1000, values = draws)
posterior_sims <- sims$pop_modelled[, , 1]
posterior_sims_discrete <- round(posterior_sims)
posterior_est <- colMeans(posterior_sims_discrete)
posterior_ci <- apply(posterior_sims_discrete, 2, quantile, c(0.025, 0.975))
# plot posterior draws, summary stats, and the truth
plot(posterior_est,
ylim = range(c(posterior_ci, pop_true)),
type = "n",
xlab = "day",
ylab = "population")
for (i in 1:50) {
lines(posterior_sims_discrete[i, ],
lwd = 0.1,
col = grey(0.4))
}
lines(posterior_ci[1, ], lty = 2)
lines(posterior_ci[2, ], lty = 2)
lines(posterior_est,
lwd = 1.5)
lines(pop_true,
col = "blue")
# noisy observations, naively adjusted for detection probability
points(pop_obs / obs_prob,
cex = 0.5) Created on 2024-03-13 with reprex v2.0.2 |
Here's a demo estimating population trajectories and growth rates from data with discrete stochastic (poisson + multinomial) population and extinction/invasion dynamics in a metapopulation: # demo of stochastic dispersal and growth dynamics
# simulate a discrete stochastic growth rate population model
set.seed(1)
n_times <- 20
n_pops <- 4
# daily population growth rate in each location
r_true <- runif(n_pops, 1, 1.3)
# initial populations (just before timeseries)
pop_init_true <- c(15, 0, 0, 0)
# population locations and dispersal matrix
coords <- matrix(runif(n_pops * 2), nrow = n_pops)
dispersal_range <- 6
dispersal_weight_raw <- exp(-dispersal_range * as.matrix(dist(coords)))
# add a nugget effect (increased probability of not dispersing)
prob_dispersing <- 0.1
dispersal_weight <- dispersal_weight_raw * prob_dispersing +
diag(n_pops) * (1 - prob_dispersing)
# plot these
par(mfrow = c(1, 1))
plot(coords,
type = "n",
ylab = "",
xlab = "",
axes = FALSE)
for (i in 1:n_pops) {
for (j in i:n_pops) {
arrows(x0 = coords[i, 1],
y0 = coords[i, 2],
x1 = coords[j, 1],
y1 = coords[j, 2],
length = 0,
lwd = 10 * dispersal_weight[i, j])
}
}
points(coords,
pch = 21,
bg = grey(0.8),
cex = 2)
text(coords[, 1],
coords[, 2],
labels = paste("pop", seq_len(n_pops)),
pos = 3,
xpd = NA)
# normalise dispersal weights to get dispersal probabilities
dispersal_prob <- sweep(dispersal_weight,
1,
rowSums(dispersal_weight),
FUN = "/")
# simulate stochastic population dynamics and dispersal
time <- seq_len(n_times)
pop_true <- matrix(NA, nrow = n_times, ncol = n_pops)
pop_previous <- pop_init_true
for (i in 1:n_times) {
# innovate populations
pop_grown <- rpois(n_pops, pop_previous * r_true)
# do dispersal, with multinomial randomness
pop_dispersed <- matrix(NA, n_pops, n_pops)
for (pop in 1:n_pops) {
pop_dispersed[pop, ] <- rmultinom(1,
pop_grown[pop],
prob = dispersal_prob[pop, ])
}
# collate all the individuals staying, arriving, less those leaving
pop_new <- colSums(pop_dispersed)
# store the states
pop_previous <- pop_true[i, ] <- pop_new
}
# add an observation process (binomial with fixed detection probability)
obs_prob <- 0.8
pop_obs <- pop_true * NA
pop_obs[] <- rbinom(length(pop_true),
size = pop_true[],
prob = obs_prob)
# # plot true (lines) and observed (points) populations across these populations
# par(mfrow = n2mfrow(n_pops))
# for (i in seq_len(n_pops)) {
# plot(pop_true[, i] ~ time,
# type = "l",
# ylab = "population",
# ylim = range(pop_true),
# main = paste("pop", i))
#
# points(pop_obs[, i] ~ time,
# type = "b",
# pch = 21,
# bg = ifelse(pop_true[, i] > 0, grey(0.4, 0.5), NA))
# }
# build a greta model to infer latent populations, using stochastic transitions
# but a continuous relaxation of the discrete process
source("src/modelFunctions.R")
#> Warning in file(filename, "r", encoding = encoding): cannot open file
#> 'src/modelFunctions.R': No such file or directory
#> Error in file(filename, "r", encoding = encoding): cannot open the connection
library(greta.dynamics)
#> Loading required package: greta
#>
#> Attaching package: 'greta'
#> The following objects are masked from 'package:stats':
#>
#> binomial, cov2cor, poisson
#> The following objects are masked from 'package:base':
#>
#> %*%, apply, backsolve, beta, chol2inv, colMeans, colSums, diag,
#> eigen, forwardsolve, gamma, identity, rowMeans, rowSums, sweep,
#> tapply # define functions for continuous relaxation and reparameterisation of the
# Poisson distributions
# given random variables z (with standard normal distribution a priori), and
# Poisson rate parameter lambda, return a strictly positive continuous random
# variable with the same mean and variance as a poisson random variable with
# rate lambda, by approximating the poisson as a lognormal distribution.
lognormal_continuous_poisson <- function(lambda, z) {
sigma <- sqrt(log1p(lambda / exp(2 * log(lambda))))
# sigma2 <- log1p(1 / lambda)
mu <- log(lambda) - sigma^2 / 2
exp(mu + z * sigma)
}
# dispersal parameters to be fixed for now, just learn the growth rates and the
# initial populations and transitions
r <- normal(1, 0.25,
truncation = c(0, Inf),
dim = n_pops)
#> ℹ Initialising python and checking dependencies, this may take a moment.
#> ✔ Initialising python and checking dependencies ... done!
#>
# latent variable for stochastic transitions in lognormal approximation to
# Poisson
latent_z_vec <- normal(0, 1, dim = c(n_times, n_pops))
# initial population size - with mean equal to truth, but small, instead of
# zero, elements
pop_init_prior <- pmax(pop_init_true, 1e-3)
init <- exponential(1 / pop_init_prior)
# transition function for the population process as a difference equation in
# integer timestep
fun <- function(state, iter, r, latent_z) {
# grow the populations
expected_pop_grown <- state * r
# disperse the populations
expected_pop <- t(dispersal_prob) %*% expected_pop_grown
# add stochasticity
state_new <- lognormal_continuous_poisson(expected_pop, latent_z)
state_new
}
# solve it (integer times)
pop_sim <- iterate_dynamic_function(
transition_function = fun,
initial_state = init,
niter = n_times,
tol = 0,
r = r,
latent_z = latent_z_vec,
parameter_is_time_varying = "latent_z",
# clamp the populations to reasonably values to avoid numerical under/overflow
state_limits = c(1e-3, 1e3)
)
# get the modelled true population
pop_modelled <- t(pop_sim$all_states)
# and the expected value of the observation distribution
pop_obs_expected <- pop_modelled * obs_prob
pop_obs_ga <- as_data(pop_obs)
distribution(pop_obs_ga) <- poisson(pop_obs_expected)
# set the initial values for the population trajectories to be near the
# (deterministic) expected values, so sampling the stochastic values doesn't
# initialise us in very weird parts of parameter space
n_chains <- 4
inits <- replicate(n_chains,
initials(
init = pop_init_prior,
latent_z_vec = matrix(rnorm(n_times * n_pops, 0, 0.1),
n_times,
n_pops)),
simplify = FALSE)
m <- model(r, latent_z_vec, init)
draws <- mcmc(m,
initial_values = inits)
#> running 4 chains simultaneously on up to 8 CPU cores
#>
#> warmup ====================================== 1000/1000 | eta: 0s
#> sampling ====================================== 1000/1000 | eta: 0s
# check convergence
summary(coda::gelman.diag(draws,
autoburnin = FALSE,
multivariate = FALSE)$psrf)
#> Point est. Upper C.I.
#> Min. :1.002 Min. :1.004
#> 1st Qu.:1.048 1st Qu.:1.123
#> Median :1.112 Median :1.298
#> Mean :1.211 Mean :1.538
#> 3rd Qu.:1.234 3rd Qu.:1.616
#> Max. :2.750 Max. :5.395
# check growth rate estimates we would not expect the posterior to be
# data-informed at all for populations 2 and 3 where the population doesn't have
# a chance to grow. For 1 and 4 (established populations form near the start, no
# stochastic extinction)) we would expect to to estimate a positive growth rate
# somewhere in the correct ball park
summary(calculate(r, values = draws))$statistics
#> Mean SD Naive SE Time-series SE
#> r[1,1] 1.0957812 0.03560134 0.0005629066 0.004680752
#> r[2,1] 0.7560150 0.24658097 0.0038987875 0.040579097
#> r[3,1] 0.9790061 0.12836447 0.0020296205 0.015262417
#> r[4,1] 1.1751366 0.06651551 0.0010517025 0.008744965
r_true
#> [1] 1.079653 1.111637 1.171856 1.272462
# posterior simulations
sims_posterior <- calculate(pop_modelled,
values = draws,
nsim = 100)
par(mfrow = n2mfrow(n_pops))
for (pop in 1:n_pops) {
trajectories_posterior <- round(sims_posterior[[1]][, , pop])
plot(trajectories_posterior[1, ] ~ time,
type = "n",
ylim = range(c(trajectories_posterior)),
ylab = "population",
main = paste("pop", pop))
apply(jitter(trajectories_posterior),
1,
lines,
col = grey(0.4, 0.2),
lwd = 2)
lines(pop_true[, pop] ~ time,
lwd = 2,
col = "blue")
points(pop_obs[, pop] / obs_prob ~ time,
bg = ifelse(pop_true[, pop] > 0, "skyblue", "white"),
cex = 1,
pch = 21)
} Created on 2024-03-15 with reprex v2.0.2 |
Here are some ballpark run times (MCMC stage only, 4 chains, 2K warmup, 2K samples, not accounting for convergence) with varying dimensions of this example problem:
So the model run time is approximately linear in |
Background
Gradient-based inference (like the HMC greta uses) can only operate on continuous parameter spaces. That means it cannot learn the values of parameters with discrete support (e.g. no unobserved Poisson random variables).
But demographic stochasticity due to discrete stochastic variation in population sizes between timesteps in models of populations and infections (e.g. the number of new infectees is Poisson, the number of individuals surviving is binomial) is often important, especially when populations reach low numbers and near extinction. We cannot directly model the values of these discrete random variables, but we can apply continuous relaxations to approximate them; keeping the state values continuous and replacing the discrete random variables with continuous random variables that matches the mean and variance (and ideally the full shape of the distribution) of the random variable we would like to model as stochastic.
Continuous approximations to distributions
E.g. a poisson random variable can be approximated with an appropriately-shaped gamma distribution that exactly matches the PMF at discrete values, or by some other distribution that is a a close-enough approximation:
We might write a discrete stochastic growth-rate model like this:
where$x$ takes integer values, $Poisson(\lambda)$ is the Poisson distribution, and $r$ is a positive-valued growth rate parameter. To estimate the posterior over the values represented by $x$ in this model, but using HMC, we could instead fit the model:
where$y_t$ is a (strictly positive) real-valued parameter, and $\pi(\lambda)$ is some probability distribution with support on positive real values that has similar moments (men, variance, skewness, etc.) to $Poisson(\lambda)$ .
Reparameterisation
If we structure these probability distributions such that they can be reparameterised in terms of the parameter and some latent 'innovation' or noise, with known distribution, we can significantly improve the ability to sample these models, since we can decorrelate the posterior distribution in a similar way to the reparameterisation trick for hierarchical models. This can also provide some computational advantages in greta/tensorflow by working with arrays rather than scalars.
If we know the quantile function of the continuous distribution (e.g.$q_{\pi}(p, \lambda)$ as the quantile function of $\pi(\lambda)$ , with $p$ the probability argument), we can reparameterise the innovations as $u \sim U(0, 1)$ , and then plug them into the quantile function to sample poisson values. Ie. equation 2 above is equivalent to:
The vector of$u$ values can then be computed in advance, and passed into the solvers to be chopped up appropriately. The dependency structure in the model means $y_t$ depends on $y_{t-1}$ , and so they are a priori (and therefore also a posteriori) correlated. But in this reparameterised formulation, HMC operates instead on $u$ , and $u_t$ doesn't depend on $u_{t-1}$ so they are a priori uncorrelated, which removes a lot of correlation in the posterior and makes sampling much easier.
Note that if the quantile function is expensive to compute, other approximations and reparameterisations may be more appealing. E.g. for the Poisson, the inverse of the incomplete gamma function gives the quantile of the gamma distribution whose PDF matches the Poisson PMF at discrete values, but the function has no analytic form and is expensive to compute. A lognormal approximation (with either uniform or standard normal innovations) is imperfect but much more computationally efficient.
Currently this reparameterisation trick is only applicable in the
greta_2
branch withiterate_dynamic_function()
, since it requires functionality for indexing time-varying parameters.Implementation
We just need to provide functions for the relaxations, documentation, and examples of applying this approach. I have some, I just need to add them to the
greta_2
branch and work out the neatest user interface.The text was updated successfully, but these errors were encountered: