Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add discrete_lognormal and discrete_normal #16

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
336c601
match roxygen version and import R6
hrlai Jul 26, 2022
3ec4b1e
first pass at moving, slighly refactoring, and documenting the discre…
hrlai Jul 26, 2022
b53e09f
implement test for discrete_lognormal
hrlai Jul 27, 2022
1c35e6c
add checks for the breaks and dim arguments
hrlai Jul 27, 2022
5c6cf8b
create home for .onLoad functions, copied from https://github.com/gre…
hrlai Jul 27, 2022
a3a4498
copy reticulate functions from https://github.com/greta-dev/greta/blo…
hrlai Jul 27, 2022
fcb31cf
remove check dims for now, I think we need to use the check_dims func…
hrlai Jul 27, 2022
8a70b8b
add internal ::: call back to discrete_lognormal; is this ok or any b…
hrlai Jul 27, 2022
bcc9d84
add the necessary tf function for discrete lognormal; naming script i…
hrlai Jul 27, 2022
7b9393f
switch from floor to round for a more general purpose use; this will …
hrlai Jul 28, 2022
77a76a2
add discrete_normal and fix a typo in discrete_lognormal
hrlai Jul 28, 2022
048f22e
add note to link Stan URL in case their specification is helpful
hrlai Aug 1, 2022
e2ae5b6
remove greta::: we don't seem to need it
hrlai Aug 1, 2022
15cc772
remove hard fix of lower and upper bounds in tf_safe_cdf; remove pmax…
hrlai Aug 3, 2022
49f6d0e
replace round in the sample function with tf$gather that discretise c…
hrlai Aug 3, 2022
8c646a8
define tf_as_integer to avoid it being an unexported function
hrlai Aug 3, 2022
b9a412f
move check functions to a dedicated file
hrlai Aug 3, 2022
8676719
experimenting with using edges, instead of breaks, for the integratio…
hrlai Aug 4, 2022
9610368
use tf$greater_equal rather than tf$greater for the upper bound of CDFs
hrlai Aug 4, 2022
cd0f1f7
small fixes from merging into main
njtierney Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,8 @@ Imports:
progress,
R6,
tensorflow (== 2.16.0),
rlang
rlang,
reticulate (>= 1.40.0)
Suggests:
coda,
covr,
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Generated by roxygen2: do not edit by hand

export(discrete_lognormal)
export(discrete_normal)
export(zero_inflated_negative_binomial)
export(zero_inflated_poisson)
importFrom(R6,R6Class)
Expand Down
17 changes: 17 additions & 0 deletions R/checks.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# check functions


check_break_length <- function(breaks){
if (length(breaks) <= 1) {
msg <- cli::format_error(
c(
"{.var breaks} must be a vector with at least two break points",
"but {.var breaks} has length {length(breaks)}"
)
)
stop(
msg,
call. = FALSE
)

Check warning on line 15 in R/checks.R

View check run for this annotation

Codecov / codecov/patch

R/checks.R#L5-L15

Added lines #L5 - L15 were not covered by tests
}
}
18 changes: 18 additions & 0 deletions R/conda_greta_env.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use_greta_conda_env <- function() {
tryCatch(
expr = reticulate::use_condaenv("greta-env", required = TRUE),
error = function(e) NULL
)

Check warning on line 5 in R/conda_greta_env.R

View check run for this annotation

Codecov / codecov/patch

R/conda_greta_env.R#L2-L5

Added lines #L2 - L5 were not covered by tests
}

using_greta_conda_env <- function() {
config <- reticulate::py_discover_config()
grepl("greta-env", config$python)

Check warning on line 10 in R/conda_greta_env.R

View check run for this annotation

Codecov / codecov/patch

R/conda_greta_env.R#L9-L10

Added lines #L9 - L10 were not covered by tests
}

have_greta_conda_env <- function(){
tryCatch(
expr = "greta-env" %in% reticulate::conda_list()$name,
error = function(e) FALSE
)

Check warning on line 17 in R/conda_greta_env.R

View check run for this annotation

Codecov / codecov/patch

R/conda_greta_env.R#L14-L17

Added lines #L14 - L17 were not covered by tests
}
107 changes: 107 additions & 0 deletions R/discrete_lognormal.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#' @name discrete_lognormal
#' @title Discrete lognormal distribution
#'
#' @description a discretised lognormal distribution (i.e. sampled by applying
#' the round operation to samples from a lognormal). Due to the numerical
#' instability of integrating across the distribution, a vector of breaks
#' must be defined and the observations will be treated as censored
#' within those breaks
#'
#' @param meanlog unconstrained parameters giving the mean of the distribution
#' on the log scale
#' @param sdlog unconstrained parameters giving the standard deviation of the
#' distribution on the log scale
#' @param breaks a vector of breaks; observations will be treated as censored
#' within those breaks
#' @param dim a scalar giving the number of rows in the resulting greta array
#'
#' @importFrom R6 R6Class
#' @export

