Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conditional bernouilli #9

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by roxygen2: do not edit by hand

export(conditional_bernoulli)
export(zero_inflated_negative_binomial)
export(zero_inflated_poisson)
importFrom(R6,R6Class)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# greta.distributions 0.0.0.9000

* Added a `NEWS.md` file to track changes to the package.
* Added `conditional_bernoulli` distribution (#5)
130 changes: 130 additions & 0 deletions R/conditional-bernoulli.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
#' @name conditional_bernoulli
#' @title conditional bernoulli distribution
#'
#' @description greta probability distribution over a K-dimensional vector of
#' binary variables, arising from independent Bernoulli draws each conditioned
#' on a single draw from another Bernoulli draw.
#'
#' @details A compound distribution, where elements of the bernoulli vector
#' variable _y_ can only be 1 if a scalar latent bernoulli variable
#' _z_ takes value 1, i.e.:
#'
#' \deqn{ y_i ~ bernoulli(z * p_i)}
#' \deqn{z ~ bernoulli(psi)}
#' where
#' \deqn{p_i = p(y_i = 1 | z = 1)}
#' \deqn{psi = p(z = 1)}
#'
#' _p_ and _psi_ are distinguishable provided there are multiple
#' trials in each observation of _y_. The density of this compound
#' distribution can be calculated directly, explicitly integrating over the
#' latent variable _z_, as:
#'
#' \deqn{psi * prod((p ^ y) * (1 - p) ^ (1 - y)) + max(y) * (1 - psi)}
#'
#' This formulation underpins the ecological imperfect-detection model of
#' MacKenzie et al. where _y_ and _p_ are vectors indicating whether
#' a species was detected at each visit, and the probability of detection
#' (which may vary between visits), and _z_ and _psi_ are scalars
#' indicating whether the species was present (assumed to be the same at all
#' visits) and the probability of being present.
#'
#' @references MacKenzie, D. I., Nichols, J. D., Lachman, G. B., Droege, S.,
#' Andrew Royle, J., & Langtimm, C. A. (2002). Estimating site occupancy rates
#' when detection probabilities are less than one. _Ecology_, 83(8),
#' 2248-2255.
#'
#' @param p matrix (of dimension `dim` x K) of (conditional) probabilities
#' of success
#' @param psi scalar or column vector (of length `dim`) of probabilities
#' for the latent bernoulli variable
#' @param dim a scalar giving the number of rows in the resulting greta array
#'
#' @export
#' @examples
#' \dontrun{
#' cb <- conditional_bernoulli(
#' p = matrix(c(0.1,0.9), ncol = 2),
#' psi = c(0.9),
#' dim = 1
#' )
#' cb
#' }
conditional_bernoulli <- function(p, psi, dim = 1) {
distrib("conditional_bernoulli", p, psi, dim)
}

# multivariate probit distribution
conditional_bernoulli_distribution <- R6::R6Class(
classname = "conditional_bernoulli_distribution",
inherit = distribution_node,
public = list(
initialize = function(p, psi, dim) {

# check that p and psi are between 0 and
p_val <- p
psi_val <- psi

# coerce to greta arrays
p <- as.greta_array(p)
psi <- as.greta_array(psi)

# check dimensions of p
check_if_2d_array(p)
check_if_2d_gte_two_col(p)
check_if_2d_one_col(psi)
# compare possible dimensions
check_params_same_rows(p, psi)

check_dim_positive_scalar_int(dim)

# check p and psi are between 0 and 1
check_valid_probability(p_val, var_name = "p")
check_valid_probability(psi_val, var_name = "psi")

# coerce the parameter arguments to nodes and add as children and
# parameters
super$initialize("conditional_bernoulli",
dim = c(dim, ncol(p)),
discrete = TRUE
)
self$add_parameter(p, "p")
self$add_parameter(psi, "psi")
},
tf_distrib = function(parameters) {
p <- parameters$p
psi <- parameters$psi

# return a tf function, taking the binary vector and returning the density

log_prob <- function(x) {

# for each row, were all elements 0?
none <- tf_as_float(tf_rowsums(x, 1L) == 0)

one <- fl(1)

# log conditional probability
# cp <- (p ^ x) * (one - p) ^ (one - x)
# log_cp <- log(cp)
log_cp <- x * log(p) + (one - x) * log(one - p)

# log probability
# log(rowProds(cp) * psi + nd * (1 - psi))
prob <- exp(tf_rowsums(log_cp, 1L) + log(psi)) + none * (one - psi)
log(prob)
}

list(
log_prob = log_prob,
sample = NULL,
cdf = NULL,
log_cdf = NULL
)
},

# no CDF for multivariate distributions
tf_cdf_function = NULL,
tf_log_cdf_function = NULL
)
)
1 change: 0 additions & 1 deletion R/internals.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# need some internal greta functions accessible

check_dims <- .internals$checks$check_dims
check_in_family <- .internals$checks$check_in_family
check_positive <- .internals$checks$check_positive
Expand Down
123 changes: 123 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@

n_dim <- function(x) length(dim(x))

check_if_2d_array <- function(x,
call = rlang::caller_env()){
if (n_dim(x) != 2) {
cli::cli_abort(
c(
"{.var x} must be a 2D array",
"but {.var x} has dimensions {paste(dim(x), collapse = 'x')}"
),
call = call
)
}
}

check_n_col_gte <- function(x,
n_col,
call = rlang::caller_env()){
if (ncol(x) < n_col) {
cli::cli_abort(
c(
"{.var x} must have at least two columns",
"but {.var x} has {ncol(x)} columns"
),
call = call
)
}
}

check_valid_probability <- function(x,
var_name = "x",
call = rlang::caller_env()) {
if (any(x < 0) | any(x > 1)) {

first_line <- glue::glue(
"{.var [var_name]} must be a valid probability - between 0 and 1",
.open = "[",
.close = "]"
)
second_line <- glue::glue(
"We see {.var [var_name]} = {x}",
.open = "[",
.close = "]"
)
cli::cli_abort(
c(
first_line,
second_line
),
call = call
)
}
}

check_if_2d_gte_two_col <- function(p,
call = rlang::caller_env()){
does_not_have_at_least_two_cols <- ncol(p) < 2 | length(dim(p)) != 2
if (does_not_have_at_least_two_cols) {
cli::cli_abort(
c(
"{.var p} must be a 2D array with at least two columns",
"but {.var p} has dimensions {paste(dim(p), collapse = 'x')}"
),
call = call
)
}
}

check_if_2d_one_col <- function(psi,
call = rlang::caller_env()){
# check dimensions of psi
not_2d_or_one_col <- ncol(psi) != 1 | length(dim(psi)) != 2
if (not_2d_or_one_col) {
cli::cli_abort(
c(
"{.var psi} must be a 2D array with one column",
"but {.var psi} has dimensions {paste(dim(psi), collapse = 'x')}"
),
call = call
)
}
}

check_params_same_rows <- function(p,
psi,
call = rlang::caller_env()){
dim_p <- nrow(p)
dim_psi <- nrow(psi)

not_same_nrows <- dim_p != dim_psi

if (not_same_nrows) {
cli::cli_abort(
c(
"{.var p} and {.var psi} must have the same number of rows",
"But we see {.var p} and {.var psi} have:",
"{.var p}: {dim_p} {?row/rows}",
"{.var psi}: {dim_psi} {?row/rows}",
"Perhaps you need to coerce {.var p} or {.var psi} to an \\
appropriate matrix?"
),
call = call
)
}
}

check_dim_positive_scalar_int <- function(dim,
call = rlang::caller_env()){
# check dim is a positive scalar integer
dim_old <- dim
dim <- as.integer(dim)
not_scalar_positive_integer <- length(dim) > 1 || dim <= 0 || !is.finite(dim)
if (not_scalar_positive_integer) {
cli::cli_abort(
c(
"{.var dim} must be a scalar positive integer, but was:",
"{dim_old}"
),
call = call
)
}
}
63 changes: 63 additions & 0 deletions man/conditional_bernoulli.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/greta.distributions.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 2 additions & 4 deletions tests/spelling.R
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
if(requireNamespace('spelling', quietly = TRUE)) {

if (requireNamespace("spelling", quietly = TRUE)) {
spelling::spell_check_test(
vignettes = TRUE,
error = FALSE,
skip_on_cran = TRUE
)

}
}
39 changes: 39 additions & 0 deletions tests/testthat/_snaps/conditional_bernoulli.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# conditional_bernoulli fails when given the wrong argument dimensions

`p` must be a 2D array with at least two columns
but `p` has dimensions 1x1

---

`p` and `psi` must have the same number of rows
But we see `p` and `psi` have:
`p`: 1 row
`psi`: 2 rows
Perhaps you need to coerce `p` or `psi` to an appropriate matrix?

---

`p` and `psi` must have the same number of rows
But we see `p` and `psi` have:
`p`: 1 row
`psi`: 2 rows
Perhaps you need to coerce `p` or `psi` to an appropriate matrix?

---

`p` and `psi` must have the same number of rows
But we see `p` and `psi` have:
`p`: 1 row
`psi`: 2 rows
Perhaps you need to coerce `p` or `psi` to an appropriate matrix?

---

`dim` must be a scalar positive integer, but was:
0.1

---

`dim` must be a scalar positive integer, but was:
0.1

Loading
Loading