generated from greta-dev/greta.template
-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add discrete_lognormal and discrete_normal #16
Open
hrlai
wants to merge
20
commits into
greta-dev:main
Choose a base branch
from
hrlai:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
336c601
match roxygen version and import R6
hrlai 3ec4b1e
first pass at moving, slighly refactoring, and documenting the discre…
hrlai b53e09f
implement test for discrete_lognormal
hrlai 1c35e6c
add checks for the breaks and dim arguments
hrlai 5c6cf8b
create home for .onLoad functions, copied from https://github.com/gre…
hrlai a3a4498
copy reticulate functions from https://github.com/greta-dev/greta/blo…
hrlai fcb31cf
remove check dims for now, I think we need to use the check_dims func…
hrlai 8a70b8b
add internal ::: call back to discrete_lognormal; is this ok or any b…
hrlai bcc9d84
add the necessary tf function for discrete lognormal; naming script i…
hrlai 7b9393f
switch from floor to round for a more general purpose use; this will …
hrlai 77a76a2
add discrete_normal and fix a typo in discrete_lognormal
hrlai 048f22e
add note to link Stan URL in case their specification is helpful
hrlai e2ae5b6
remove greta::: we don't seem to need it
hrlai 15cc772
remove hard fix of lower and upper bounds in tf_safe_cdf; remove pmax…
hrlai 49f6d0e
replace round in the sample function with tf$gather that discretise c…
hrlai 8c646a8
define tf_as_integer to avoid it being an unexported function
hrlai b9a412f
move check functions to a dedicated file
hrlai 8676719
experimenting with using edges, instead of breaks, for the integratio…
hrlai 9610368
use tf$greater_equal rather than tf$greater for the upper bound of CDFs
hrlai cd0f1f7
small fixes from merging into main
njtierney File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
# check functions | ||
|
||
|
||
check_break_length <- function(breaks){ | ||
if (length(breaks) <= 1) { | ||
msg <- cli::format_error( | ||
c( | ||
"{.var breaks} must be a vector with at least two break points", | ||
"but {.var breaks} has length {length(breaks)}" | ||
) | ||
) | ||
stop( | ||
msg, | ||
call. = FALSE | ||
) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
use_greta_conda_env <- function() { | ||
tryCatch( | ||
expr = reticulate::use_condaenv("greta-env", required = TRUE), | ||
error = function(e) NULL | ||
) | ||
} | ||
|
||
using_greta_conda_env <- function() { | ||
config <- reticulate::py_discover_config() | ||
grepl("greta-env", config$python) | ||
} | ||
|
||
have_greta_conda_env <- function(){ | ||
tryCatch( | ||
expr = "greta-env" %in% reticulate::conda_list()$name, | ||
error = function(e) FALSE | ||
) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
#' @name discrete_lognormal | ||
#' @title Discrete lognormal distribution | ||
#' | ||
#' @description a discretised lognormal distribution (i.e. sampled by applying | ||
#' the round operation to samples from a lognormal). Due to the numerical | ||
#' instability of integrating across the distribution, a vector of breaks | ||
#' must be defined and the observations will be treated as censored | ||
#' within those breaks | ||
#' | ||
#' @param meanlog unconstrained parameters giving the mean of the distribution | ||
#' on the log scale | ||
#' @param sdlog unconstrained parameters giving the standard deviation of the | ||
#' distribution on the log scale | ||
#' @param breaks a vector of breaks; observations will be treated as censored | ||
#' within those breaks | ||
#' @param dim a scalar giving the number of rows in the resulting greta array | ||
#' | ||
#' @importFrom R6 R6Class | ||
#' @export | ||
|
||
discrete_lognormal <- function(meanlog, sdlog, breaks, dim = NULL) { | ||
distrib("discrete_lognormal", meanlog, sdlog, breaks, dim) | ||
} | ||
|
||
# define the discrete lognormal distribution | ||
discrete_lognormal_distribution <- R6Class( | ||
classname = "discrete_lognormal_distribution", | ||
inherit = distribution_node, | ||
public = list( | ||
|
||
breaks = NA, | ||
lower_bounds = NA, | ||
upper_bounds = NA, | ||
|
||
initialize = function(meanlog, sdlog, breaks, dim) { | ||
|
||
meanlog <- as.greta_array(meanlog) | ||
sdlog <- as.greta_array(sdlog) | ||
|
||
# check length of breaks | ||
check_break_length(breaks) | ||
|
||
# handle gradient issue between sdlog and 0s | ||
breaks <- pmax(breaks, .Machine$double.eps) | ||
self$breaks <- breaks | ||
self$lower_bounds <- breaks[-length(breaks)] | ||
self$upper_bounds <- breaks[-1] | ||
|
||
# add the nodes as parents and parameters | ||
dim <- check_dims(meanlog, sdlog, target_dim = dim) | ||
super$initialize("discrete_lognormal", dim, discrete = TRUE) | ||
self$add_parameter(meanlog, "meanlog") | ||
self$add_parameter(sdlog, "sdlog") | ||
|
||
}, | ||
|
||
tf_distrib = function(parameters, dag) { | ||
|
||
meanlog <- parameters$meanlog | ||
sdlog <- parameters$sdlog | ||
|
||
tf_breaks <- fl(self$breaks) | ||
tf_lower_bounds <- fl(self$lower_bounds) | ||
tf_upper_bounds <- fl(self$upper_bounds) | ||
|
||
log_prob <- function(x) { | ||
|
||
# build distribution object | ||
d <- tfp$distributions$LogNormal( | ||
loc = meanlog, | ||
scale = sdlog | ||
) | ||
|
||
# for those lumped into groups, | ||
# compute the bounds of the observed groups | ||
# and get tensors for the bounds in the format expected by TFP | ||
x_safe <- tf$math$maximum(x, fl(.Machine$double.eps)) | ||
tf_idx <- tfp$stats$find_bins(x_safe, tf_breaks) | ||
tf_idx_int <- tf_as_integer(tf_idx) | ||
tf_lower_vec <- tf$gather(tf_lower_bounds, tf_idx_int) | ||
tf_upper_vec <- tf$gather(tf_upper_bounds, tf_idx_int) | ||
|
||
# compute the density over the observed groups | ||
low <- tf_safe_cdf(tf_lower_vec, d) | ||
up <- tf_safe_cdf(tf_upper_vec, d) | ||
log_density <- log(up - low) | ||
|
||
} | ||
|
||
sample <- function(seed) { | ||
|
||
d <- tfp$distributions$LogNormal( | ||
loc = meanlog, | ||
scale = sdlog | ||
) | ||
continuous <- d$sample(seed = seed) | ||
# tf$floor(continuous) | ||
tf$round(continuous) | ||
|
||
} | ||
|
||
list(log_prob = log_prob, sample = sample) | ||
|
||
} | ||
|
||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
#' @name discrete_normal | ||
#' @title Discrete normal distribution | ||
#' | ||
#' @description a discretised normal distribution (i.e. sampled by applying | ||
#' the round operation to samples from a normal). Due to the numerical | ||
#' instability of integrating across the distribution, a vector of breaks | ||
#' must be defined and the observations will be treated as censored | ||
#' within those breaks | ||
#' | ||
#' @param mean unconstrained parameters giving the mean of the distribution | ||
#' @param sd unconstrained parameters giving the standard deviation of the | ||
#' distribution | ||
#' @param breaks a vector of breaks; observations will be treated as censored | ||
#' within those breaks | ||
#' @param edges a vector of edges; length needs to be length(breaks)+1; | ||
#' observations between two consecutive edges will be discretised to the | ||
#' break value between the corresponding edges | ||
#' @param dim a scalar giving the number of rows in the resulting greta array | ||
#' | ||
#' @importFrom R6 R6Class | ||
#' @export | ||
|
||
discrete_normal <- function(mean, sd, breaks, edges, dim = NULL) { | ||
distrib("discrete_normal", mean, sd, breaks, edges, dim) | ||
} | ||
|
||
# define the discrete normal distribution | ||
discrete_normal_distribution <- R6Class( | ||
classname = "discrete_normal_distribution", | ||
inherit = distribution_node, | ||
public = list( | ||
|
||
breaks = NA, | ||
edges = NA, | ||
lower_bounds = NA, | ||
upper_bounds = NA, | ||
lower_bound = NA, | ||
upper_bound = NA, | ||
|
||
initialize = function(mean, sd, breaks, edges, dim) { | ||
|
||
mean <- as.greta_array(mean) | ||
sd <- as.greta_array(sd) | ||
|
||
# check length of breaks | ||
check_break_length(breaks) | ||
|
||
# add breaks, vector of lower and upper bounds, and the lower and upper | ||
# bound of supported values | ||
# self$breaks <- breaks | ||
# self$lower_bounds <- breaks[-length(breaks)] | ||
# self$upper_bounds <- breaks[-1] | ||
# self$lower_bound <- min(breaks) | ||
# self$upper_bound <- max(breaks) | ||
|
||
# EXPERIMENTAL: | ||
# convert breaks to edges, which | ||
# will be used to gather samples at the breaks and convert them to | ||
# rounded values | ||
# while avoiding the use of round or floor, which assumes that | ||
# the breaks are all the integers | ||
# perhaps more crucially, we use edges, not breaks, to integrate the CDFs | ||
# because we say that the likelihood of observing a break value is the sum | ||
# of probability from the lower to upper edges surrounding a break | ||
# first, create the lower and upper bounds that will bin a continuous | ||
# variable into breaks | ||
# using midpoint between breaks for now (from diff) | ||
# generate edges using midpoint, -Inf, and Inf | ||
# breaks_diff <- diff(breaks) | ||
# edges <- c(-Inf, breaks[-length(breaks)] + breaks_diff/2, Inf) | ||
self$edges <- edges | ||
self$lower_bounds <- edges[-length(edges)] | ||
self$upper_bounds <- edges[-1] | ||
self$lower_bound <- min(edges) | ||
self$upper_bound <- max(edges) | ||
|
||
# add the nodes as parents and parameters | ||
dim <- check_dims(mean, sd, target_dim = dim) | ||
super$initialize("discrete_normal", dim, discrete = TRUE) | ||
self$add_parameter(mean, "mean") | ||
self$add_parameter(sd, "sd") | ||
|
||
}, | ||
|
||
tf_distrib = function(parameters, dag) { | ||
|
||
mean <- parameters$mean | ||
sd <- parameters$sd | ||
|
||
tf_breaks <- fl(self$breaks) | ||
tf_edges <- fl(self$edges) | ||
tf_lower_bounds <- fl(self$lower_bounds) | ||
tf_upper_bounds <- fl(self$upper_bounds) | ||
tf_lower_bound <- fl(self$lower_bound) | ||
tf_upper_bound <- fl(self$upper_bound) | ||
|
||
log_prob <- function(x) { | ||
|
||
# build distribution object | ||
d <- tfp$distributions$Normal( | ||
loc = mean, | ||
scale = sd | ||
) | ||
|
||
# for those lumped into groups, | ||
# compute the bounds of the observed groups | ||
# and get tensors for the bounds in the format expected by TFP | ||
tf_idx <- tfp$stats$find_bins(x, tf_edges) | ||
tf_idx_int <- tf_as_integer(tf_idx) | ||
tf_lower_vec <- tf$gather(tf_lower_bounds, tf_idx_int) | ||
tf_upper_vec <- tf$gather(tf_upper_bounds, tf_idx_int) | ||
|
||
# compute the density over the observed groups | ||
# note-to-self: this looks like https://mc-stan.org/docs/2_29/stan-users-guide/bayesian-measurement-error-model.html#rounding | ||
low <- tf_safe_cdf(tf_lower_vec, d, tf_lower_bound, tf_upper_bound) | ||
up <- tf_safe_cdf(tf_upper_vec, d, tf_lower_bound, tf_upper_bound) | ||
log_density <- log(up - low) | ||
|
||
} | ||
|
||
sample <- function(seed) { | ||
|
||
d <- tfp$distributions$Normal( | ||
loc = mean, | ||
scale = sd | ||
) | ||
continuous <- d$sample(seed = seed) | ||
|
||
# gather samples at the breaks to convert them to rounded values | ||
# ditto from what we did to breaks above | ||
tf_edges_idx <- tfp$stats$find_bins(continuous, tf_edges) | ||
tf_edges_idx_int <- tf_as_integer(tf_edges_idx) | ||
tf$gather(tf_breaks, tf_edges_idx_int) | ||
|
||
} | ||
|
||
list(log_prob = log_prob, sample = sample) | ||
|
||
} | ||
|
||
) | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
# tensorflow functions | ||
|
||
# CDF of the provided distribution, handling 0s and Infs | ||
tf_safe_cdf <- function(x, distribution, lower_bound, upper_bound) { | ||
|
||
# prepare to handle values outside the supported range | ||
too_low <- tf$less(x, lower_bound) | ||
too_high <- tf$greater_equal(x, upper_bound) | ||
supported <- !too_low & !too_high | ||
ones <- tf$ones_like(x) | ||
zeros <- tf$zeros_like(x) | ||
|
||
# run cdf on supported values, and fill others in with the appropriate value | ||
x_clean <- tf$where(supported, x, ones) | ||
cdf_clean <- distribution$cdf(x_clean) | ||
mask <- tf$where(supported, ones, zeros) | ||
add <- tf$where(too_high, ones, zeros) | ||
cdf_clean * mask + add | ||
|
||
} | ||
|
||
# cast to integer | ||
tf_as_integer <- function(x) { | ||
tf$cast(x, tf$int32) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
# load tf probability | ||
tfp <- reticulate::import("tensorflow_probability", delay_load = TRUE) | ||
tfp <- reticulate::import("tensorflow_probability", delay_load = TRUE) |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When the simulated data has high variance, e.g.,
rnorm(..., sd = 10)
, I often get this error:The problem may lie in
log_density <- log(up - low)
. Gut feeling is that high variance generates data that are more widely spaced, so whenbreaks
are regularly spaced integers (seetf_idx
), some bins have zero elements, thentf_safe_cdf
has nothing to evaluate? Then we end up taking the log of zero... ???