discrete_lognormal <- function(meanlog, sdlog, breaks, dim = NULL) {
distrib("discrete_lognormal", meanlog, sdlog, breaks, dim)
}

# define the discrete lognormal distribution
discrete_lognormal_distribution <- R6Class(
classname = "discrete_lognormal_distribution",
inherit = distribution_node,
public = list(

breaks = NA,
lower_bounds = NA,
upper_bounds = NA,

initialize = function(meanlog, sdlog, breaks, dim) {

meanlog <- as.greta_array(meanlog)
sdlog <- as.greta_array(sdlog)

# check length of breaks
check_break_length(breaks)

# handle gradient issue between sdlog and 0s
breaks <- pmax(breaks, .Machine$double.eps)
self$breaks <- breaks

Check warning on line 45 in R/discrete_lognormal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_lognormal.R#L45

Added line #L45 was not covered by tests
self$lower_bounds <- breaks[-length(breaks)]
self$upper_bounds <- breaks[-1]

# add the nodes as parents and parameters
dim <- check_dims(meanlog, sdlog, target_dim = dim)
super$initialize("discrete_lognormal", dim, discrete = TRUE)

Check warning on line 51 in R/discrete_lognormal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_lognormal.R#L47-L51

Added lines #L47 - L51 were not covered by tests
self$add_parameter(meanlog, "meanlog")
self$add_parameter(sdlog, "sdlog")

Check warning on line 53 in R/discrete_lognormal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_lognormal.R#L53

Added line #L53 was not covered by tests

},

tf_distrib = function(parameters, dag) {

meanlog <- parameters$meanlog
sdlog <- parameters$sdlog

tf_breaks <- fl(self$breaks)
tf_lower_bounds <- fl(self$lower_bounds)
tf_upper_bounds <- fl(self$upper_bounds)

Check warning on line 64 in R/discrete_lognormal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_lognormal.R#L63-L64

Added lines #L63 - L64 were not covered by tests

log_prob <- function(x) {

Check warning on line 66 in R/discrete_lognormal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_lognormal.R#L66

Added line #L66 was not covered by tests

# build distribution object
d <- tfp$distributions$LogNormal(
loc = meanlog,
scale = sdlog
)

# for those lumped into groups,
# compute the bounds of the observed groups

Check warning on line 75 in R/discrete_lognormal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_lognormal.R#L69-L75

Added lines #L69 - L75 were not covered by tests
# and get tensors for the bounds in the format expected by TFP
x_safe <- tf$math$maximum(x, fl(.Machine$double.eps))
tf_idx <- tfp$stats$find_bins(x_safe, tf_breaks)
tf_idx_int <- tf_as_integer(tf_idx)
tf_lower_vec <- tf$gather(tf_lower_bounds, tf_idx_int)
tf_upper_vec <- tf$gather(tf_upper_bounds, tf_idx_int)

Check warning on line 81 in R/discrete_lognormal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_lognormal.R#L77-L81

Added lines #L77 - L81 were not covered by tests

# compute the density over the observed groups
low <- tf_safe_cdf(tf_lower_vec, d)
up <- tf_safe_cdf(tf_upper_vec, d)
log_density <- log(up - low)

}

Check warning on line 88 in R/discrete_lognormal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_lognormal.R#L84-L88

Added lines #L84 - L88 were not covered by tests

sample <- function(seed) {

d <- tfp$distributions$LogNormal(
loc = meanlog,
scale = sdlog
)
continuous <- d$sample(seed = seed)
# tf$floor(continuous)
tf$round(continuous)

Check warning on line 98 in R/discrete_lognormal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_lognormal.R#L90-L98

Added lines #L90 - L98 were not covered by tests

}

list(log_prob = log_prob, sample = sample)

Check warning on line 103 in R/discrete_lognormal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_lognormal.R#L100-L103

Added lines #L100 - L103 were not covered by tests
}

)
)
142 changes: 142 additions & 0 deletions R/discrete_normal.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#' @name discrete_normal
#' @title Discrete normal distribution
#'
#' @description a discretised normal distribution (i.e. sampled by applying
#' the round operation to samples from a normal). Due to the numerical
#' instability of integrating across the distribution, a vector of breaks
#' must be defined and the observations will be treated as censored
#' within those breaks
#'
#' @param mean unconstrained parameters giving the mean of the distribution
#' @param sd unconstrained parameters giving the standard deviation of the
#' distribution
#' @param breaks a vector of breaks; observations will be treated as censored
#' within those breaks
#' @param edges a vector of edges; length needs to be length(breaks)+1;
#' observations between two consecutive edges will be discretised to the
#' break value between the corresponding edges
#' @param dim a scalar giving the number of rows in the resulting greta array
#'
#' @importFrom R6 R6Class
#' @export

