diff --git a/DESCRIPTION b/DESCRIPTION index a918b45..949df22 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -32,7 +32,8 @@ Imports: progress, R6, tensorflow (== 2.16.0), - rlang + rlang, + reticulate (>= 1.40.0) Suggests: coda, covr, diff --git a/NAMESPACE b/NAMESPACE index b5e6eab..29f55fe 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -1,5 +1,7 @@ # Generated by roxygen2: do not edit by hand +export(discrete_lognormal) +export(discrete_normal) export(zero_inflated_negative_binomial) export(zero_inflated_poisson) importFrom(R6,R6Class) diff --git a/R/checks.R b/R/checks.R new file mode 100644 index 0000000..5050d72 --- /dev/null +++ b/R/checks.R @@ -0,0 +1,17 @@ +# check functions + + +check_break_length <- function(breaks){ + if (length(breaks) <= 1) { + msg <- cli::format_error( + c( + "{.var breaks} must be a vector with at least two break points", + "but {.var breaks} has length {length(breaks)}" + ) + ) + stop( + msg, + call. = FALSE + ) + } +} \ No newline at end of file diff --git a/R/conda_greta_env.R b/R/conda_greta_env.R new file mode 100644 index 0000000..a2b2cf0 --- /dev/null +++ b/R/conda_greta_env.R @@ -0,0 +1,18 @@ +use_greta_conda_env <- function() { + tryCatch( + expr = reticulate::use_condaenv("greta-env", required = TRUE), + error = function(e) NULL + ) +} + +using_greta_conda_env <- function() { + config <- reticulate::py_discover_config() + grepl("greta-env", config$python) +} + +have_greta_conda_env <- function(){ + tryCatch( + expr = "greta-env" %in% reticulate::conda_list()$name, + error = function(e) FALSE + ) +} diff --git a/R/discrete_lognormal.R b/R/discrete_lognormal.R new file mode 100644 index 0000000..92e8704 --- /dev/null +++ b/R/discrete_lognormal.R @@ -0,0 +1,107 @@ +#' @name discrete_lognormal +#' @title Discrete lognormal distribution +#' +#' @description a discretised lognormal distribution (i.e. sampled by applying +#' the round operation to samples from a lognormal). Due to the numerical +#' instability of integrating across the distribution, a vector of breaks +#' must be defined and the observations will be treated as censored +#' within those breaks +#' +#' @param meanlog unconstrained parameters giving the mean of the distribution +#' on the log scale +#' @param sdlog unconstrained parameters giving the standard deviation of the +#' distribution on the log scale +#' @param breaks a vector of breaks; observations will be treated as censored +#' within those breaks +#' @param dim a scalar giving the number of rows in the resulting greta array +#' +#' @importFrom R6 R6Class +#' @export + +discrete_lognormal <- function(meanlog, sdlog, breaks, dim = NULL) { + distrib("discrete_lognormal", meanlog, sdlog, breaks, dim) +} + +# define the discrete lognormal distribution +discrete_lognormal_distribution <- R6Class( + classname = "discrete_lognormal_distribution", + inherit = distribution_node, + public = list( + + breaks = NA, + lower_bounds = NA, + upper_bounds = NA, + + initialize = function(meanlog, sdlog, breaks, dim) { + + meanlog <- as.greta_array(meanlog) + sdlog <- as.greta_array(sdlog) + + # check length of breaks + check_break_length(breaks) + + # handle gradient issue between sdlog and 0s + breaks <- pmax(breaks, .Machine$double.eps) + self$breaks <- breaks + self$lower_bounds <- breaks[-length(breaks)] + self$upper_bounds <- breaks[-1] + + # add the nodes as parents and parameters + dim <- check_dims(meanlog, sdlog, target_dim = dim) + super$initialize("discrete_lognormal", dim, discrete = TRUE) + self$add_parameter(meanlog, "meanlog") + self$add_parameter(sdlog, "sdlog") + + }, + + tf_distrib = function(parameters, dag) { + + meanlog <- parameters$meanlog + sdlog <- parameters$sdlog + + tf_breaks <- fl(self$breaks) + tf_lower_bounds <- fl(self$lower_bounds) + tf_upper_bounds <- fl(self$upper_bounds) + + log_prob <- function(x) { + + # build distribution object + d <- tfp$distributions$LogNormal( + loc = meanlog, + scale = sdlog + ) + + # for those lumped into groups, + # compute the bounds of the observed groups + # and get tensors for the bounds in the format expected by TFP + x_safe <- tf$math$maximum(x, fl(.Machine$double.eps)) + tf_idx <- tfp$stats$find_bins(x_safe, tf_breaks) + tf_idx_int <- tf_as_integer(tf_idx) + tf_lower_vec <- tf$gather(tf_lower_bounds, tf_idx_int) + tf_upper_vec <- tf$gather(tf_upper_bounds, tf_idx_int) + + # compute the density over the observed groups + low <- tf_safe_cdf(tf_lower_vec, d) + up <- tf_safe_cdf(tf_upper_vec, d) + log_density <- log(up - low) + + } + + sample <- function(seed) { + + d <- tfp$distributions$LogNormal( + loc = meanlog, + scale = sdlog + ) + continuous <- d$sample(seed = seed) + # tf$floor(continuous) + tf$round(continuous) + + } + + list(log_prob = log_prob, sample = sample) + + } + + ) +) \ No newline at end of file diff --git a/R/discrete_normal.R b/R/discrete_normal.R new file mode 100644 index 0000000..b8d3f1b --- /dev/null +++ b/R/discrete_normal.R @@ -0,0 +1,142 @@ +#' @name discrete_normal +#' @title Discrete normal distribution +#' +#' @description a discretised normal distribution (i.e. sampled by applying +#' the round operation to samples from a normal). Due to the numerical +#' instability of integrating across the distribution, a vector of breaks +#' must be defined and the observations will be treated as censored +#' within those breaks +#' +#' @param mean unconstrained parameters giving the mean of the distribution +#' @param sd unconstrained parameters giving the standard deviation of the +#' distribution +#' @param breaks a vector of breaks; observations will be treated as censored +#' within those breaks +#' @param edges a vector of edges; length needs to be length(breaks)+1; +#' observations between two consecutive edges will be discretised to the +#' break value between the corresponding edges +#' @param dim a scalar giving the number of rows in the resulting greta array +#' +#' @importFrom R6 R6Class +#' @export + +discrete_normal <- function(mean, sd, breaks, edges, dim = NULL) { + distrib("discrete_normal", mean, sd, breaks, edges, dim) +} + +# define the discrete normal distribution +discrete_normal_distribution <- R6Class( + classname = "discrete_normal_distribution", + inherit = distribution_node, + public = list( + + breaks = NA, + edges = NA, + lower_bounds = NA, + upper_bounds = NA, + lower_bound = NA, + upper_bound = NA, + + initialize = function(mean, sd, breaks, edges, dim) { + + mean <- as.greta_array(mean) + sd <- as.greta_array(sd) + + # check length of breaks + check_break_length(breaks) + + # add breaks, vector of lower and upper bounds, and the lower and upper + # bound of supported values + # self$breaks <- breaks + # self$lower_bounds <- breaks[-length(breaks)] + # self$upper_bounds <- breaks[-1] + # self$lower_bound <- min(breaks) + # self$upper_bound <- max(breaks) + + # EXPERIMENTAL: + # convert breaks to edges, which + # will be used to gather samples at the breaks and convert them to + # rounded values + # while avoiding the use of round or floor, which assumes that + # the breaks are all the integers + # perhaps more crucially, we use edges, not breaks, to integrate the CDFs + # because we say that the likelihood of observing a break value is the sum + # of probability from the lower to upper edges surrounding a break + # first, create the lower and upper bounds that will bin a continuous + # variable into breaks + # using midpoint between breaks for now (from diff) + # generate edges using midpoint, -Inf, and Inf + # breaks_diff <- diff(breaks) + # edges <- c(-Inf, breaks[-length(breaks)] + breaks_diff/2, Inf) + self$edges <- edges + self$lower_bounds <- edges[-length(edges)] + self$upper_bounds <- edges[-1] + self$lower_bound <- min(edges) + self$upper_bound <- max(edges) + + # add the nodes as parents and parameters + dim <- check_dims(mean, sd, target_dim = dim) + super$initialize("discrete_normal", dim, discrete = TRUE) + self$add_parameter(mean, "mean") + self$add_parameter(sd, "sd") + + }, + + tf_distrib = function(parameters, dag) { + + mean <- parameters$mean + sd <- parameters$sd + + tf_breaks <- fl(self$breaks) + tf_edges <- fl(self$edges) + tf_lower_bounds <- fl(self$lower_bounds) + tf_upper_bounds <- fl(self$upper_bounds) + tf_lower_bound <- fl(self$lower_bound) + tf_upper_bound <- fl(self$upper_bound) + + log_prob <- function(x) { + + # build distribution object + d <- tfp$distributions$Normal( + loc = mean, + scale = sd + ) + + # for those lumped into groups, + # compute the bounds of the observed groups + # and get tensors for the bounds in the format expected by TFP + tf_idx <- tfp$stats$find_bins(x, tf_edges) + tf_idx_int <- tf_as_integer(tf_idx) + tf_lower_vec <- tf$gather(tf_lower_bounds, tf_idx_int) + tf_upper_vec <- tf$gather(tf_upper_bounds, tf_idx_int) + + # compute the density over the observed groups + # note-to-self: this looks like https://mc-stan.org/docs/2_29/stan-users-guide/bayesian-measurement-error-model.html#rounding + low <- tf_safe_cdf(tf_lower_vec, d, tf_lower_bound, tf_upper_bound) + up <- tf_safe_cdf(tf_upper_vec, d, tf_lower_bound, tf_upper_bound) + log_density <- log(up - low) + + } + + sample <- function(seed) { + + d <- tfp$distributions$Normal( + loc = mean, + scale = sd + ) + continuous <- d$sample(seed = seed) + + # gather samples at the breaks to convert them to rounded values + # ditto from what we did to breaks above + tf_edges_idx <- tfp$stats$find_bins(continuous, tf_edges) + tf_edges_idx_int <- tf_as_integer(tf_edges_idx) + tf$gather(tf_breaks, tf_edges_idx_int) + + } + + list(log_prob = log_prob, sample = sample) + + } + + ) +) \ No newline at end of file diff --git a/R/tf_functions.R b/R/tf_functions.R new file mode 100644 index 0000000..b53bde7 --- /dev/null +++ b/R/tf_functions.R @@ -0,0 +1,25 @@ +# tensorflow functions + +# CDF of the provided distribution, handling 0s and Infs +tf_safe_cdf <- function(x, distribution, lower_bound, upper_bound) { + + # prepare to handle values outside the supported range + too_low <- tf$less(x, lower_bound) + too_high <- tf$greater_equal(x, upper_bound) + supported <- !too_low & !too_high + ones <- tf$ones_like(x) + zeros <- tf$zeros_like(x) + + # run cdf on supported values, and fill others in with the appropriate value + x_clean <- tf$where(supported, x, ones) + cdf_clean <- distribution$cdf(x_clean) + mask <- tf$where(supported, ones, zeros) + add <- tf$where(too_high, ones, zeros) + cdf_clean * mask + add + +} + +# cast to integer +tf_as_integer <- function(x) { + tf$cast(x, tf$int32) +} \ No newline at end of file diff --git a/R/zzz.R b/R/zzz.R index f5f8ae3..26415d1 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -1,2 +1,2 @@ # load tf probability -tfp <- reticulate::import("tensorflow_probability", delay_load = TRUE) +tfp <- reticulate::import("tensorflow_probability", delay_load = TRUE) \ No newline at end of file diff --git a/man/discrete_lognormal.Rd b/man/discrete_lognormal.Rd new file mode 100644 index 0000000..2a16f6e --- /dev/null +++ b/man/discrete_lognormal.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/discrete_lognormal.R +\name{discrete_lognormal} +\alias{discrete_lognormal} +\title{Discrete lognormal distribution} +\usage{ +discrete_lognormal(meanlog, sdlog, breaks, dim = NULL) +} +\arguments{ +\item{meanlog}{unconstrained parameters giving the mean of the distribution +on the log scale} + +\item{sdlog}{unconstrained parameters giving the standard deviation of the +distribution on the log scale} + +\item{breaks}{a vector of breaks; observations will be treated as censored +within those breaks} + +\item{dim}{a scalar giving the number of rows in the resulting greta array} +} +\description{ +a discretised lognormal distribution (i.e. sampled by applying +the round operation to samples from a lognormal). Due to the numerical +instability of integrating across the distribution, a vector of breaks +must be defined and the observations will be treated as censored +within those breaks +} diff --git a/man/discrete_normal.Rd b/man/discrete_normal.Rd new file mode 100644 index 0000000..64215f0 --- /dev/null +++ b/man/discrete_normal.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/discrete_normal.R +\name{discrete_normal} +\alias{discrete_normal} +\title{Discrete normal distribution} +\usage{ +discrete_normal(mean, sd, breaks, edges, dim = NULL) +} +\arguments{ +\item{mean}{unconstrained parameters giving the mean of the distribution} + +\item{sd}{unconstrained parameters giving the standard deviation of the +distribution} + +\item{breaks}{a vector of breaks; observations will be treated as censored +within those breaks} + +\item{edges}{a vector of edges; length needs to be length(breaks)+1; +observations between two consecutive edges will be discretised to the +break value between the corresponding edges} + +\item{dim}{a scalar giving the number of rows in the resulting greta array} +} +\description{ +a discretised normal distribution (i.e. sampled by applying +the round operation to samples from a normal). Due to the numerical +instability of integrating across the distribution, a vector of breaks +must be defined and the observations will be treated as censored +within those breaks +} diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R index 7e1fac5..0debc91 100644 --- a/tests/testthat/helpers.R +++ b/tests/testthat/helpers.R @@ -949,6 +949,8 @@ run_opt <- function( ) } +# discrete lognormal + # zero inflated poisson using distributional dist_zero_inflated_pois <- function(lambda, pi) { diff --git a/tests/testthat/test-discrete_lognormal.R b/tests/testthat/test-discrete_lognormal.R new file mode 100644 index 0000000..b6d1f5a --- /dev/null +++ b/tests/testthat/test-discrete_lognormal.R @@ -0,0 +1,43 @@ +test_that( + "discrete_lognormal fails when given the wrong argument dimensions", { + # check breaks is a vector with at least length 2 + expect_snapshot( + error = TRUE, + discrete_lognormal(p = 1, psi = 1, breaks = 1) + ) + + # check dim to be a positive scalar integer + expect_snapshot( + error = TRUE, + discrete_lognormal(p = 1, psi = 1, breaks = c(1, 2), dim = 0.1) + ) + + }) +# +# test_that("discrete lognormal distribution has correct density", { +# skip_if_not(check_tf_version()) +# +# compare_distribution( +# greta_fun = discrete_lognormal, +# r_fun = extraDistr::ddnorm, +# parameters = list(lambda = 2, pi = 0.2), +# x = extraDistr::rdnorm( +# n = 100, +# mean = 0, +# sd = 1 +# ) +# ) +# }) +# +# test_that("discrete normal distribution has correct density", { +# skip_if_not(check_tf_version()) +# +# compare_distribution( +# greta_fun = discrete_normal, +# r_fun = extraDistr::ddnorm, +# parameters = list(), +# x = extraDistr::rzinb( +# n = 100, size = 10, prob = 0.1, pi = 0.2 +# ) +# ) +# })