Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add discrete_lognormal and discrete_normal #16

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
336c601
match roxygen version and import R6
hrlai Jul 26, 2022
3ec4b1e
first pass at moving, slighly refactoring, and documenting the discre…
hrlai Jul 26, 2022
b53e09f
implement test for discrete_lognormal
hrlai Jul 27, 2022
1c35e6c
add checks for the breaks and dim arguments
hrlai Jul 27, 2022
5c6cf8b
create home for .onLoad functions, copied from https://github.com/gre…
hrlai Jul 27, 2022
a3a4498
copy reticulate functions from https://github.com/greta-dev/greta/blo…
hrlai Jul 27, 2022
fcb31cf
remove check dims for now, I think we need to use the check_dims func…
hrlai Jul 27, 2022
8a70b8b
add internal ::: call back to discrete_lognormal; is this ok or any b…
hrlai Jul 27, 2022
bcc9d84
add the necessary tf function for discrete lognormal; naming script i…
hrlai Jul 27, 2022
7b9393f
switch from floor to round for a more general purpose use; this will …
hrlai Jul 28, 2022
77a76a2
add discrete_normal and fix a typo in discrete_lognormal
hrlai Jul 28, 2022
048f22e
add note to link Stan URL in case their specification is helpful
hrlai Aug 1, 2022
e2ae5b6
remove greta::: we don't seem to need it
hrlai Aug 1, 2022
15cc772
remove hard fix of lower and upper bounds in tf_safe_cdf; remove pmax…
hrlai Aug 3, 2022
49f6d0e
replace round in the sample function with tf$gather that discretise c…
hrlai Aug 3, 2022
8c646a8
define tf_as_integer to avoid it being an unexported function
hrlai Aug 3, 2022
b9a412f
move check functions to a dedicated file
hrlai Aug 3, 2022
8676719
experimenting with using edges, instead of breaks, for the integratio…
hrlai Aug 4, 2022
9610368
use tf$greater_equal rather than tf$greater for the upper bound of CDFs
hrlai Aug 4, 2022
cd0f1f7
small fixes from merging into main
njtierney Dec 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ Encoding: UTF-8
Language: en-GB
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.2.0
RoxygenNote: 7.2.1
SystemRequirements: Python (>= 2.7.0) with header files and shared
library; TensorFlow (v1.14; https://www.tensorflow.org/); TensorFlow
Probability (v0.7.0; https://www.tensorflow.org/probability/)
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Generated by roxygen2: do not edit by hand

export(discrete_lognormal)
export(discrete_normal)
import(R6)
importFrom(R6,R6Class)
importFrom(greta,.internals)
importFrom(tensorflow,tf)
17 changes: 17 additions & 0 deletions R/checks.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# check functions


check_break_length <- function(breaks){
if (length(breaks) <= 1) {
msg <- cli::format_error(
c(
"{.var breaks} must be a vector with at least two break points",
"but {.var breaks} has length {length(breaks)}"
)
)
stop(
msg,
call. = FALSE
)
}
}
18 changes: 18 additions & 0 deletions R/conda_greta_env.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
use_greta_conda_env <- function() {
tryCatch(
expr = reticulate::use_condaenv("greta-env", required = TRUE),
error = function(e) NULL
)
}

using_greta_conda_env <- function() {
config <- reticulate::py_discover_config()
grepl("greta-env", config$python)
}

have_greta_conda_env <- function(){
tryCatch(
expr = "greta-env" %in% reticulate::conda_list()$name,
error = function(e) FALSE
)
}
107 changes: 107 additions & 0 deletions R/discrete_lognormal.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
#' @name discrete_lognormal
#' @title Discrete lognormal distribution
#'
#' @description a discretised lognormal distribution (i.e. sampled by applying
#' the round operation to samples from a lognormal). Due to the numerical
#' instability of integrating across the distribution, a vector of breaks
#' must be defined and the observations will be treated as censored
#' within those breaks
#'
#' @param meanlog unconstrained parameters giving the mean of the distribution
#' on the log scale
#' @param sdlog unconstrained parameters giving the standard deviation of the
#' distribution on the log scale
#' @param breaks a vector of breaks; observations will be treated as censored
#' within those breaks
#' @param dim a scalar giving the number of rows in the resulting greta array
#'
#' @importFrom R6 R6Class
#' @export

discrete_lognormal <- function(meanlog, sdlog, breaks, dim = NULL) {
distrib("discrete_lognormal", meanlog, sdlog, breaks, dim)
}

# define the discrete lognormal distribution
discrete_lognormal_distribution <- R6Class(
classname = "discrete_lognormal_distribution",
inherit = distribution_node,
public = list(

breaks = NA,
lower_bounds = NA,
upper_bounds = NA,

initialize = function(meanlog, sdlog, breaks, dim) {

meanlog <- as.greta_array(meanlog)
sdlog <- as.greta_array(sdlog)

# check length of breaks
check_break_length(breaks)

# handle gradient issue between sdlog and 0s
breaks <- pmax(breaks, .Machine$double.eps)
self$breaks <- breaks
self$lower_bounds <- breaks[-length(breaks)]
self$upper_bounds <- breaks[-1]

# add the nodes as parents and parameters
dim <- check_dims(meanlog, sdlog, target_dim = dim)
super$initialize("discrete_lognormal", dim, discrete = TRUE)
self$add_parameter(meanlog, "meanlog")
self$add_parameter(sdlog, "sdlog")

},

tf_distrib = function(parameters, dag) {

meanlog <- parameters$meanlog
sdlog <- parameters$sdlog

tf_breaks <- fl(self$breaks)
tf_lower_bounds <- fl(self$lower_bounds)
tf_upper_bounds <- fl(self$upper_bounds)

log_prob <- function(x) {

# build distribution object
d <- tfp$distributions$LogNormal(
loc = meanlog,
scale = sdlog
)

# for those lumped into groups,
# compute the bounds of the observed groups
# and get tensors for the bounds in the format expected by TFP
x_safe <- tf$math$maximum(x, fl(.Machine$double.eps))
tf_idx <- tfp$stats$find_bins(x_safe, tf_breaks)
tf_idx_int <- tf_as_integer(tf_idx)
tf_lower_vec <- tf$gather(tf_lower_bounds, tf_idx_int)
tf_upper_vec <- tf$gather(tf_upper_bounds, tf_idx_int)

# compute the density over the observed groups
low <- tf_safe_cdf(tf_lower_vec, d)
up <- tf_safe_cdf(tf_upper_vec, d)
log_density <- log(up - low)

}

sample <- function(seed) {

d <- tfp$distributions$LogNormal(
loc = meanlog,
scale = sdlog
)
continuous <- d$sample(seed = seed)
# tf$floor(continuous)
tf$round(continuous)

}

list(log_prob = log_prob, sample = sample)

}

)
)
142 changes: 142 additions & 0 deletions R/discrete_normal.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
#' @name discrete_normal
#' @title Discrete normal distribution
#'
#' @description a discretised normal distribution (i.e. sampled by applying
#' the round operation to samples from a normal). Due to the numerical
#' instability of integrating across the distribution, a vector of breaks
#' must be defined and the observations will be treated as censored
#' within those breaks
#'
#' @param mean unconstrained parameters giving the mean of the distribution
#' @param sd unconstrained parameters giving the standard deviation of the
#' distribution
#' @param breaks a vector of breaks; observations will be treated as censored
#' within those breaks
#' @param edges a vector of edges; length needs to be length(breaks)+1;
#' observations between two consecutive edges will be discretised to the
#' break value between the corresponding edges
#' @param dim a scalar giving the number of rows in the resulting greta array
#'
#' @importFrom R6 R6Class
#' @export

discrete_normal <- function(mean, sd, breaks, edges, dim = NULL) {
distrib("discrete_normal", mean, sd, breaks, edges, dim)
}

# define the discrete normal distribution
discrete_normal_distribution <- R6Class(
classname = "discrete_normal_distribution",
inherit = distribution_node,
public = list(

breaks = NA,
edges = NA,
lower_bounds = NA,
upper_bounds = NA,
lower_bound = NA,
upper_bound = NA,

initialize = function(mean, sd, breaks, edges, dim) {

mean <- as.greta_array(mean)
sd <- as.greta_array(sd)

# check length of breaks
check_break_length(breaks)

# add breaks, vector of lower and upper bounds, and the lower and upper
# bound of supported values
# self$breaks <- breaks
# self$lower_bounds <- breaks[-length(breaks)]
# self$upper_bounds <- breaks[-1]
# self$lower_bound <- min(breaks)
# self$upper_bound <- max(breaks)

# EXPERIMENTAL:
# convert breaks to edges, which
# will be used to gather samples at the breaks and convert them to
# rounded values
# while avoiding the use of round or floor, which assumes that
# the breaks are all the integers
# perhaps more crucially, we use edges, not breaks, to integrate the CDFs
# because we say that the likelihood of observing a break value is the sum
# of probability from the lower to upper edges surrounding a break
# first, create the lower and upper bounds that will bin a continuous
# variable into breaks
# using midpoint between breaks for now (from diff)
# generate edges using midpoint, -Inf, and Inf
# breaks_diff <- diff(breaks)
# edges <- c(-Inf, breaks[-length(breaks)] + breaks_diff/2, Inf)
self$edges <- edges
self$lower_bounds <- edges[-length(edges)]
self$upper_bounds <- edges[-1]
self$lower_bound <- min(edges)
self$upper_bound <- max(edges)

# add the nodes as parents and parameters
dim <- check_dims(mean, sd, target_dim = dim)
super$initialize("discrete_normal", dim, discrete = TRUE)
self$add_parameter(mean, "mean")
self$add_parameter(sd, "sd")

},

tf_distrib = function(parameters, dag) {

mean <- parameters$mean
sd <- parameters$sd

tf_breaks <- fl(self$breaks)
tf_edges <- fl(self$edges)
tf_lower_bounds <- fl(self$lower_bounds)
tf_upper_bounds <- fl(self$upper_bounds)
tf_lower_bound <- fl(self$lower_bound)
tf_upper_bound <- fl(self$upper_bound)

log_prob <- function(x) {

# build distribution object
d <- tfp$distributions$Normal(
loc = mean,
scale = sd
)

# for those lumped into groups,
# compute the bounds of the observed groups
# and get tensors for the bounds in the format expected by TFP
tf_idx <- tfp$stats$find_bins(x, tf_edges)
tf_idx_int <- tf_as_integer(tf_idx)
tf_lower_vec <- tf$gather(tf_lower_bounds, tf_idx_int)
tf_upper_vec <- tf$gather(tf_upper_bounds, tf_idx_int)

# compute the density over the observed groups
# note-to-self: this looks like https://mc-stan.org/docs/2_29/stan-users-guide/bayesian-measurement-error-model.html#rounding
low <- tf_safe_cdf(tf_lower_vec, d, tf_lower_bound, tf_upper_bound)
up <- tf_safe_cdf(tf_upper_vec, d, tf_lower_bound, tf_upper_bound)
log_density <- log(up - low)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When the simulated data has high variance, e.g., rnorm(..., sd = 10), I often get this error:

> draws <- mcmc(mod)
Error: Could not find reasonable starting values after 20 attempts.
Please specify initial values manually via the `initial_values` argument

The problem may lie in log_density <- log(up - low). Gut feeling is that high variance generates data that are more widely spaced, so when breaks are regularly spaced integers (see tf_idx), some bins have zero elements, then tf_safe_cdf has nothing to evaluate? Then we end up taking the log of zero... ???


}

sample <- function(seed) {

d <- tfp$distributions$Normal(
loc = mean,
scale = sd
)
continuous <- d$sample(seed = seed)

# gather samples at the breaks to convert them to rounded values
# ditto from what we did to breaks above
tf_edges_idx <- tfp$stats$find_bins(continuous, tf_edges)
tf_edges_idx_int <- tf_as_integer(tf_edges_idx)
tf$gather(tf_breaks, tf_edges_idx_int)

}

list(log_prob = log_prob, sample = sample)

}

)
)
1 change: 1 addition & 0 deletions R/package.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#'
#' @importFrom tensorflow tf
#' @importFrom greta .internals
#' @import R6
#'
#' @examples
#'
Expand Down
25 changes: 25 additions & 0 deletions R/tf_functions.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# tensorflow functions

# CDF of the provided distribution, handling 0s and Infs
tf_safe_cdf <- function(x, distribution, lower_bound, upper_bound) {

# prepare to handle values outside the supported range
too_low <- tf$less(x, lower_bound)
too_high <- tf$greater_equal(x, upper_bound)
supported <- !too_low & !too_high
ones <- tf$ones_like(x)
zeros <- tf$zeros_like(x)

# run cdf on supported values, and fill others in with the appropriate value
x_clean <- tf$where(supported, x, ones)
cdf_clean <- distribution$cdf(x_clean)
mask <- tf$where(supported, ones, zeros)
add <- tf$where(too_high, ones, zeros)
cdf_clean * mask + add

}

# cast to integer
tf_as_integer <- function(x) {
tf$cast(x, tf$int32)
}
31 changes: 31 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# load tf probability
tfp <- reticulate::import("tensorflow_probability", delay_load = TRUE)

# crate the node list object whenever the package is loaded
.onLoad <- function(libname, pkgname) { # nolint

# unset reticulate python environment, for more details, see:
# https://github.com/greta-dev/greta/issues/444
Sys.unsetenv("RETICULATE_PYTHON")

if (have_greta_conda_env()) {
use_greta_conda_env()
}

# silence TF's CPU instructions message
Sys.setenv(TF_CPP_MIN_LOG_LEVEL = 2)

# silence messages about deprecation etc.
# disable_tensorflow_logging()

# warn if TF version is bad
# check_tf_version("startup")

# switch back to 0-based extraction in tensorflow, and don't warn about
# indexing with tensors
options(tensorflow.one_based_extract = FALSE)
options(tensorflow.extract.warn_tensors_passed_asis = FALSE)

# default float type
options(greta_tf_float = "float64")
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll need to test this, but I assume this isn't needed since this is called in greta, and greta.distributions has a dependency on greta, so effectively does library(greta) when it is loaded.

I believe that you will need the tfp part however, as that is not exported from greta

Loading