discrete_normal <- function(mean, sd, breaks, edges, dim = NULL) {
distrib("discrete_normal", mean, sd, breaks, edges, dim)
}

# define the discrete normal distribution
discrete_normal_distribution <- R6Class(
classname = "discrete_normal_distribution",
inherit = distribution_node,
public = list(

breaks = NA,
edges = NA,
lower_bounds = NA,
upper_bounds = NA,
lower_bound = NA,
upper_bound = NA,

initialize = function(mean, sd, breaks, edges, dim) {

mean <- as.greta_array(mean)
sd <- as.greta_array(sd)

# check length of breaks

Check warning on line 45 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L45

Added line #L45 was not covered by tests
check_break_length(breaks)

# add breaks, vector of lower and upper bounds, and the lower and upper
# bound of supported values
# self$breaks <- breaks
# self$lower_bounds <- breaks[-length(breaks)]

Check warning on line 51 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L47-L51

Added lines #L47 - L51 were not covered by tests
# self$upper_bounds <- breaks[-1]
# self$lower_bound <- min(breaks)

Check warning on line 53 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L53

Added line #L53 was not covered by tests
# self$upper_bound <- max(breaks)

# EXPERIMENTAL:
# convert breaks to edges, which
# will be used to gather samples at the breaks and convert them to
# rounded values
# while avoiding the use of round or floor, which assumes that
# the breaks are all the integers
# perhaps more crucially, we use edges, not breaks, to integrate the CDFs
# because we say that the likelihood of observing a break value is the sum
# of probability from the lower to upper edges surrounding a break

Check warning on line 64 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L63-L64

Added lines #L63 - L64 were not covered by tests
# first, create the lower and upper bounds that will bin a continuous
# variable into breaks

Check warning on line 66 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L66

Added line #L66 was not covered by tests
# using midpoint between breaks for now (from diff)
# generate edges using midpoint, -Inf, and Inf
# breaks_diff <- diff(breaks)
# edges <- c(-Inf, breaks[-length(breaks)] + breaks_diff/2, Inf)
self$edges <- edges
self$lower_bounds <- edges[-length(edges)]
self$upper_bounds <- edges[-1]
self$lower_bound <- min(edges)
self$upper_bound <- max(edges)

Check warning on line 75 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L69-L75

Added lines #L69 - L75 were not covered by tests

# add the nodes as parents and parameters
dim <- check_dims(mean, sd, target_dim = dim)
super$initialize("discrete_normal", dim, discrete = TRUE)
self$add_parameter(mean, "mean")
self$add_parameter(sd, "sd")

Check warning on line 81 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L77-L81

Added lines #L77 - L81 were not covered by tests

},

tf_distrib = function(parameters, dag) {

mean <- parameters$mean
sd <- parameters$sd

Check warning on line 88 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L84-L88

Added lines #L84 - L88 were not covered by tests

tf_breaks <- fl(self$breaks)
tf_edges <- fl(self$edges)
tf_lower_bounds <- fl(self$lower_bounds)
tf_upper_bounds <- fl(self$upper_bounds)
tf_lower_bound <- fl(self$lower_bound)
tf_upper_bound <- fl(self$upper_bound)

log_prob <- function(x) {

Check warning on line 98 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L90-L98

Added lines #L90 - L98 were not covered by tests
# build distribution object
d <- tfp$distributions$Normal(
loc = mean,
scale = sd
)

Check warning on line 103 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L100-L103

Added lines #L100 - L103 were not covered by tests

# for those lumped into groups,
# compute the bounds of the observed groups
# and get tensors for the bounds in the format expected by TFP
tf_idx <- tfp$stats$find_bins(x, tf_edges)
tf_idx_int <- tf_as_integer(tf_idx)
tf_lower_vec <- tf$gather(tf_lower_bounds, tf_idx_int)
tf_upper_vec <- tf$gather(tf_upper_bounds, tf_idx_int)

Check warning on line 111 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L108-L111

Added lines #L108 - L111 were not covered by tests

# compute the density over the observed groups
# note-to-self: this looks like https://mc-stan.org/docs/2_29/stan-users-guide/bayesian-measurement-error-model.html#rounding
low <- tf_safe_cdf(tf_lower_vec, d, tf_lower_bound, tf_upper_bound)
up <- tf_safe_cdf(tf_upper_vec, d, tf_lower_bound, tf_upper_bound)
log_density <- log(up - low)

Check warning on line 117 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L115-L117

Added lines #L115 - L117 were not covered by tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the simulated data has high variance, e.g., rnorm(..., sd = 10), I often get this error:

> draws <- mcmc(mod)
Error: Could not find reasonable starting values after 20 attempts.
Please specify initial values manually via the `initial_values` argument

The problem may lie in log_density <- log(up - low). Gut feeling is that high variance generates data that are more widely spaced, so when breaks are regularly spaced integers (see tf_idx), some bins have zero elements, then tf_safe_cdf has nothing to evaluate? Then we end up taking the log of zero... ???


}

Check warning on line 119 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L119

Added line #L119 was not covered by tests

sample <- function(seed) {

Check warning on line 121 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L121

Added line #L121 was not covered by tests

d <- tfp$distributions$Normal(
loc = mean,
scale = sd
)
continuous <- d$sample(seed = seed)

Check warning on line 127 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L123-L127

Added lines #L123 - L127 were not covered by tests

# gather samples at the breaks to convert them to rounded values
# ditto from what we did to breaks above
tf_edges_idx <- tfp$stats$find_bins(continuous, tf_edges)
tf_edges_idx_int <- tf_as_integer(tf_edges_idx)
tf$gather(tf_breaks, tf_edges_idx_int)

Check warning on line 133 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L131-L133

Added lines #L131 - L133 were not covered by tests

}

Check warning on line 135 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L135

Added line #L135 was not covered by tests

list(log_prob = log_prob, sample = sample)

Check warning on line 137 in R/discrete_normal.R

View check run for this annotation

Codecov / codecov/patch

R/discrete_normal.R#L137

Added line #L137 was not covered by tests

}

)
)
25 changes: 25 additions & 0 deletions R/tf_functions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# tensorflow functions

# CDF of the provided distribution, handling 0s and Infs
tf_safe_cdf <- function(x, distribution, lower_bound, upper_bound) {

# prepare to handle values outside the supported range
too_low <- tf$less(x, lower_bound)
too_high <- tf$greater_equal(x, upper_bound)
supported <- !too_low & !too_high
ones <- tf$ones_like(x)
zeros <- tf$zeros_like(x)

Check warning on line 11 in R/tf_functions.R

View check run for this annotation

Codecov / codecov/patch

R/tf_functions.R#L7-L11

Added lines #L7 - L11 were not covered by tests

# run cdf on supported values, and fill others in with the appropriate value
x_clean <- tf$where(supported, x, ones)
cdf_clean <- distribution$cdf(x_clean)
mask <- tf$where(supported, ones, zeros)
add <- tf$where(too_high, ones, zeros)
cdf_clean * mask + add

Check warning on line 18 in R/tf_functions.R

View check run for this annotation

Codecov / codecov/patch

R/tf_functions.R#L14-L18

Added lines #L14 - L18 were not covered by tests

}

# cast to integer
tf_as_integer <- function(x) {
tf$cast(x, tf$int32)

Check warning on line 24 in R/tf_functions.R

View check run for this annotation

Codecov / codecov/patch

R/tf_functions.R#L24

Added line #L24 was not covered by tests
}
2 changes: 1 addition & 1 deletion R/zzz.R
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# load tf probability
tfp <- reticulate::import("tensorflow_probability", delay_load = TRUE)
tfp <- reticulate::import("tensorflow_probability", delay_load = TRUE)
27 changes: 27 additions & 0 deletions man/discrete_lognormal.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 30 additions & 0 deletions man/discrete_normal.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading