From 0defdce64a842e829511c0c0934f2156021b5f41 Mon Sep 17 00:00:00 2001
From: olivuntu <olivier.flores@univ-reunion.fr>
Date: Fri, 22 Jul 2022 06:57:22 +0400
Subject: [PATCH 01/19] Add ZIP and ZINB distributions

---
 R/probability_distributions.R       | 1743 +++++++++++++++++++++++++++
 man/distributions.Rd                |  254 ++++
 tests/testthat/helpers.R            |  212 ++--
 tests/testthat/test_distributions.R | 1178 ++++++++++++++++++
 4 files changed, 3291 insertions(+), 96 deletions(-)
 create mode 100644 R/probability_distributions.R
 create mode 100644 man/distributions.Rd
 create mode 100644 tests/testthat/test_distributions.R

diff --git a/R/probability_distributions.R b/R/probability_distributions.R
new file mode 100644
index 0000000..333530f
--- /dev/null
+++ b/R/probability_distributions.R
@@ -0,0 +1,1743 @@
+uniform_distribution <- R6Class(
+  "uniform_distribution",
+  inherit = distribution_node,
+  public = list(
+    min = NA,
+    max = NA,
+    initialize = function(min, max, dim) {
+      if (inherits(min, "greta_array") | inherits(max, "greta_array")) {
+        msg <- cli::format_error(
+          "{.arg min} and {.arg max} must be fixed, they cannot be another \\
+          greta array"
+        )
+        stop(
+          msg,
+          call. = FALSE
+        )
+      }
+
+      good_types <- is.numeric(min) && length(min) == 1 &
+        is.numeric(max) && length(max) == 1
+
+      if (!good_types) {
+        msg <- cli::format_error(
+          c(
+            "{.arg min} and {.arg max} must be numeric vectors of length 1",
+            "They have class and length:",
+            "{.arg min}: {class(min)}, {length(min)}",
+            "{.arg max}: {class(max)}, {length(max)}"
+          )
+        )
+        stop(
+          msg,
+          call. = FALSE
+        )
+      }
+
+      if (!is.finite(min) | !is.finite(max)) {
+        msg <- cli::format_error(
+          c(
+            "{.arg min} and {.arg max} must finite scalars",
+            "Their values are:",
+            "{.arg min}: {min}",
+            "{.arg max}: {max}"
+          )
+        )
+        stop(
+          msg,
+          call. = FALSE
+        )
+      }
+
+      if (min >= max) {
+        msg <- cli::format_error(
+          c(
+            "{.arg max} must be greater than {.arg min}",
+            "Their values are:",
+            "{.arg min}: {min}",
+            "{.arg max}: {max}"
+          )
+        )
+        stop(
+          msg,
+          call. = FALSE
+        )
+      }
+
+      # store min and max as numeric scalars (needed in create_target, done in
+      # initialisation)
+      self$min <- min
+      self$max <- max
+      self$bounds <- c(min, max)
+
+      # initialize the rest
+      super$initialize("uniform", dim)
+
+      # add them as parents and greta arrays
+      min <- as.greta_array(min)
+      max <- as.greta_array(max)
+      self$add_parameter(min, "min")
+      self$add_parameter(max, "max")
+    },
+
+    # default value (ignore any truncation arguments)
+    create_target = function(...) {
+      vble(
+        truncation = c(self$min, self$max),
+        dim = self$dim
+      )
+    },
+    tf_distrib = function(parameters, dag) {
+      tfp$distributions$Uniform(
+        low = parameters$min,
+        high = parameters$max
+      )
+    }
+  )
+)
+
+normal_distribution <- R6Class(
+  "normal_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(mean, sd, dim, truncation) {
+      mean <- as.greta_array(mean)
+      sd <- as.greta_array(sd)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(mean, sd, target_dim = dim)
+      super$initialize("normal", dim, truncation)
+      self$add_parameter(mean, "mean")
+      self$add_parameter(sd, "sd")
+    },
+    tf_distrib = function(parameters, dag) {
+      tfp$distributions$Normal(
+        loc = parameters$mean,
+        scale = parameters$sd
+      )
+    }
+  )
+)
+
+lognormal_distribution <- R6Class(
+  "lognormal_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(meanlog, sdlog, dim, truncation) {
+      meanlog <- as.greta_array(meanlog)
+      sdlog <- as.greta_array(sdlog)
+
+      dim <- check_dims(meanlog, sdlog, target_dim = dim)
+      check_positive(truncation)
+      self$bounds <- c(0, Inf)
+      super$initialize("lognormal", dim, truncation)
+      self$add_parameter(meanlog, "meanlog")
+      self$add_parameter(sdlog, "sdlog")
+    },
+
+    # nolint start
+    tf_distrib = function(parameters, dag) {
+      tfp$distributions$LogNormal(
+        loc = parameters$meanlog,
+        scale = parameters$sdlog
+      )
+    }
+    # nolint end
+  )
+)
+
+bernoulli_distribution <- R6Class(
+  "bernoulli_distribution",
+  inherit = distribution_node,
+  public = list(
+    prob_is_logit = FALSE,
+    prob_is_probit = FALSE,
+    initialize = function(prob, dim) {
+      prob <- as.greta_array(prob)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(prob, target_dim = dim)
+      super$initialize("bernoulli", dim, discrete = TRUE)
+
+      if (has_representation(prob, "logit")) {
+        prob <- representation(prob, "logit")
+        self$prob_is_logit <- TRUE
+      } else if (has_representation(prob, "probit")) {
+        prob <- representation(prob, "probit")
+        self$prob_is_probit <- TRUE
+      }
+
+      self$add_parameter(prob, "prob")
+    },
+    tf_distrib = function(parameters, dag) {
+      if (self$prob_is_logit) {
+        tfp$distributions$Bernoulli(logits = parameters$prob)
+      } else if (self$prob_is_probit) {
+
+        # in the probit case, get the log probability of success and compute the
+        # log prob directly
+        probit <- parameters$prob
+        d <- tfp$distributions$Normal(fl(0), fl(1))
+        lprob <- d$log_cdf(probit)
+        lprobnot <- d$log_cdf(-probit)
+
+        log_prob <- function(x) {
+          x * lprob + (fl(1) - x) * lprobnot
+        }
+
+        list(log_prob = log_prob)
+      } else {
+        tfp$distributions$Bernoulli(probs = parameters$prob)
+      }
+    }
+  )
+)
+
+binomial_distribution <- R6Class(
+  "binomial_distribution",
+  inherit = distribution_node,
+  public = list(
+    prob_is_logit = FALSE,
+    prob_is_probit = FALSE,
+    initialize = function(size, prob, dim) {
+      size <- as.greta_array(size)
+      prob <- as.greta_array(prob)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(size, prob, target_dim = dim)
+      super$initialize("binomial", dim, discrete = TRUE)
+
+      if (has_representation(prob, "logit")) {
+        prob <- representation(prob, "logit")
+        self$prob_is_logit <- TRUE
+      } else if (has_representation(prob, "probit")) {
+        prob <- representation(prob, "probit")
+        self$prob_is_probit <- TRUE
+      }
+
+      self$add_parameter(prob, "prob")
+      self$add_parameter(size, "size")
+    },
+    tf_distrib = function(parameters, dag) {
+      if (self$prob_is_logit) {
+        tfp$distributions$Binomial(
+          total_count = parameters$size,
+          logits = parameters$prob
+        )
+      } else if (self$prob_is_probit) {
+
+        # in the probit case, get the log probability of success and compute the
+        # log prob directly
+        size <- parameters$size
+        probit <- parameters$prob
+        d <- tfp$distributions$Normal(fl(0), fl(1))
+        lprob <- d$log_cdf(probit)
+        lprobnot <- d$log_cdf(-probit)
+
+        log_prob <- function(x) {
+          log_choose <- tf$math$lgamma(size + fl(1)) -
+            tf$math$lgamma(x + fl(1)) -
+            tf$math$lgamma(size - x + fl(1))
+          log_choose + x * lprob + (size - x) * lprobnot
+        }
+
+        list(log_prob = log_prob)
+      } else {
+        tfp$distributions$Binomial(
+          total_count = parameters$size,
+          probs = parameters$prob
+        )
+      }
+    }
+  )
+)
+
+beta_binomial_distribution <- R6Class(
+  "beta_binomial_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(size, alpha, beta, dim) {
+      size <- as.greta_array(size)
+      alpha <- as.greta_array(alpha)
+      beta <- as.greta_array(beta)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(size, alpha, beta, target_dim = dim)
+      super$initialize("beta_binomial", dim, discrete = TRUE)
+      self$add_parameter(size, "size")
+      self$add_parameter(alpha, "alpha")
+      self$add_parameter(beta, "beta")
+    },
+    tf_distrib = function(parameters, dag) {
+      size <- parameters$size
+      alpha <- parameters$alpha
+      beta <- parameters$beta
+
+      log_prob <- function(x) {
+        tf_lchoose(size, x) +
+          tf_lbeta(x + alpha, size - x + beta) -
+          tf_lbeta(alpha, beta)
+      }
+
+      # generate a beta, then a binomial
+      sample <- function(seed) {
+        beta <- tfp$distributions$Beta(
+          concentration1 = alpha,
+          concentration0 = beta
+        )
+        probs <- beta$sample(seed = seed)
+        binomial <- tfp$distributions$Binomial(
+          total_count = size,
+          probs = probs
+        )
+        binomial$sample(seed = seed)
+      }
+
+      list(log_prob = log_prob, sample = sample)
+    }
+  )
+)
+
+poisson_distribution <- R6Class(
+  "poisson_distribution",
+  inherit = distribution_node,
+  public = list(
+    lambda_is_log = FALSE,
+    initialize = function(lambda, dim) {
+      lambda <- as.greta_array(lambda)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(lambda, target_dim = dim)
+      super$initialize("poisson", dim, discrete = TRUE)
+
+      if (has_representation(lambda, "log")) {
+        lambda <- representation(lambda, "log")
+        self$lambda_is_log <- TRUE
+      }
+      self$add_parameter(lambda, "lambda")
+    },
+    tf_distrib = function(parameters, dag) {
+      if (self$lambda_is_log) {
+        log_lambda <- parameters$lambda
+      } else {
+        log_lambda <- tf$math$log(parameters$lambda)
+      }
+
+      tfp$distributions$Poisson(log_rate = log_lambda)
+    }
+  )
+)
+
+negative_binomial_distribution <- R6Class(
+  "negative_binomial_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(size, prob, dim) {
+      size <- as.greta_array(size)
+      prob <- as.greta_array(prob)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(size, prob, target_dim = dim)
+      super$initialize("negative_binomial", dim, discrete = TRUE)
+      self$add_parameter(size, "size")
+      self$add_parameter(prob, "prob")
+    },
+
+    # nolint start
+    tf_distrib = function(parameters, dag) {
+      tfp$distributions$NegativeBinomial(
+        total_count = parameters$size,
+        probs = fl(1) - parameters$prob
+      )
+    }
+    # nolint end
+  )
+)
+
+zero_inflated_poisson_distribution <- R6Class(
+  "zero_inflated_poisson_distribution",
+  inherit = greta::.internals$nodes$node_classes$distribution_node,
+  public = list(
+    initialize = function(theta, lambda, dim) {
+      theta <- as.greta_array(theta)
+      lambda <- as.greta_array(lambda)
+      # add the nodes as children and parameters
+      dim <- check_dims(theta, lambda, target_dim = dim)
+      super$initialize("zero_inflated_poisson", dim, discrete = TRUE)
+      self$add_parameter(theta, "theta")
+      self$add_parameter(lambda, "lambda")
+    },
+
+    tf_distrib = function(parameters, dag) {
+      theta <- parameters$theta
+      lambda <- parameters$lambda
+      log_prob <- function(x) {
+
+        tf$math$log(theta * tf$nn$relu(fl(1) - x) + (fl(1) - theta) * tf$pow(lambda, x) * tf$exp(-lambda) / tf$exp(tf$math$lgamma(x + fl(1))))
+      }
+
+      sample <- function(seed) {
+
+        binom <- tfp$distributions$Binomial(total_count = 1, probs = theta)
+        pois <- tfp$distributions$Poisson(rate = lambda)
+
+        zi <- binom$sample(seed = seed)
+        lbd <- pois$sample(seed = seed)
+
+        (fl(1) - zi) * lbd
+
+      }
+
+      list(log_prob = log_prob, sample = sample, cdf = NULL, log_cdf = NULL)
+    },
+
+    tf_cdf_function = NULL,
+    tf_log_cdf_function = NULL
+  )
+)
+
+
+zero_inflated_negative_binomial_distribution <- R6Class(
+  "zero_inflated_negative_binomial_distribution",
+  inherit = greta::.internals$nodes$node_classes$distribution_node,
+  public = list(
+    initialize = function(theta, size, prob, dim) {
+      theta <- as.greta_array(theta)
+      size <- as.greta_array(size)
+      prob <- as.greta_array(prob)
+      # add the nodes as children and parameters
+      dim <- check_dims(theta, size, prob, target_dim = dim)
+      super$initialize("zero_inflated_negative_binomial", dim, discrete = TRUE)
+      self$add_parameter(theta, "theta")
+      self$add_parameter(size, "size")
+      self$add_parameter(prob, "prob")
+    },
+  
+    tf_distrib = function(parameters, dag) {
+      theta <- parameters$theta
+      size <- parameters$size
+      p <- parameters$prob # probability of success
+      q <- fl(1) - parameters$prob 
+      log_prob <- function(x) {
+
+        tf$math$log(theta * tf$nn$relu(fl(1) - x) + (fl(1) - theta) * tf$pow(p, size) * tf$pow(q, x) * tf$exp(tf$math$lgamma(x + size)) / tf$exp(tf$math$lgamma(size)) / tf$exp(tf$math$lgamma(x + fl(1))))
+
+      }
+
+      sample <- function(seed) {
+
+        binom <- tfp$distributions$Binomial(total_count = 1, probs = theta)
+        negbin <- tfp$distributions$NegativeBinomial(total_count = size, probs = q) # change of proba / parametrisation in 'stats'
+
+        zi <- binom$sample(seed = seed)
+        lbd <- negbin$sample(seed = seed)
+
+        (fl(1) - zi) * lbd
+
+      }
+
+      list(log_prob = log_prob, sample = sample, cdf = NULL, log_cdf = NULL)
+    },
+
+    tf_cdf_function = NULL,
+    tf_log_cdf_function = NULL
+  )
+)
+
+
+
+hypergeometric_distribution <- R6Class(
+  "hypergeometric_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(m, n, k, dim) {
+      m <- as.greta_array(m)
+      n <- as.greta_array(n)
+      k <- as.greta_array(k)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(m, n, k, target_dim = dim)
+      super$initialize("hypergeometric", dim, discrete = TRUE)
+      self$add_parameter(m, "m")
+      self$add_parameter(n, "n")
+      self$add_parameter(k, "k")
+    },
+    tf_distrib = function(parameters, dag) {
+      m <- parameters$m
+      n <- parameters$n
+      k <- parameters$k
+
+      log_prob <- function(x) {
+        tf_lchoose(m, x) +
+          tf_lchoose(n, k - x) -
+          tf_lchoose(m + n, k)
+      }
+
+      list(log_prob = log_prob)
+    }
+  )
+)
+
+gamma_distribution <- R6Class(
+  "gamma_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(shape, rate, dim, truncation) {
+      shape <- as.greta_array(shape)
+      rate <- as.greta_array(rate)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(shape, rate, target_dim = dim)
+      check_positive(truncation)
+      self$bounds <- c(0, Inf)
+      super$initialize("gamma", dim, truncation)
+      self$add_parameter(shape, "shape")
+      self$add_parameter(rate, "rate")
+    },
+    tf_distrib = function(parameters, dag) {
+      tfp$distributions$Gamma(
+        concentration = parameters$shape,
+        rate = parameters$rate
+      )
+    }
+  )
+)
+
+inverse_gamma_distribution <- R6Class(
+  "inverse_gamma_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(alpha, beta, dim, truncation) {
+      alpha <- as.greta_array(alpha)
+      beta <- as.greta_array(beta)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(alpha, beta, target_dim = dim)
+      check_positive(truncation)
+      self$bounds <- c(0, Inf)
+      super$initialize("inverse_gamma", dim, truncation)
+      self$add_parameter(alpha, "alpha")
+      self$add_parameter(beta, "beta")
+    },
+
+    # nolint start
+    tf_distrib = function(parameters, dag) {
+      tfp$distributions$InverseGamma(
+        concentration = parameters$alpha,
+        rate = parameters$beta
+      )
+    }
+    # nolint end
+  )
+)
+
+weibull_distribution <- R6Class(
+  "weibull_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(shape, scale, dim, truncation) {
+      shape <- as.greta_array(shape)
+      scale <- as.greta_array(scale)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(shape, scale, target_dim = dim)
+      check_positive(truncation)
+      self$bounds <- c(0, Inf)
+      super$initialize("weibull", dim, truncation)
+      self$add_parameter(shape, "shape")
+      self$add_parameter(scale, "scale")
+    },
+    tf_distrib = function(parameters, dag) {
+      a <- parameters$shape
+      b <- parameters$scale
+
+      # use the TFP Weibull CDF bijector
+      bijector <- tfp$bijectors$Weibull(scale = b, concentration = a)
+
+      log_prob <- function(x) {
+        log(a) - log(b) + (a - fl(1)) * (log(x) - log(b)) - (x / b)^a
+      }
+
+      cdf <- function(x) {
+        bijector$forward(x)
+      }
+
+      log_cdf <- function(x) {
+        log(cdf(x))
+      }
+
+      quantile <- function(x) {
+        bijector$inverse(x)
+      }
+
+      sample <- function(seed) {
+
+        # sample by pushing standard uniforms through the inverse cdf
+        u <- tf_randu(self$dim, dag)
+        quantile(u)
+      }
+
+      list(
+        log_prob = log_prob,
+        cdf = cdf,
+        log_cdf = log_cdf,
+        quantile = quantile,
+        sample = sample
+      )
+    }
+  )
+)
+
+exponential_distribution <- R6Class(
+  "exponential_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(rate, dim, truncation) {
+      rate <- as.greta_array(rate)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(rate, target_dim = dim)
+      check_positive(truncation)
+      self$bounds <- c(0, Inf)
+      super$initialize("exponential", dim, truncation)
+      self$add_parameter(rate, "rate")
+    },
+    tf_distrib = function(parameters, dag) {
+      tfp$distributions$Exponential(rate = parameters$rate)
+    }
+  )
+)
+
+pareto_distribution <- R6Class(
+  "pareto_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(a, b, dim, truncation) {
+      a <- as.greta_array(a)
+      b <- as.greta_array(b)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(a, b, target_dim = dim)
+      check_positive(truncation)
+      self$bounds <- c(0, Inf)
+      super$initialize("pareto", dim, truncation)
+      self$add_parameter(a, "a")
+      self$add_parameter(b, "b")
+    },
+    tf_distrib = function(parameters, dag) {
+
+      # a is shape, b is scale
+      tfp$distributions$Pareto(
+        concentration = parameters$a,
+        scale = parameters$b
+      )
+    }
+  )
+)
+
+student_distribution <- R6Class(
+  "student_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(df, mu, sigma, dim, truncation) {
+      df <- as.greta_array(df)
+      mu <- as.greta_array(mu)
+      sigma <- as.greta_array(sigma)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(df, mu, sigma, target_dim = dim)
+      super$initialize("student", dim, truncation)
+      self$add_parameter(df, "df")
+      self$add_parameter(mu, "mu")
+      self$add_parameter(sigma, "sigma")
+    },
+
+    # nolint start
+    tf_distrib = function(parameters, dag) {
+      tfp$distributions$StudentT(
+        df = parameters$df,
+        loc = parameters$mu,
+        scale = parameters$sigma
+      )
+    }
+    # nolint end
+  )
+)
+
+laplace_distribution <- R6Class(
+  "laplace_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(mu, sigma, dim, truncation) {
+      mu <- as.greta_array(mu)
+      sigma <- as.greta_array(sigma)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(mu, sigma, target_dim = dim)
+      super$initialize("laplace", dim, truncation)
+      self$add_parameter(mu, "mu")
+      self$add_parameter(sigma, "sigma")
+    },
+    tf_distrib = function(parameters, dag) {
+      tfp$distributions$Laplace(
+        loc = parameters$mu,
+        scale = parameters$sigma
+      )
+    }
+  )
+)
+
+beta_distribution <- R6Class(
+  "beta_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(shape1, shape2, dim, truncation) {
+      shape1 <- as.greta_array(shape1)
+      shape2 <- as.greta_array(shape2)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(shape1, shape2, target_dim = dim)
+      check_unit(truncation)
+      self$bounds <- c(0, 1)
+      super$initialize("beta", dim, truncation)
+      self$add_parameter(shape1, "shape1")
+      self$add_parameter(shape2, "shape2")
+    },
+    tf_distrib = function(parameters, dag) {
+      tfp$distributions$Beta(
+        concentration1 = parameters$shape1,
+        concentration0 = parameters$shape2
+      )
+    }
+  )
+)
+
+cauchy_distribution <- R6Class(
+  "cauchy_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(location, scale, dim, truncation) {
+      location <- as.greta_array(location)
+      scale <- as.greta_array(scale)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(location, scale, target_dim = dim)
+      super$initialize("cauchy", dim, truncation)
+      self$add_parameter(location, "location")
+      self$add_parameter(scale, "scale")
+    },
+    tf_distrib = function(parameters, dag) {
+      tfp$distributions$Cauchy(
+        loc = parameters$location,
+        scale = parameters$scale
+      )
+    }
+  )
+)
+
+chi_squared_distribution <- R6Class(
+  "chi_squared_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(df, dim, truncation) {
+      df <- as.greta_array(df)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(df, target_dim = dim)
+      check_positive(truncation)
+      self$bounds <- c(0, Inf)
+      super$initialize("chi_squared", dim, truncation)
+      self$add_parameter(df, "df")
+    },
+    tf_distrib = function(parameters, dag) {
+      tfp$distributions$Chi2(df = parameters$df)
+    }
+  )
+)
+
+logistic_distribution <- R6Class(
+  "logistic_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(location, scale, dim, truncation) {
+      location <- as.greta_array(location)
+      scale <- as.greta_array(scale)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(location, scale, target_dim = dim)
+      super$initialize("logistic", dim, truncation)
+      self$add_parameter(location, "location")
+      self$add_parameter(scale, "scale")
+    },
+    tf_distrib = function(parameters, dag) {
+      tfp$distributions$Logistic(
+        loc = parameters$location,
+        scale = parameters$scale
+      )
+    }
+  )
+)
+
+f_distribution <- R6Class(
+  "f_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(df1, df2, dim, truncation) {
+      df1 <- as.greta_array(df1)
+      df2 <- as.greta_array(df2)
+
+      # add the nodes as parents and parameters
+      dim <- check_dims(df1, df2, target_dim = dim)
+      check_positive(truncation)
+      self$bounds <- c(0, Inf)
+      super$initialize("f", dim, truncation)
+      self$add_parameter(df1, "df1")
+      self$add_parameter(df2, "df2")
+    },
+    tf_distrib = function(parameters, dag) {
+      df1 <- parameters$df1
+      df2 <- parameters$df2
+
+      tf_lbeta <- function(a, b) {
+        tf$math$lgamma(a) + tf$math$lgamma(b) - tf$math$lgamma(a + b)
+      }
+
+      log_prob <- function(x) {
+        df1_x <- df1 * x
+        la <- df1 * log(df1_x) + df2 * log(df2)
+        lb <- (df1 + df2) * log(df1_x + df2)
+        lnumerator <- fl(0.5) * (la - lb)
+        lnumerator - log(x) - tf_lbeta(df1 / fl(2), df2 / fl(2))
+      }
+
+      cdf <- function(x) {
+        df1_x <- df1 * x
+        ratio <- df1_x / (df1_x + df2)
+        tf$math$betainc(df1 / fl(2), df2 / fl(2), ratio)
+      }
+
+      log_cdf <- function(x) {
+        log(cdf(x))
+      }
+
+      sample <- function(seed) {
+
+        # sample as the ratio of two scaled chi squared distributions
+        d1 <- tfp$distributions$Chi2(df = df1)
+        d2 <- tfp$distributions$Chi2(df = df2)
+
+        u1 <- d1$sample(seed = seed)
+        u2 <- d2$sample(seed = seed)
+
+        (u1 / df1) / (u2 / df2)
+      }
+
+      list(
+        log_prob = log_prob,
+        cdf = cdf,
+        log_cdf = log_cdf,
+        sample = sample
+      )
+    }
+  )
+)
+
+dirichlet_distribution <- R6Class(
+  "dirichlet_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(alpha, n_realisations, dimension) {
+      # coerce to greta arrays
+      alpha <- as.greta_array(alpha)
+
+      dim <- check_multivariate_dims(
+        vectors = list(alpha),
+        n_realisations = n_realisations,
+        dimension = dimension
+      )
+
+      # coerce the parameter arguments to nodes and add as parents and
+      # parameters
+      self$bounds <- c(0, Inf)
+      super$initialize("dirichlet", dim,
+        truncation = c(0, Inf),
+        multivariate = TRUE
+      )
+      self$add_parameter(alpha, "alpha")
+    },
+    create_target = function(truncation) {
+      simplex_greta_array <- simplex_variable(self$dim)
+
+      # return the node for the simplex
+      target_node <- get_node(simplex_greta_array)
+      target_node
+    },
+    tf_distrib = function(parameters, dag) {
+      alpha <- parameters$alpha
+      tfp$distributions$Dirichlet(concentration = alpha)
+    }
+  )
+)
+
+dirichlet_multinomial_distribution <- R6Class(
+  "dirichlet_multinomial_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(size, alpha, n_realisations, dimension) {
+
+      # coerce to greta arrays
+      size <- as.greta_array(size)
+      alpha <- as.greta_array(alpha)
+
+      dim <- check_multivariate_dims(
+        scalars = list(size),
+        vectors = list(alpha),
+        n_realisations = n_realisations,
+        dimension = dimension
+      )
+
+
+      # need to handle size as a vector!
+
+      # coerce the parameter arguments to nodes and add as parents and
+      # parameters
+      super$initialize("dirichlet_multinomial",
+        dim = dim,
+        discrete = TRUE,
+        multivariate = TRUE
+      )
+      self$add_parameter(size, "size", shape_matches_output = FALSE)
+      self$add_parameter(alpha, "alpha")
+    },
+
+    # nolint start
+    tf_distrib = function(parameters, dag) {
+      parameters$size <- tf_flatten(parameters$size)
+      distrib <- tfp$distributions$DirichletMultinomial
+      distrib(
+        total_count = parameters$size,
+        concentration = parameters$alpha
+      )
+    }
+    # nolint end
+  )
+)
+
+multinomial_distribution <- R6Class(
+  "multinomial_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(size, prob, n_realisations, dimension) {
+
+      # coerce to greta arrays
+      size <- as.greta_array(size)
+      prob <- as.greta_array(prob)
+
+      dim <- check_multivariate_dims(
+        scalars = list(size),
+        vectors = list(prob),
+        n_realisations = n_realisations,
+        dimension = dimension
+      )
+
+      # need to make sure size is a column vector!
+
+      # coerce the parameter arguments to nodes and add as parents and
+      # parameters
+      super$initialize("multinomial",
+        dim = dim,
+        discrete = TRUE,
+        multivariate = TRUE
+      )
+      self$add_parameter(size, "size", shape_matches_output = FALSE)
+      self$add_parameter(prob, "prob")
+    },
+    tf_distrib = function(parameters, dag) {
+      parameters$size <- tf_flatten(parameters$size)
+      # scale probs to get absolute density correct
+      parameters$prob <- parameters$prob / tf_sum(parameters$prob)
+
+      tfp$distributions$Multinomial(
+        total_count = parameters$size,
+        probs = parameters$prob
+      )
+    }
+  )
+)
+
+categorical_distribution <- R6Class(
+  "categorical_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(prob, n_realisations, dimension) {
+
+      # coerce to greta arrays
+      prob <- as.greta_array(prob)
+
+      dim <- check_multivariate_dims(
+        vectors = list(prob),
+        n_realisations = n_realisations,
+        dimension = dimension
+      )
+
+      # coerce the parameter arguments to nodes and add as parents and
+      # parameters
+      super$initialize("categorical",
+        dim = dim,
+        discrete = TRUE,
+        multivariate = TRUE
+      )
+      self$add_parameter(prob, "prob")
+    },
+    tf_distrib = function(parameters, dag) {
+      # scale probs to get absolute density correct
+      probs <- parameters$prob
+      probs <- probs / tf_sum(probs)
+      tfp$distributions$Multinomial(
+        total_count = fl(1),
+        probs = probs
+      )
+    }
+  )
+)
+
+multivariate_normal_distribution <- R6Class(
+  "multivariate_normal_distribution",
+  inherit = distribution_node,
+  public = list(
+    sigma_is_cholesky = FALSE,
+    # nolint start
+    initialize = function(mean, Sigma, n_realisations, dimension) {
+      # nolint end
+      # coerce to greta arrays
+      mean <- as.greta_array(mean)
+      sigma <- as.greta_array(Sigma)
+
+      # check dim is a positive scalar integer
+      dim <- check_multivariate_dims(
+        vectors = list(mean),
+        squares = list(sigma),
+        n_realisations = n_realisations,
+        dimension = dimension
+      )
+
+      # check dimensions of Sigma
+      if (nrow(sigma) != ncol(sigma) |
+        length(dim(sigma)) != 2) {
+        msg <- cli::format_error(
+          c(
+            "{.arg Sigma} must be a square 2D greta array",
+            "However {.arg Sigma} has dimensions \\
+            {.val {paste(dim(sigma), collapse = 'x')}}"
+          )
+        )
+        stop(
+          msg,
+          call. = FALSE
+        )
+      }
+
+      # compare possible dimensions
+      dim_mean <- ncol(mean)
+      dim_sigma <- nrow(sigma)
+
+      if (dim_mean != dim_sigma) {
+        msg <- cli::format_error(
+          c(
+            "{.arg mean} and {.arg Sigma} must have the same dimensions",
+            "However they are different: {dim_mean} vs {dim_sigma}"
+          )
+        )
+        stop(
+          msg,
+          call. = FALSE
+        )
+      }
+
+      # coerce the parameter arguments to nodes and add as parents and
+      # parameters
+      super$initialize("multivariate_normal", dim, multivariate = TRUE)
+
+      if (has_representation(sigma, "cholesky")) {
+        sigma <- representation(sigma, "cholesky")
+        self$sigma_is_cholesky <- TRUE
+      }
+      self$add_parameter(mean, "mean")
+      self$add_parameter(sigma, "sigma")
+    },
+    tf_distrib = function(parameters, dag) {
+
+      # if Sigma is a cholesky factor transpose it to tensorflow expoectation,
+      # otherwise decompose it
+
+      if (self$sigma_is_cholesky) {
+        l <- tf_transpose(parameters$sigma)
+      } else {
+        l <- tf$linalg$cholesky(parameters$sigma)
+      }
+
+      # add an extra dimension for the observation batch size (otherwise tfp
+      # will try to use the n_chains batch dimension)
+      l <- tf$expand_dims(l, 1L)
+
+      mu <- parameters$mean
+      # nolint start
+      tfp$distributions$MultivariateNormalTriL(
+        loc = mu,
+        scale_tril = l
+      )
+      # nolint end
+    }
+  )
+)
+
+wishart_distribution <- R6Class(
+  "wishart_distribution",
+  inherit = distribution_node,
+  public = list(
+
+    # set when defining the distribution
+    sigma_is_cholesky = FALSE,
+
+    # set when defining the graph
+    target_is_cholesky = FALSE,
+    initialize = function(df, Sigma) { # nolint
+      # add the nodes as parents and parameters
+
+      df <- as.greta_array(df)
+      sigma <- as.greta_array(Sigma)
+
+      # check dimensions of Sigma
+      if (nrow(sigma) != ncol(sigma) |
+        length(dim(sigma)) != 2) {
+        msg <- cli::format_error(
+          c(
+            "{.arg Sigma} must be a square 2D greta array",
+            "However, {.arg Sigma} has dimensions ",
+            "{.val {paste(dim(sigma), collapse = 'x')}}"
+          )
+        )
+        stop(
+          msg,
+          call. = FALSE
+        )
+      }
+
+      dim <- nrow(sigma)
+
+      # initialize with a cholesky factor
+      super$initialize("wishart", dim(sigma), multivariate = TRUE)
+
+      # set parameters
+      if (has_representation(sigma, "cholesky")) {
+        sigma <- representation(sigma, "cholesky")
+        self$sigma_is_cholesky <- TRUE
+      }
+      self$add_parameter(df, "df", shape_matches_output = FALSE)
+      self$add_parameter(sigma, "sigma")
+
+      # make the initial value PD (no idea whether this does anything)
+      self$value(unknowns(dims = c(dim, dim), data = diag(dim)))
+    },
+
+    # create a variable, and transform to a symmetric matrix (with cholesky
+    # factor representation)
+    create_target = function(truncation) {
+
+      # create cholesky factor variable greta array
+      chol_greta_array <- cholesky_variable(self$dim[1])
+
+      # reshape to a symmetric matrix (retaining cholesky representation)
+      matrix_greta_array <- chol2symm(chol_greta_array)
+
+      # return the node for the symmetric matrix
+      target_node <- get_node(matrix_greta_array)
+      target_node
+    },
+
+    # get a cholesky factor for the target if possible
+    get_tf_target_node = function() {
+      target <- self$target
+      if (has_representation(target, "cholesky")) {
+        chol <- representation(target, "cholesky")
+        target <- get_node(chol)
+        self$target_is_cholesky <- TRUE
+      }
+      target
+    },
+
+    # if the target is changed, make sure target_is_cholesky is reset to FALSE
+    # (can be resent on graph definition)
+    reset_target_flags = function() {
+      self$target_is_cholesky <- FALSE
+    },
+    tf_distrib = function(parameters, dag) {
+
+      # this is messy, we want to use the tfp wishart, but can't define the
+      # density without expanding the dimension of x
+
+      log_prob <- function(x) {
+
+        # reshape the dimensions
+        df <- tf_flatten(parameters$df)
+        sigma <- tf$expand_dims(parameters$sigma, 1L)
+        x <- tf$expand_dims(x, 1L)
+
+        # get the cholesky factor of Sigma in tf orientation
+        if (self$sigma_is_cholesky) {
+          sigma_chol <- tf$linalg$matrix_transpose(sigma)
+        } else {
+          sigma_chol <- tf$linalg$cholesky(sigma)
+        }
+
+        # get the cholesky factor of the target in tf_orientation
+        if (self$target_is_cholesky) {
+          x_chol <- tf$linalg$matrix_transpose(x)
+        } else {
+          x_chol <- tf$linalg$cholesky(x)
+        }
+
+        # use the density for choleskied x, with choleskied Sigma
+        distrib <- tfp$distributions$Wishart(
+          df = df,
+          scale_tril = sigma_chol,
+          input_output_cholesky = TRUE
+        )
+
+        distrib$log_prob(x_chol)
+      }
+
+      sample <- function(seed) {
+        df <- tf$squeeze(parameters$df, 1:2)
+        sigma <- parameters$sigma
+
+        # get the cholesky factor of Sigma in tf orientation
+        if (self$sigma_is_cholesky) {
+          sigma_chol <- tf$linalg$matrix_transpose(sigma)
+        } else {
+          sigma_chol <- tf$linalg$cholesky(sigma)
+        }
+
+        # use the density for choleskied x, with choleskied Sigma
+        distrib <- tfp$distributions$Wishart(
+          df = df,
+          scale_tril = sigma_chol
+        )
+
+        draws <- distrib$sample(seed = seed)
+
+        if (self$target_is_cholesky) {
+          draws <- tf_chol(draws)
+        }
+
+        draws
+      }
+
+      list(log_prob = log_prob, sample = sample)
+    }
+  )
+)
+
+lkj_correlation_distribution <- R6Class(
+  "lkj_correlation_distribution",
+  inherit = distribution_node,
+  public = list(
+
+    # set when defining the graph
+    target_is_cholesky = FALSE,
+    initialize = function(eta, dimension = 2) {
+      dimension <- check_dimension(target = dimension)
+
+      if (!inherits(eta, "greta_array")) {
+        if (!is.numeric(eta) || !length(eta) == 1 || eta <= 0) {
+          msg <- cli::format_error(
+            "{.arg eta} must be a positive scalar value, or a scalar \\
+            {.cls greta_array}"
+          )
+          stop(
+            msg,
+            call. = FALSE
+          )
+        }
+      }
+
+      # add the nodes as parents and parameters
+      eta <- as.greta_array(eta)
+
+      if (!is_scalar(eta)) {
+        msg <- cli::format_error(
+          c(
+            "{.arg eta} must be a scalar",
+            "However {.arg eta} had dimensions: \\
+            {paste0(dim(eta), collapse = ', ')}"
+          )
+        )
+        stop(
+          msg,
+          call. = FALSE
+        )
+      }
+
+      dim <- c(dimension, dimension)
+      super$initialize("lkj_correlation", dim, multivariate = TRUE)
+
+      # don't try to expand scalar eta out to match the target size
+      self$add_parameter(eta, "eta", shape_matches_output = FALSE)
+
+      # make the initial value PD
+      self$value(unknowns(dims = dim, data = diag(dimension)))
+    },
+
+    # default (cholesky factor, ignores truncation)
+    create_target = function(truncation) {
+
+      # create (correlation matrix) cholesky factor variable greta array
+      chol_greta_array <- cholesky_variable(self$dim[1], correlation = TRUE)
+
+      # reshape to a symmetric matrix (retaining cholesky representation)
+      matrix_greta_array <- chol2symm(chol_greta_array)
+
+      # return the node for the symmetric matrix
+      target_node <- get_node(matrix_greta_array)
+      target_node
+    },
+
+    # get a cholesky factor for the target if possible
+    get_tf_target_node = function() {
+      target <- self$target
+      if (has_representation(target, "cholesky")) {
+        chol <- representation(target, "cholesky")
+        target <- get_node(chol)
+        self$target_is_cholesky <- TRUE
+      }
+      target
+    },
+
+    # if the target is changed, make sure target_is_cholesky is reset to FALSE
+    # (can be resent on graph definition)
+    reset_target_flags = function() {
+      self$target_is_cholesky <- FALSE
+    },
+    tf_distrib = function(parameters, dag) {
+      eta <- tf$squeeze(parameters$eta, 1:2)
+      dim <- self$dim[1]
+
+      distrib <- tfp$distributions$LKJ(
+        dimension = dim,
+        concentration = eta,
+        input_output_cholesky = self$target_is_cholesky
+      )
+
+      # tfp's lkj sampling can't detect the size of the output from eta, for
+      # some reason. But we can use map_fun to apply their simulation to each
+      # element of eta.
+      sample <- function(seed) {
+        sample_once <- function(eta) {
+          d <- tfp$distributions$LKJ(
+            dimension = dim,
+            concentration = eta,
+            input_output_cholesky = self$target_is_cholesky
+          )
+
+          d$sample(seed = seed)
+        }
+
+        tf$map_fn(sample_once, eta)
+      }
+
+      list(
+        log_prob = distrib$log_prob,
+        sample = sample
+      )
+    }
+  )
+)
+
+# module for export via .internals
+distribution_classes_module <- module(uniform_distribution,
+                                      normal_distribution,
+                                      lognormal_distribution,
+                                      bernoulli_distribution,
+                                      binomial_distribution,
+                                      beta_binomial_distribution,
+                                      negative_binomial_distribution,
+                                      zero_inflated_poisson_distribution,
+                                      zero_inflated_negative_binomial_distribution,
+                                      hypergeometric_distribution,
+                                      poisson_distribution,
+                                      gamma_distribution,
+                                      inverse_gamma_distribution,
+                                      weibull_distribution,
+                                      exponential_distribution,
+                                      pareto_distribution,
+                                      student_distribution,
+                                      laplace_distribution,
+                                      beta_distribution,
+                                      cauchy_distribution,
+                                      chi_squared_distribution,
+                                      logistic_distribution,
+                                      f_distribution,
+                                      multivariate_normal_distribution,
+                                      wishart_distribution,
+                                      lkj_correlation_distribution,
+                                      multinomial_distribution,
+                                      categorical_distribution,
+                                      dirichlet_distribution,
+                                      dirichlet_multinomial_distribution)
+
+# export constructors
+
+# nolint start
+#' @name distributions
+#' @title probability distributions
+#' @description These functions can be used to define random variables in a
+#'   greta model. They return a variable greta array that follows the specified
+#'   distribution. This variable greta array can be used to represent a
+#'   parameter with prior distribution, combined into a mixture distribution
+#'   using [mixture()], or used with [distribution()] to
+#'   define a distribution over a data greta array.
+#'
+#' @param truncation a length-two vector giving values between which to truncate
+#'   the distribution, similarly to the `lower` and `upper` arguments
+#'   to [variable()]
+#'
+#' @param min,max scalar values giving optional limits to `uniform`
+#'   variables. Like `lower` and `upper`, these must be specified as
+#'   numerics, they cannot be greta arrays (though see details for a
+#'   workaround). Unlike `lower` and `upper`, they must be finite.
+#'   `min` must always be less than `max`.
+#'
+#' @param mean,meanlog,location,mu unconstrained parameters
+#'
+#' @param
+#'   sd,sdlog,sigma,lambda,shape,rate,df,scale,shape1,shape2,alpha,beta,df1,df2,a,b,eta
+#'    positive parameters, `alpha` must be a vector for `dirichlet`
+#'   and `dirichlet_multinomial`.
+#'
+#' @param size,m,n,k positive integer parameter
+#'
+#' @param prob probability parameter (`0 < prob < 1`), must be a vector for
+#'   `multinomial` and `categorical`
+#'
+#' @param Sigma positive definite variance-covariance matrix parameter
+#'
+#' @param dim the dimensions of the greta array to be returned, either a scalar
+#'   or a vector of positive integers. See details.
+#'
+#' @param dimension the dimension of a multivariate distribution
+#'
+#' @param n_realisations the number of independent realisation of a multivariate
+#'   distribution
+#'
+#' @details The discrete probability distributions (`bernoulli`,
+#'   `binomial`, `negative_binomial`, `poisson`,
+#'   `multinomial`, `categorical`, `dirichlet_multinomial`) can
+#'   be used when they have fixed values (e.g. defined as a likelihood using
+#'   [distribution()], but not as unknown variables.
+#'
+#'   For univariate distributions `dim` gives the dimensions of the greta
+#'   array to create. Each element of the greta array will be (independently)
+#'   distributed according to the distribution. `dim` can also be left at
+#'   its default of `NULL`, in which case the dimension will be detected
+#'   from the dimensions of the parameters (provided they are compatible with
+#'   one another).
+#'
+#'   For multivariate distributions (`multivariate_normal()`,
+#'   `multinomial()`, `categorical()`, `dirichlet()`, and
+#'   `dirichlet_multinomial()`) each row of the output and parameters
+#'   corresponds to an independent realisation. If a single realisation or
+#'   parameter value is specified, it must therefore be a row vector (see
+#'   example). `n_realisations` gives the number of rows/realisations, and
+#'   `dimension` gives the dimension of the distribution. I.e. a bivariate
+#'   normal distribution would be produced with `multivariate_normal(...,
+#'   dimension = 2)`. The dimension can usually be detected from the parameters.
+#'
+#'   `multinomial()` does not check that observed values sum to
+#'   `size`, and `categorical()` does not check that only one of the
+#'   observed entries is 1. It's the user's responsibility to check their data
+#'   matches the distribution!
+#'
+#'   The parameters of `uniform` must be fixed, not greta arrays. This
+#'   ensures these values can always be transformed to a continuous scale to run
+#'   the samplers efficiently. However, a hierarchical `uniform` parameter
+#'   can always be created by defining a `uniform` variable constrained
+#'   between 0 and 1, and then transforming it to the required scale. See below
+#'   for an example.
+#'
+#'   Wherever possible, the parameterisations and argument names of greta
+#'   distributions match commonly used R functions for distributions, such as
+#'   those in the `stats` or `extraDistr` packages. The following
+#'   table states the distribution function to which greta's implementation
+#'   corresponds:
+#'
+#'   \tabular{ll}{ greta \tab reference\cr `uniform` \tab
+#'   [stats::dunif]\cr `normal` \tab
+#'   [stats::dnorm]\cr `lognormal` \tab
+#'   [stats::dlnorm]\cr `bernoulli` \tab
+#'   [extraDistr::dbern]\cr `binomial` \tab
+#'   [stats::dbinom]\cr `beta_binomial` \tab
+#'   [extraDistr::dbbinom]\cr `negative_binomial`
+#'   \tab [stats::dnbinom]\cr `hypergeometric` \tab
+#'   [stats::dhyper]\cr `poisson` \tab
+#'   [stats::dpois]\cr `gamma` \tab
+#'   [stats::dgamma]\cr `inverse_gamma` \tab
+#'   [extraDistr::dinvgamma]\cr `weibull` \tab
+#'   [stats::dweibull]\cr `exponential` \tab
+#'   [stats::dexp]\cr `pareto` \tab
+#'   [extraDistr::dpareto]\cr `student` \tab
+#'   [extraDistr::dlst]\cr `laplace` \tab
+#'   [extraDistr::dlaplace]\cr `beta` \tab
+#'   [stats::dbeta]\cr `cauchy` \tab
+#'   [stats::dcauchy]\cr `chi_squared` \tab
+#'   [stats::dchisq]\cr `logistic` \tab
+#'   [stats::dlogis]\cr `f` \tab
+#'   [stats::df]\cr `multivariate_normal` \tab
+#'   [mvtnorm::dmvnorm]\cr `multinomial` \tab
+#'   [stats::dmultinom]\cr `categorical` \tab
+#'   {[stats::dmultinom] (size = 1)}\cr `dirichlet`
+#'   \tab [extraDistr::ddirichlet]\cr
+#'   `dirichlet_multinomial` \tab
+#'   [extraDistr::ddirmnom]\cr `wishart` \tab
+#'   [stats::rWishart]\cr `lkj_correlation` \tab
+#'   [rethinking::dlkjcorr](https://rdrr.io/github/rmcelreath/rethinking/man/dlkjcorr.html)
+#'   }
+#'
+#' @examples
+#' \dontrun{
+#'
+#' # a uniform parameter constrained to be between 0 and 1
+#' phi <- uniform(min = 0, max = 1)
+#'
+#' # a length-three variable, with each element following a standard normal
+#' # distribution
+#' alpha <- normal(0, 1, dim = 3)
+#'
+#' # a length-three variable of lognormals
+#' sigma <- lognormal(0, 3, dim = 3)
+#'
+#' # a hierarchical uniform, constrained between alpha and alpha + sigma,
+#' eta <- alpha + uniform(0, 1, dim = 3) * sigma
+#'
+#' # a hierarchical distribution
+#' mu <- normal(0, 1)
+#' sigma <- lognormal(0, 1)
+#' theta <- normal(mu, sigma)
+#'
+#' # a vector of 3 variables drawn from the same hierarchical distribution
+#' thetas <- normal(mu, sigma, dim = 3)
+#'
+#' # a matrix of 12 variables drawn from the same hierarchical distribution
+#' thetas <- normal(mu, sigma, dim = c(3, 4))
+#'
+#' # a multivariate normal variable, with correlation between two elements
+#' # note that the parameter must be a row vector
+#' Sig <- diag(4)
+#' Sig[3, 4] <- Sig[4, 3] <- 0.6
+#' theta <- multivariate_normal(t(rep(mu, 4)), Sig)
+#'
+#' # 10 independent replicates of that
+#' theta <- multivariate_normal(t(rep(mu, 4)), Sig, n_realisations = 10)
+#'
+#' # 10 multivariate normal replicates, each with a different mean vector,
+#' # but the same covariance matrix
+#' means <- matrix(rnorm(40), 10, 4)
+#' theta <- multivariate_normal(means, Sig, n_realisations = 10)
+#' dim(theta)
+#'
+#' # a Wishart variable with the same covariance parameter
+#' theta <- wishart(df = 5, Sigma = Sig)
+#' }
+NULL
+# nolint end
+
+#' @rdname distributions
+#' @export
+uniform <- function(min, max, dim = NULL) {
+  distrib("uniform", min, max, dim)
+}
+
+#' @rdname distributions
+#' @export
+normal <- function(mean, sd, dim = NULL, truncation = c(-Inf, Inf)) {
+  distrib("normal", mean, sd, dim, truncation)
+}
+
+#' @rdname distributions
+#' @export
+lognormal <- function(meanlog, sdlog, dim = NULL, truncation = c(0, Inf)) {
+  distrib("lognormal", meanlog, sdlog, dim, truncation)
+}
+
+#' @rdname distributions
+#' @export
+bernoulli <- function(prob, dim = NULL) {
+  distrib("bernoulli", prob, dim)
+}
+
+#' @rdname distributions
+#' @export
+binomial <- function(size, prob, dim = NULL) {
+  check_in_family("binomial", size)
+  distrib("binomial", size, prob, dim)
+}
+
+#' @rdname distributions
+#' @export
+beta_binomial <- function(size, alpha, beta, dim = NULL) {
+  distrib("beta_binomial", size, alpha, beta, dim)
+}
+
+#' @rdname distributions
+#' @export
+negative_binomial <- function(size, prob, dim = NULL) {
+  distrib("negative_binomial", size, prob, dim)
+}
+
+#' @rdname distributions
+#' @export
+hypergeometric <- function(m, n, k, dim = NULL) {
+  distrib("hypergeometric", m, n, k, dim)
+}
+
+#' @rdname distributions
+#' @export
+poisson <- function(lambda, dim = NULL) {
+  check_in_family("poisson", lambda)
+  distrib("poisson", lambda, dim)
+}
+
+#' @rdname distributions
+#' @export
+zero_inflated_poisson <- function (theta, lambda, dim = NULL) {
+  distrib('zero_inflated_poisson', theta, lambda, dim)
+}
+
+#' @rdname distributions
+#' @export
+zero_inflated_negative_binomial <- function (theta, size, prob, dim = NULL) {
+  distrib('zero_inflated_negative_binomial', theta, size, prob, dim)
+}
+
+#' @rdname distributions
+#' @export
+gamma <- function(shape, rate, dim = NULL, truncation = c(0, Inf)) {
+  distrib("gamma", shape, rate, dim, truncation)
+}
+
+#' @rdname distributions
+#' @export
+inverse_gamma <- function(alpha, beta, dim = NULL, truncation = c(0, Inf)) {
+  distrib("inverse_gamma", alpha, beta, dim, truncation)
+}
+
+#' @rdname distributions
+#' @export
+weibull <- function(shape, scale, dim = NULL, truncation = c(0, Inf)) {
+  distrib("weibull", shape, scale, dim, truncation)
+}
+
+#' @rdname distributions
+#' @export
+exponential <- function(rate, dim = NULL, truncation = c(0, Inf)) {
+  distrib("exponential", rate, dim, truncation)
+}
+
+#' @rdname distributions
+#' @export
+pareto <- function(a, b, dim = NULL, truncation = c(0, Inf)) {
+  distrib("pareto", a, b, dim, truncation)
+}
+
+#' @rdname distributions
+#' @export
+student <- function(df, mu, sigma, dim = NULL, truncation = c(-Inf, Inf)) {
+  distrib("student", df, mu, sigma, dim, truncation)
+}
+
+#' @rdname distributions
+#' @export
+laplace <- function(mu, sigma, dim = NULL, truncation = c(-Inf, Inf)) {
+  distrib("laplace", mu, sigma, dim, truncation)
+}
+
+#' @rdname distributions
+#' @export
+beta <- function(shape1, shape2, dim = NULL, truncation = c(0, 1)) {
+  distrib("beta", shape1, shape2, dim, truncation)
+}
+
+#' @rdname distributions
+#' @export
+cauchy <- function(location, scale, dim = NULL, truncation = c(-Inf, Inf)) {
+  distrib("cauchy", location, scale, dim, truncation)
+}
+
+#' @rdname distributions
+#' @export
+chi_squared <- function(df, dim = NULL, truncation = c(0, Inf)) {
+  distrib("chi_squared", df, dim, truncation)
+}
+
+#' @rdname distributions
+#' @export
+logistic <- function(location, scale, dim = NULL, truncation = c(-Inf, Inf)) {
+  distrib("logistic", location, scale, dim, truncation)
+}
+
+#' @rdname distributions
+#' @export
+f <- function(df1, df2, dim = NULL, truncation = c(0, Inf)) {
+  distrib("f", df1, df2, dim, truncation)
+}
+
+# nolint start
+#' @rdname distributions
+#' @export
+multivariate_normal <- function(mean, Sigma,
+                                n_realisations = NULL, dimension = NULL) {
+  # nolint end
+  distrib(
+    "multivariate_normal", mean, Sigma,
+    n_realisations, dimension
+  )
+}
+
+#' @rdname distributions
+#' @export
+wishart <- function(df, Sigma) { # nolint
+  distrib("wishart", df, Sigma)
+}
+
+#' @rdname distributions
+#' @export
+lkj_correlation <- function(eta, dimension = 2) {
+  distrib("lkj_correlation", eta, dimension)
+}
+
+#' @rdname distributions
+#' @export
+multinomial <- function(size, prob, n_realisations = NULL, dimension = NULL) {
+  distrib("multinomial", size, prob, n_realisations, dimension)
+}
+
+#' @rdname distributions
+#' @export
+categorical <- function(prob, n_realisations = NULL, dimension = NULL) {
+  distrib("categorical", prob, n_realisations, dimension)
+}
+
+#' @rdname distributions
+#' @export
+dirichlet <- function(alpha, n_realisations = NULL, dimension = NULL) {
+  distrib("dirichlet", alpha, n_realisations, dimension)
+}
+
+#' @rdname distributions
+#' @export
+dirichlet_multinomial <- function(size, alpha,
+                                  n_realisations = NULL, dimension = NULL) {
+  distrib(
+    "dirichlet_multinomial",
+    size, alpha, n_realisations, dimension
+  )
+}
diff --git a/man/distributions.Rd b/man/distributions.Rd
new file mode 100644
index 0000000..c133420
--- /dev/null
+++ b/man/distributions.Rd
@@ -0,0 +1,254 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/probability_distributions.R
+\name{distributions}
+\alias{distributions}
+\alias{uniform}
+\alias{normal}
+\alias{lognormal}
+\alias{bernoulli}
+\alias{binomial}
+\alias{beta_binomial}
+\alias{negative_binomial}
+\alias{hypergeometric}
+\alias{poisson}
+\alias{zero_inflated_poisson}
+\alias{zero_inflated_negative_binomial}
+\alias{gamma}
+\alias{inverse_gamma}
+\alias{weibull}
+\alias{exponential}
+\alias{pareto}
+\alias{student}
+\alias{laplace}
+\alias{beta}
+\alias{cauchy}
+\alias{chi_squared}
+\alias{logistic}
+\alias{f}
+\alias{multivariate_normal}
+\alias{wishart}
+\alias{lkj_correlation}
+\alias{multinomial}
+\alias{categorical}
+\alias{dirichlet}
+\alias{dirichlet_multinomial}
+\title{probability distributions}
+\usage{
+uniform(min, max, dim = NULL)
+
+normal(mean, sd, dim = NULL, truncation = c(-Inf, Inf))
+
+lognormal(meanlog, sdlog, dim = NULL, truncation = c(0, Inf))
+
+bernoulli(prob, dim = NULL)
+
+binomial(size, prob, dim = NULL)
+
+beta_binomial(size, alpha, beta, dim = NULL)
+
+negative_binomial(size, prob, dim = NULL)
+
+hypergeometric(m, n, k, dim = NULL)
+
+poisson(lambda, dim = NULL)
+
+zero_inflated_poisson(theta, lambda, dim = NULL)
+
+zero_inflated_negative_binomial(theta, size, prob, dim = NULL)
+
+gamma(shape, rate, dim = NULL, truncation = c(0, Inf))
+
+inverse_gamma(alpha, beta, dim = NULL, truncation = c(0, Inf))
+
+weibull(shape, scale, dim = NULL, truncation = c(0, Inf))
+
+exponential(rate, dim = NULL, truncation = c(0, Inf))
+
+pareto(a, b, dim = NULL, truncation = c(0, Inf))
+
+student(df, mu, sigma, dim = NULL, truncation = c(-Inf, Inf))
+
+laplace(mu, sigma, dim = NULL, truncation = c(-Inf, Inf))
+
+beta(shape1, shape2, dim = NULL, truncation = c(0, 1))
+
+cauchy(location, scale, dim = NULL, truncation = c(-Inf, Inf))
+
+chi_squared(df, dim = NULL, truncation = c(0, Inf))
+
+logistic(location, scale, dim = NULL, truncation = c(-Inf, Inf))
+
+f(df1, df2, dim = NULL, truncation = c(0, Inf))
+
+multivariate_normal(mean, Sigma, n_realisations = NULL, dimension = NULL)
+
+wishart(df, Sigma)
+
+lkj_correlation(eta, dimension = 2)
+
+multinomial(size, prob, n_realisations = NULL, dimension = NULL)
+
+categorical(prob, n_realisations = NULL, dimension = NULL)
+
+dirichlet(alpha, n_realisations = NULL, dimension = NULL)
+
+dirichlet_multinomial(size, alpha, n_realisations = NULL, dimension = NULL)
+}
+\arguments{
+\item{min, max}{scalar values giving optional limits to \code{uniform}
+variables. Like \code{lower} and \code{upper}, these must be specified as
+numerics, they cannot be greta arrays (though see details for a
+workaround). Unlike \code{lower} and \code{upper}, they must be finite.
+\code{min} must always be less than \code{max}.}
+
+\item{dim}{the dimensions of the greta array to be returned, either a scalar
+or a vector of positive integers. See details.}
+
+\item{mean, meanlog, location, mu}{unconstrained parameters}
+
+\item{sd, sdlog, sigma, lambda, shape, rate, df, scale, shape1, shape2, alpha, beta, df1, df2, a, b, eta}{positive parameters, \code{alpha} must be a vector for \code{dirichlet}
+and \code{dirichlet_multinomial}.}
+
+\item{truncation}{a length-two vector giving values between which to truncate
+the distribution, similarly to the \code{lower} and \code{upper} arguments
+to \code{\link[=variable]{variable()}}}
+
+\item{prob}{probability parameter (\verb{0 < prob < 1}), must be a vector for
+\code{multinomial} and \code{categorical}}
+
+\item{size, m, n, k}{positive integer parameter}
+
+\item{Sigma}{positive definite variance-covariance matrix parameter}
+
+\item{n_realisations}{the number of independent realisation of a multivariate
+distribution}
+
+\item{dimension}{the dimension of a multivariate distribution}
+}
+\description{
+These functions can be used to define random variables in a
+greta model. They return a variable greta array that follows the specified
+distribution. This variable greta array can be used to represent a
+parameter with prior distribution, combined into a mixture distribution
+using \code{\link[=mixture]{mixture()}}, or used with \code{\link[=distribution]{distribution()}} to
+define a distribution over a data greta array.
+}
+\details{
+The discrete probability distributions (\code{bernoulli},
+\code{binomial}, \code{negative_binomial}, \code{poisson},
+\code{multinomial}, \code{categorical}, \code{dirichlet_multinomial}) can
+be used when they have fixed values (e.g. defined as a likelihood using
+\code{\link[=distribution]{distribution()}}, but not as unknown variables.
+
+For univariate distributions \code{dim} gives the dimensions of the greta
+array to create. Each element of the greta array will be (independently)
+distributed according to the distribution. \code{dim} can also be left at
+its default of \code{NULL}, in which case the dimension will be detected
+from the dimensions of the parameters (provided they are compatible with
+one another).
+
+For multivariate distributions (\code{multivariate_normal()},
+\code{multinomial()}, \code{categorical()}, \code{dirichlet()}, and
+\code{dirichlet_multinomial()}) each row of the output and parameters
+corresponds to an independent realisation. If a single realisation or
+parameter value is specified, it must therefore be a row vector (see
+example). \code{n_realisations} gives the number of rows/realisations, and
+\code{dimension} gives the dimension of the distribution. I.e. a bivariate
+normal distribution would be produced with \code{multivariate_normal(..., dimension = 2)}. The dimension can usually be detected from the parameters.
+
+\code{multinomial()} does not check that observed values sum to
+\code{size}, and \code{categorical()} does not check that only one of the
+observed entries is 1. It's the user's responsibility to check their data
+matches the distribution!
+
+The parameters of \code{uniform} must be fixed, not greta arrays. This
+ensures these values can always be transformed to a continuous scale to run
+the samplers efficiently. However, a hierarchical \code{uniform} parameter
+can always be created by defining a \code{uniform} variable constrained
+between 0 and 1, and then transforming it to the required scale. See below
+for an example.
+
+Wherever possible, the parameterisations and argument names of greta
+distributions match commonly used R functions for distributions, such as
+those in the \code{stats} or \code{extraDistr} packages. The following
+table states the distribution function to which greta's implementation
+corresponds:
+
+\tabular{ll}{ greta \tab reference\cr \code{uniform} \tab
+\link[stats:Uniform]{stats::dunif}\cr \code{normal} \tab
+\link[stats:Normal]{stats::dnorm}\cr \code{lognormal} \tab
+\link[stats:Lognormal]{stats::dlnorm}\cr \code{bernoulli} \tab
+\link[extraDistr:Bernoulli]{extraDistr::dbern}\cr \code{binomial} \tab
+\link[stats:Binomial]{stats::dbinom}\cr \code{beta_binomial} \tab
+\link[extraDistr:BetaBinom]{extraDistr::dbbinom}\cr \code{negative_binomial}
+\tab \link[stats:NegBinomial]{stats::dnbinom}\cr \code{hypergeometric} \tab
+\link[stats:Hypergeometric]{stats::dhyper}\cr \code{poisson} \tab
+\link[stats:Poisson]{stats::dpois}\cr \code{gamma} \tab
+\link[stats:GammaDist]{stats::dgamma}\cr \code{inverse_gamma} \tab
+\link[extraDistr:InvGamma]{extraDistr::dinvgamma}\cr \code{weibull} \tab
+\link[stats:Weibull]{stats::dweibull}\cr \code{exponential} \tab
+\link[stats:Exponential]{stats::dexp}\cr \code{pareto} \tab
+\link[extraDistr:Pareto]{extraDistr::dpareto}\cr \code{student} \tab
+\link[extraDistr:LocationScaleT]{extraDistr::dlst}\cr \code{laplace} \tab
+\link[extraDistr:Laplace]{extraDistr::dlaplace}\cr \code{beta} \tab
+\link[stats:Beta]{stats::dbeta}\cr \code{cauchy} \tab
+\link[stats:Cauchy]{stats::dcauchy}\cr \code{chi_squared} \tab
+\link[stats:Chisquare]{stats::dchisq}\cr \code{logistic} \tab
+\link[stats:Logistic]{stats::dlogis}\cr \code{f} \tab
+\link[stats:Fdist]{stats::df}\cr \code{multivariate_normal} \tab
+\link[mvtnorm:Mvnorm]{mvtnorm::dmvnorm}\cr \code{multinomial} \tab
+\link[stats:Multinom]{stats::dmultinom}\cr \code{categorical} \tab
+{\link[stats:Multinom]{stats::dmultinom} (size = 1)}\cr \code{dirichlet}
+\tab \link[extraDistr:Dirichlet]{extraDistr::ddirichlet}\cr
+\code{dirichlet_multinomial} \tab
+\link[extraDistr:DirMnom]{extraDistr::ddirmnom}\cr \code{wishart} \tab
+\link[stats:rWishart]{stats::rWishart}\cr \code{lkj_correlation} \tab
+\href{https://rdrr.io/github/rmcelreath/rethinking/man/dlkjcorr.html}{rethinking::dlkjcorr}
+}
+}
+\examples{
+\dontrun{
+
+# a uniform parameter constrained to be between 0 and 1
+phi <- uniform(min = 0, max = 1)
+
+# a length-three variable, with each element following a standard normal
+# distribution
+alpha <- normal(0, 1, dim = 3)
+
+# a length-three variable of lognormals
+sigma <- lognormal(0, 3, dim = 3)
+
+# a hierarchical uniform, constrained between alpha and alpha + sigma,
+eta <- alpha + uniform(0, 1, dim = 3) * sigma
+
+# a hierarchical distribution
+mu <- normal(0, 1)
+sigma <- lognormal(0, 1)
+theta <- normal(mu, sigma)
+
+# a vector of 3 variables drawn from the same hierarchical distribution
+thetas <- normal(mu, sigma, dim = 3)
+
+# a matrix of 12 variables drawn from the same hierarchical distribution
+thetas <- normal(mu, sigma, dim = c(3, 4))
+
+# a multivariate normal variable, with correlation between two elements
+# note that the parameter must be a row vector
+Sig <- diag(4)
+Sig[3, 4] <- Sig[4, 3] <- 0.6
+theta <- multivariate_normal(t(rep(mu, 4)), Sig)
+
+# 10 independent replicates of that
+theta <- multivariate_normal(t(rep(mu, 4)), Sig, n_realisations = 10)
+
+# 10 multivariate normal replicates, each with a different mean vector,
+# but the same covariance matrix
+means <- matrix(rnorm(40), 10, 4)
+theta <- multivariate_normal(means, Sig, n_realisations = 10)
+dim(theta)
+
+# a Wishart variable with the same covariance parameter
+theta <- wishart(df = 5, Sigma = Sig)
+}
+}
diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R
index bd2b00b..00860e3 100644
--- a/tests/testthat/helpers.R
+++ b/tests/testthat/helpers.R
@@ -16,13 +16,13 @@ grab <- function(x, dag = NULL) {
   if (inherits(x, "node")) {
     x <- as.greta_array(x)
   }
-  
+
   if (inherits(x, "greta_array")) {
     node <- get_node(x)
     dag <- dag_class$new(list(x))
     dag$define_tf()
   }
-  
+
   dag$set_tf_data_list("batch_size", 1L)
   dag$build_feed_dict()
   out <- dag$tf_sess_run(dag$tf_name(node), as_text = TRUE)
@@ -44,11 +44,11 @@ set_distribution <- function(dist, data) {
 get_density <- function(distrib, data) {
   x <- as_data(data)
   distribution(x) <- distrib
-  
+
   # create dag and define the density
   dag <- dag_class$new(list(x))
   get_node(x)$distribution$define_tf(dag)
-  
+
   # get the log density as a vector
   tensor_name <- dag$tf_name(get_node(distrib)$distribution)
   tensor <- get(tensor_name, envir = dag$tf_environment)
@@ -65,7 +65,7 @@ compare_distribution <- function(greta_fun, r_fun, parameters, x,
   # both of these functions must take the same parameters in the same order
   # 'parameters' is an optionally named list of numeric parameter values
   # x is the vector of values at which to evaluate the log density
-  
+
   # define greta distribution, with fixed values
   greta_log_density <- greta_density(
     greta_fun, parameters, x,
@@ -73,7 +73,7 @@ compare_distribution <- function(greta_fun, r_fun, parameters, x,
   )
   # get R version
   r_log_density <- log(do.call(r_fun, c(list(x), parameters)))
-  
+
   # return absolute difference
   compare_op(r_log_density, greta_log_density, tolerance)
 }
@@ -85,45 +85,45 @@ greta_density <- function(fun, parameters, x,
   if (is.null(dim)) {
     dim <- NROW(x)
   }
-  
+
   # add the output dimension to the arguments list
   dim_list <- list(dim = dim)
-  
+
   # if it's a multivariate distribution name it n_realisations
   if (multivariate) {
     names(dim_list) <- "n_realisations"
   }
-  
+
   # don't add it for wishart & lkj, which don't mave multiple realisations
   is_wishart <- identical(names(parameters), c("df", "Sigma"))
   is_lkj <- identical(names(parameters), c("eta", "dimension"))
   if (is_wishart | is_lkj) {
     dim_list <- list()
   }
-  
+
   parameters <- c(parameters, dim_list)
-  
+
   # evaluate greta distribution
   dist <- do.call(fun, parameters)
   distrib_node <- get_node(dist)$distribution
-  
+
   # set density
   x_ <- as.greta_array(x)
   distrib_node$remove_target()
   distrib_node$add_target(get_node(x_))
-  
+
   # create dag
   dag <- dag_class$new(list(x_))
   dag$define_tf()
   dag$set_tf_data_list("batch_size", 1L)
   dag$build_feed_dict()
-  
+
   # get the log density as a vector
   dag$on_graph(
     result <- dag$evaluate_density(distrib_node, get_node(x_))
   )
   assign("test_density", result, dag$tf_environment)
-  
+
   density <- dag$tf_sess_run(test_density)
   as.vector(density)
 }
@@ -140,21 +140,21 @@ with_greta <- function(call, swap = c("x"), swap_scope = 1) {
     ")"
   )
   swap_list <- eval(parse(text = swap_text),
-                    envir = parent.frame(n = swap_scope)
+    envir = parent.frame(n = swap_scope)
   )
-  
+
   greta_result <- with(
     swap_list,
     eval(call)
   )
   result <- grab(greta_result)
-  
+
   # account for the fact that greta outputs are 1D arrays; convert them back to
   # R vectors
   if (is.array(result) && length(dim(result)) == 2 && dim(result)[2] == 1) {
     result <- as.vector(result)
   }
-  
+
   result
 }
 
@@ -163,13 +163,13 @@ with_greta <- function(call, swap = c("x"), swap_scope = 1) {
 # e.g. check_expr(a[1:3], swap = 'a')
 check_expr <- function(expr, swap = c("x"), tolerance = 1e-4) {
   call <- substitute(expr)
-  
+
   r_out <- eval(expr)
   greta_out <- with_greta(call,
-                          swap = swap,
-                          swap_scope = 2
+    swap = swap,
+    swap_scope = 2
   )
-  
+
   compare_op(r_out, greta_out, tolerance)
 }
 
@@ -190,9 +190,9 @@ gen_opfun <- function(n, ops) {
   for (i in seq_len(n)) {
     string <- add_op_string(string, ops = ops)
   }
-  
+
   fun_string <- sprintf("function(a, b) {%s}", string)
-  
+
   eval(parse(text = fun_string))
 }
 
@@ -205,7 +205,7 @@ sample_distribution <- function(greta_array, n = 10,
   draws <- mcmc(m, n_samples = n, warmup = warmup, verbose = FALSE)
   samples <- as.matrix(draws)
   vectorised <- length(lower) > 1 | length(upper) > 1
-  
+
   if (vectorised) {
     above_lower <- sweep(samples, 2, lower, `>=`)
     below_upper <- sweep(samples, 2, upper, `<=`)
@@ -213,7 +213,7 @@ sample_distribution <- function(greta_array, n = 10,
     above_lower <- samples >= lower
     below_upper <- samples <= upper
   }
-  
+
   expect_true(all(above_lower & below_upper))
 }
 
@@ -227,7 +227,7 @@ compare_truncated_distribution <- function(greta_fun,
   # is a greta array created from a distribution and a constrained variable
   # greta array. 'r_fun' is an r function returning the log density for the same
   # truncated distribution, taking x as its only argument.
-  
+
   x <- do.call(
     truncdist::rtrunc,
     c(
@@ -238,18 +238,18 @@ compare_truncated_distribution <- function(greta_fun,
       parameters
     )
   )
-  
+
   # create truncated R function and evaluate it
   r_fun <- truncfun(which, parameters, truncation)
   r_log_density <- log(r_fun(x))
-  
+
   greta_log_density <- greta_density(
     fun = greta_fun,
     parameters = c(parameters, list(truncation = truncation)),
     x = x,
     dim = 1
   )
-  
+
   # return absolute difference
   compare_op(r_log_density, greta_log_density, tolerance)
 }
@@ -263,7 +263,7 @@ truncfun <- function(which = "norm", parameters, truncation) {
     b = truncation[2],
     parameters
   )
-  
+
   function(x) {
     arg_list <- c(x = list(x), args)
     do.call(truncdist::dtrunc, arg_list)
@@ -317,7 +317,7 @@ get_output <- function(expr) {
 # mock up mcmc progress bar output for neurotic testing
 mock_mcmc <- function(n_samples = 1010) {
   pb <- create_progress_bar("sampling", c(0, n_samples),
-                            pb_update = 10, width = 50
+    pb_update = 10, width = 50
   )
   iterate_progress_bar(pb, n_samples, rejects = 10, chains = 1)
 }
@@ -328,7 +328,7 @@ rlkjcorr <- function(n, eta = 1, dimension = 2) {
   k <- dimension
   stopifnot(is.numeric(k), k >= 2, k == as.integer(k))
   stopifnot(eta > 0)
-  
+
   f <- function() {
     alpha <- eta + (k - 2) / 2
     r12 <- 2 * stats::rbeta(1, alpha, alpha) - 1
@@ -336,7 +336,7 @@ rlkjcorr <- function(n, eta = 1, dimension = 2) {
     r[1, 1] <- 1
     r[1, 2] <- r12
     r[2, 2] <- sqrt(1 - r12^2)
-    
+
     if (k > 2) {
       for (m in 2:(k - 1)) {
         alpha <- alpha - 0.5
@@ -347,18 +347,18 @@ rlkjcorr <- function(n, eta = 1, dimension = 2) {
         r[m + 1, m + 1] <- sqrt(1 - y)
       }
     }
-    
+
     crossprod(r)
   }
-  
+
   r <- replicate(n, f())
-  
+
   if (dim(r)[3] == 1) {
     r <- r[, , 1]
   } else {
     r <- aperm(r, c(3, 1, 2))
   }
-  
+
   r
 }
 
@@ -502,9 +502,9 @@ rmixmvnorm <- function(n, ...) {
   weights <- args[[which(is_weights)]]
   args_list <- lapply(params_list, function(par) c(n, par))
   sims <- lapply(args_list, function(par) do.call(rmvnorm, par))
-  
+
   components <- sample.int(length(sims), n, prob = weights, replace = TRUE)
-  
+
   # loop through the n observations, pulling out the corresponding slice
   draws_out <- array(NA, dim(sims[[1]]))
   for (i in seq_len(n)) {
@@ -534,19 +534,19 @@ compare_iid_samples <- function(greta_fun,
                                 nsim = 200,
                                 p_value_threshold = 0.001) {
   greta_array <- do.call(greta_fun, parameters)
-  
+
   # get information about distribution
   distribution <- get_node(greta_array)$distribution
   multivariate <- distribution$multivariate
   discrete <- distribution$discrete
   name <- distribution$distribution_name
-  
+
   greta_samples <- calculate(greta_array, nsim = nsim)[[1]]
   r_samples <- do.call(r_fun, c(n = nsim, parameters))
-  
+
   # reshape to matrix or vector
   if (multivariate) {
-    
+
     # if it's a symmetric matrix, take only a triangle and flatten it
     if (name %in% c("wishart", "lkj_correlation")) {
       include_diag <- name == "wishart"
@@ -563,14 +563,14 @@ compare_iid_samples <- function(greta_fun,
   } else {
     greta_samples <- as.vector(greta_samples)
   }
-  
+
   # find a vaguely appropriate test
   if (discrete) {
     test <- ifelse(multivariate, combined_chisq_test, stats::chisq.test)
   } else {
     test <- ifelse(multivariate, cramer::cramer.test, stats::ks.test)
   }
-  
+
   # do Kolmogorov Smirnov test on samples
   suppressWarnings(test_result <- test(greta_samples, r_samples))
   testthat::expect_gte(test_result$p.value, p_value_threshold)
@@ -593,10 +593,10 @@ check_geweke <- function(sampler, model, data,
                          p_theta, p_x_bar_theta,
                          niter = 2000, warmup = 1000,
                          title = "Geweke test") {
-  
+
   # sample independently
   target_theta <- p_theta(niter)
-  
+
   # sample with Markov chain
   greta_theta <- p_theta_greta(
     niter = niter,
@@ -607,14 +607,14 @@ check_geweke <- function(sampler, model, data,
     sampler = sampler,
     warmup = warmup
   )
-  
+
   # visualise correspondence
   quants <- (1:99) / 100
   q1 <- stats::quantile(target_theta, quants)
   q2 <- stats::quantile(greta_theta, quants)
   plot(q2, q1, main = title)
   graphics::abline(0, 1)
-  
+
   # do a formal hypothesis test
   suppressWarnings(stat <- stats::ks.test(target_theta, greta_theta))
   testthat::expect_gte(stat$p.value, 0.005)
@@ -627,44 +627,44 @@ p_theta_greta <- function(niter, model, data,
                           p_theta, p_x_bar_theta,
                           sampler = hmc(),
                           warmup = 1000) {
-  
+
   # set up and initialize trace
   theta <- rep(NA, niter)
   theta[1] <- p_theta(1)
-  
+
   # set up and tune sampler
   draws <- mcmc(model,
-                warmup = warmup,
-                n_samples = 1,
-                chains = 1,
-                sampler = sampler,
-                verbose = FALSE
+    warmup = warmup,
+    n_samples = 1,
+    chains = 1,
+    sampler = sampler,
+    verbose = FALSE
   )
-  
+
   # now loop through, sampling and updating x and returning theta
   for (i in 2:niter) {
-    
+
     # sample x given theta
     x <- p_x_bar_theta(theta[i - 1])
-    
+
     # put x in the data list
     dag <- model$dag
     target_name <- dag$tf_name(get_node(data))
     x_array <- array(x, dim = c(1, dim(data)))
     dag$tf_environment$data_list[[target_name]] <- x_array
-    
+
     # put theta in the free state
     sampler <- attr(draws, "model_info")$samplers[[1]]
     sampler$free_state <- as.matrix(theta[i - 1])
-    
+
     draws <- extra_samples(draws,
-                           n_samples = 1,
-                           verbose = FALSE
+      n_samples = 1,
+      verbose = FALSE
     )
-    
+
     theta[i] <- tail(as.numeric(draws[[1]]), 1)
   }
-  
+
   theta
 }
 
@@ -702,24 +702,24 @@ get_enough_draws <- function(model,
                              one_by_one = FALSE) {
   start_time <- Sys.time()
   draws <- mcmc(model,
-                sampler = sampler,
-                verbose = verbose,
-                one_by_one = one_by_one
+    sampler = sampler,
+    verbose = verbose,
+    one_by_one = one_by_one
   )
-  
+
   while (not_finished(draws, n_effective) &
-         not_timed_out(start_time, time_limit)) {
+    not_timed_out(start_time, time_limit)) {
     n_samples <- new_samples(draws, n_effective)
     draws <- extra_samples(draws, n_samples,
-                           verbose = verbose,
-                           one_by_one = one_by_one
+      verbose = verbose,
+      one_by_one = one_by_one
     )
   }
-  
+
   if (not_finished(draws, n_effective)) {
     stop("could not draws enough effective samples within the time limit")
   }
-  
+
   draws
 }
 
@@ -728,22 +728,22 @@ mcse <- function(draws) {
   n <- nrow(draws)
   b <- floor(sqrt(n))
   a <- floor(n / b)
-  
+
   group <- function(k) {
     idx <- ((k - 1) * b + 1):(k * b)
     colMeans(draws[idx, , drop = FALSE])
   }
-  
+
   bm <- vapply(
     seq_len(a),
     group,
     draws[1, ]
   )
-  
+
   if (is.null(dim(bm))) {
     bm <- t(bm)
   }
-  
+
   mu_hat <- as.matrix(colMeans(draws))
   ss <- sweep(t(bm), 2, mu_hat, "-")^2
   var_hat <- b * colSums(ss) / (a - 1)
@@ -761,19 +761,19 @@ scaled_error <- function(draws, expectation) {
 # given a sampler (e.g. hmc()) and minimum number of effective samples, ensure
 # that the sampler can draw correct samples from a bivariate normal distribution
 check_mvn_samples <- function(sampler, n_effective = 3000) {
-  
+
   # get multivariate normal samples
   mu <- as_data(t(rnorm(2, 0, 5)))
   sigma <- stats::rWishart(1, 3, diag(2))[, , 1]
   x <- multivariate_normal(mu, sigma)
   m <- model(x, precision = "single")
-  
+
   draws <- get_enough_draws(m,
-                            sampler = sampler,
-                            n_effective = n_effective,
-                            verbose = FALSE
+    sampler = sampler,
+    n_effective = n_effective,
+    verbose = FALSE
   )
-  
+
   # get MCMC samples for statistics of the samples (value, variance and
   # correlation of error wrt mean)
   err <- x - mu
@@ -781,7 +781,7 @@ check_mvn_samples <- function(sampler, n_effective = 3000) {
   corr <- prod(err) / prod(sqrt(diag(sigma)))
   err_var_corr <- c(err, var, corr)
   stat_draws <- calculate(err_var_corr, values = draws)
-  
+
   # get true values of these - on average the error should be 0, and the
   # variance and correlation of the errors should encoded in Sigma
   stat_truth <- c(
@@ -789,7 +789,7 @@ check_mvn_samples <- function(sampler, n_effective = 3000) {
     diag(sigma),
     cov2cor(sigma)[1, 2]
   )
-  
+
   # get absolute errors between posterior means and true values, and scale them
   # by time-series Monte Carlo standard errors (the expected amount of
   # uncertainty in the MCMC estimate), to give the number of standard errors
@@ -811,27 +811,47 @@ check_samples <- function(x,
                           one_by_one = FALSE) {
   m <- model(x, precision = "single")
   draws <- get_enough_draws(m,
-                            sampler = sampler,
-                            n_effective = n_effective,
-                            verbose = FALSE,
-                            one_by_one = one_by_one
+    sampler = sampler,
+    n_effective = n_effective,
+    verbose = FALSE,
+    one_by_one = one_by_one
   )
-  
+
   neff <- coda::effectiveSize(draws)
   iid_samples <- iid_function(neff)
   mcmc_samples <- as.matrix(draws)
-  
+
   # plot
   if (is.null(title)) {
     distrib <- get_node(x)$distribution$distribution_name
     sampler_name <- class(sampler)[1]
     title <- paste(distrib, "with", sampler_name)
   }
-  
+
   stats::qqplot(mcmc_samples, iid_samples, main = title)
   graphics::abline(0, 1)
-  
+
   # do a formal hypothesis test
   suppressWarnings(stat <- ks.test(mcmc_samples, iid_samples))
   testthat::expect_gte(stat$p.value, 0.01)
 }
+
+# zero-inflated distribution from rethinking package
+dzipois <- function(x , theta , lambda , log=FALSE ) {
+  ll <- ifelse( x==0 , theta + (1-theta)*exp(-lambda) , (1-theta)*dpois(x,lambda,FALSE) )
+  if(log){
+    return(log(ll))
+  }
+  else {
+    return(ll)
+  }
+}
+
+
+# zero-inflated negative binomial likelihood from likelihoodExplore package
+require(likelihoodExplore)
+dzinb <- function(x, theta, size, prob, log = FALSE)
+    return(liknbinom(x, size = size, prob = prob, log = log))
+
+
+
diff --git a/tests/testthat/test_distributions.R b/tests/testthat/test_distributions.R
new file mode 100644
index 0000000..23b3c71
--- /dev/null
+++ b/tests/testthat/test_distributions.R
@@ -0,0 +1,1178 @@
+test_that("normal distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::normal,
+    stats::dnorm,
+    parameters = list(mean = -2, sd = 3),
+    x = rnorm(100, -2, 3)
+  )
+})
+
+
+test_that("multidimensional normal distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::normal,
+    stats::dnorm,
+    parameters = list(mean = -2, sd = 3),
+    x = array(rnorm(100, -2, 3),
+      dim = c(10, 2, 5)
+    ),
+    dim = c(10, 2, 5)
+  )
+})
+
+test_that("uniform distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::uniform,
+    stats::dunif,
+    parameters = list(min = -2.1, max = -1.2),
+    x = runif(100, -2.1, -1.2)
+  )
+})
+
+test_that("lognormal distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::lognormal,
+    stats::dlnorm,
+    parameters = list(meanlog = 1, sdlog = 3),
+    x = rlnorm(100, 1, 3)
+  )
+})
+
+test_that("bernoulli distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::bernoulli,
+    extraDistr::dbern,
+    parameters = list(prob = 0.3),
+    x = rbinom(100, 1, 0.3)
+  )
+})
+
+test_that("binomial distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::binomial,
+    stats::dbinom,
+    parameters = list(size = 10, prob = 0.8),
+    x = rbinom(100, 10, 0.8)
+  )
+})
+
+test_that("beta-binomial distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::beta_binomial,
+    extraDistr::dbbinom,
+    parameters = list(
+      size = 10,
+      alpha = 0.8,
+      beta = 1.2
+    ),
+    x = extraDistr::rbbinom(100, 10, 0.8, 1.2)
+  )
+})
+
+test_that("negative binomial distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::negative_binomial,
+    stats::dnbinom,
+    parameters = list(size = 3.3, prob = 0.2),
+    x = rnbinom(100, 3.3, 0.2)
+  )
+})
+
+test_that("zero inflated poisson distribution has correct density", {
+
+  skip_if_not(check_tf_version())
+  source("helpers.R")
+
+  compare_distribution(zero_inflated_poisson,
+                       extraDistr::dzip,
+                       parameters = list(theta = 0.2, lambda = 2, pi = 0.2),
+                       x = extraDistr::rpois(100, 2, 0.2))
+
+})
+
+test_that("zero inflated negative binomial distribution has correct density", {
+
+  skip_if_not(check_tf_version())
+  source("helpers.R")
+
+  compare_distribution(zero_inflated_negative_binomial,
+                       extraDistr::dzinb,
+                       parameters = list(theta = 2, size = 10, prob = 0.1, pi = 0.2),
+                       x = extraDistr::rzinb(100, 10, 0.1, 0.2))
+
+})
+
+
+test_that("hypergeometric distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::hypergeometric,
+    stats::dhyper,
+    parameters = list(m = 11, n = 8, k = 5),
+    x = rhyper(100, 11, 8, 5)
+  )
+})
+
+test_that("poisson distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::poisson,
+    stats::dpois,
+    parameters = list(lambda = 17.2),
+    x = rpois(100, 17.2)
+  )
+})
+
+test_that("gamma distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::gamma,
+    stats::dgamma,
+    parameters = list(shape = 1.2, rate = 2.3),
+    x = rgamma(100, 1.2, 2.3)
+  )
+})
+
+
+test_that("inverse gamma distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::inverse_gamma,
+    extraDistr::dinvgamma,
+    parameters = list(alpha = 1.2, beta = 0.9),
+    x = extraDistr::rinvgamma(100, 1.2, 0.9)
+  )
+})
+
+test_that("weibull distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::weibull,
+    dweibull,
+    parameters = list(
+      shape = 1.2,
+      scale = 0.9
+    ),
+    x = rweibull(100, 1.2, 0.9)
+  )
+})
+
+test_that("exponential distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::exponential,
+    stats::dexp,
+    parameters = list(rate = 1.9),
+    x = rexp(100, 1.9)
+  )
+})
+
+test_that("pareto distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::pareto,
+    extraDistr::dpareto,
+    parameters = list(a = 1.9, b = 2.3),
+    x = extraDistr::rpareto(100, 1.9, 2.3)
+  )
+})
+
+test_that("student distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+  dstudent <- extraDistr::dlst
+
+  compare_distribution(
+    greta::student,
+    dstudent,
+    parameters = list(
+      df = 3,
+      mu = -0.9,
+      sigma = 2
+    ),
+    x = rnorm(100, -0.9, 2)
+  )
+})
+
+test_that("laplace distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::laplace,
+    extraDistr::dlaplace,
+    parameters = list(mu = -0.9, sigma = 2),
+    x = extraDistr::rlaplace(100, -0.9, 2)
+  )
+})
+
+test_that("beta distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::beta,
+    stats::dbeta,
+    parameters = list(
+      shape1 = 2.3,
+      shape2 = 3.4
+    ),
+    x = rbeta(100, 2.3, 3.4)
+  )
+})
+
+test_that("cauchy distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::cauchy,
+    stats::dcauchy,
+    parameters = list(
+      location = -1.3,
+      scale = 3.4
+    ),
+    x = rcauchy(100, -1.3, 3.4)
+  )
+})
+
+test_that("logistic distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::logistic,
+    stats::dlogis,
+    parameters = list(
+      location = -1.3,
+      scale = 2.1
+    ),
+    x = rlogis(100, -1.3, 2.1)
+  )
+})
+
+test_that("f distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::f,
+    df,
+    parameters = list(df1 = 5.9, df2 = 2),
+    x = rf(100, 5.9, 2)
+  )
+})
+
+test_that("chi squared distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  compare_distribution(greta::chi_squared,
+    stats::dchisq,
+    parameters = list(df = 9.3),
+    x = rchisq(100, 9.3)
+  )
+})
+
+test_that("multivariate normal distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  # parameters to test
+  m <- 5
+  mn <- t(rnorm(m))
+  sig <- rWishart(1, m + 1, diag(m))[, , 1]
+
+  # function converting Sigma to sigma
+  dmvnorm2 <- function(x, mean, Sigma, log = FALSE) { # nolint
+    mvtnorm::dmvnorm(x = x, mean = mean, sigma = Sigma, log = log)
+  }
+
+  compare_distribution(greta::multivariate_normal,
+    dmvnorm2,
+    parameters = list(mean = mn, Sigma = sig),
+    x = mvtnorm::rmvnorm(100, mn, sig),
+    multivariate = TRUE
+  )
+})
+
+test_that("Wishart distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  # parameters to test
+  m <- 5
+  df <- m + 1
+  sig <- rWishart(1, df, diag(m))[, , 1]
+
+  # wrapper for argument names
+  dwishart <- function(x, df, Sigma, log = FALSE) { # nolint
+    ans <- MCMCpack::dwish(W = x, v = df, S = Sigma)
+    if (log) {
+      ans <- log(ans)
+    }
+    ans
+  }
+
+  # no vectorised wishart, so loop through all of these
+  replicate(
+    10,
+    compare_distribution(greta::wishart,
+      dwishart,
+      parameters = list(
+        df = df,
+        Sigma = sig
+      ),
+      x = rWishart(1, df, sig)[, , 1],
+      multivariate = TRUE
+    )
+  )
+})
+
+test_that("lkj distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  # parameters to test
+  m <- 5
+  eta <- 3
+
+  # normalising component of lkj (depends only on eta and dimension)
+  lkj_log_normalising <- function(eta, n) {
+    log_pi <- log(pi)
+    ans <- 0
+    for (k in 1:(n - 1)) {
+      ans <- ans + log_pi * (k / 2)
+      ans <- ans + lgamma(eta + (n - 1 - k) / 2)
+      ans <- ans - lgamma(eta + (n - 1) / 2)
+    }
+    ans
+  }
+
+  # lkj density
+  dlkj_correlation <- function(x, eta, log = FALSE, dimension = NULL) {
+    res <- (eta - 1) * log(det(x)) - lkj_log_normalising(eta, ncol(x))
+    if (!log) {
+      res <- exp(res)
+    }
+    res
+  }
+
+  # no vectorised lkj, so loop through all of these
+  replicate(
+    10,
+    compare_distribution(greta::lkj_correlation,
+      dlkj_correlation,
+      parameters = list(eta = eta, dimension = m),
+      x = rlkjcorr(1, eta = 1, dimension = m),
+      multivariate = TRUE
+    )
+  )
+})
+
+test_that("multinomial distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  # parameters to test
+  m <- 5
+  prob <- t(runif(m))
+  size <- 5
+
+  # vectorise R's density function
+  dmultinom_vec <- function(x, size, prob) {
+    apply(x, 1, stats::dmultinom, size = size, prob = prob)
+  }
+
+  compare_distribution(greta::multinomial,
+    dmultinom_vec,
+    parameters = list(
+      size = size,
+      prob = prob
+    ),
+    x = t(rmultinom(100, size, prob)),
+    multivariate = TRUE
+  )
+})
+
+test_that("categorical distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  # parameters to test
+  m <- 5
+  prob <- t(runif(m))
+
+  # vectorise R's density function
+  dcategorical_vec <- function(x, prob) {
+    apply(x, 1, stats::dmultinom, size = 1, prob = prob)
+  }
+
+  compare_distribution(greta::categorical,
+    dcategorical_vec,
+    parameters = list(prob = prob),
+    x = t(rmultinom(100, 1, prob)),
+    multivariate = TRUE
+  )
+})
+
+test_that("dirichlet distribution has correct density", {
+  skip_if_not(check_tf_version())
+  # parameters to test
+  m <- 5
+  alpha <- t(runif(m))
+
+  compare_distribution(
+    greta_fun = greta::dirichlet,
+    r_fun = extraDistr::ddirichlet,
+    parameters = list(alpha = alpha),
+    x = extraDistr::rdirichlet(100, alpha),
+    multivariate = TRUE
+  )
+})
+
+test_that("dirichlet-multinomial distribution has correct density", {
+  skip_if_not(check_tf_version())
+
+
+  # parameters to test
+  size <- 10
+  m <- 5
+  alpha <- t(runif(m))
+
+  compare_distribution(greta::dirichlet_multinomial,
+    extraDistr::ddirmnom,
+    parameters = list(
+      size = size,
+      alpha = alpha
+    ),
+    x = extraDistr::rdirmnom(
+      100,
+      size,
+      alpha
+    ),
+    multivariate = TRUE
+  )
+})
+
+test_that("scalar-valued distributions can be defined in models", {
+  skip_if_not(check_tf_version())
+
+
+  x <- randn(5)
+  y <- round(randu(5))
+  p <- iprobit(normal(0, 1))
+
+  # variable (need to define a likelihood)
+  a <- variable()
+  distribution(x) <- normal(a, 1)
+  expect_ok(model(a))
+
+  # univariate discrete distributions
+  distribution(y) <- bernoulli(p)
+  expect_ok(model(p))
+
+  distribution(y) <- binomial(1, p)
+  expect_ok(model(p))
+
+  distribution(y) <- beta_binomial(1, p, 0.2)
+  expect_ok(model(p))
+
+  distribution(y) <- negative_binomial(1, p)
+  expect_ok(model(p))
+
+  distribution(y) <- hypergeometric(5, 5, p)
+  expect_ok(model(p))
+
+  distribution(y) <- poisson(p)
+  expect_ok(model(p))
+
+  # multivariate discrete distributions
+  y <- extraDistr::rmnom(1, size = 4, prob = runif(3))
+  p <- iprobit(normal(0, 1, dim = 3))
+  distribution(y) <- multinomial(4, t(p))
+  expect_ok(model(p))
+
+  y <- extraDistr::rmnom(1, size = 1, prob = runif(3))
+  p <- iprobit(normal(0, 1, dim = 3))
+  distribution(y) <- categorical(t(p))
+  expect_ok(model(p))
+
+  y <- extraDistr::rmnom(1, size = 4, prob = runif(3))
+  alpha <- lognormal(0, 1, dim = 3)
+  distribution(y) <- dirichlet_multinomial(4, t(alpha))
+  expect_ok(model(alpha))
+
+  # univariate continuous distributions
+  expect_ok(model(normal(-2, 3)))
+  expect_ok(model(student(5.6, -2, 2.3)))
+  expect_ok(model(laplace(-1.2, 1.1)))
+  expect_ok(model(cauchy(-1.2, 1.1)))
+  expect_ok(model(logistic(-1.2, 1.1)))
+
+  expect_ok(model(lognormal(1.2, 0.2)))
+  expect_ok(model(gamma(0.9, 1.3)))
+  expect_ok(model(exponential(6.3)))
+  expect_ok(model(beta(6.3, 5.9)))
+  expect_ok(model(inverse_gamma(0.9, 1.3)))
+  expect_ok(model(weibull(2, 1.1)))
+  expect_ok(model(pareto(2.4, 1.5)))
+  expect_ok(model(chi_squared(4.3)))
+  expect_ok(model(f(24.3, 2.4)))
+
+  expect_ok(model(uniform(-13, 2.4)))
+
+  # multivariate continuous distributions
+  sig <- rWishart(1, 4, diag(3))[, , 1]
+
+  expect_ok(model(multivariate_normal(t(rnorm(3)), sig)))
+  expect_ok(model(wishart(4, sig)))
+  expect_ok(model(lkj_correlation(5, dimension = 3)))
+  expect_ok(model(dirichlet(t(runif(3)))))
+})
+
+test_that("array-valued distributions can be defined in models", {
+  skip_if_not(check_tf_version())
+
+
+  dim <- c(5, 2)
+  x <- randn(5, 2)
+  y <- round(randu(5, 2))
+
+  # variable (need to define a likelihood)
+  a <- variable(dim = dim)
+  distribution(x) <- normal(a, 1)
+  expect_ok(model(a))
+
+  # univariate discrete distributions
+  p <- iprobit(normal(0, 1, dim = dim))
+  distribution(y) <- bernoulli(p)
+  expect_ok(model(p))
+
+  p <- iprobit(normal(0, 1, dim = dim))
+  distribution(y) <- binomial(1, p)
+  expect_ok(model(p))
+
+  p <- iprobit(normal(0, 1, dim = dim))
+  distribution(y) <- beta_binomial(1, p, 0.2)
+  expect_ok(model(p))
+
+  p <- iprobit(normal(0, 1, dim = dim))
+  distribution(y) <- negative_binomial(1, p)
+  expect_ok(model(p))
+
+  p <- iprobit(normal(0, 1, dim = dim))
+  distribution(y) <- hypergeometric(10, 5, p)
+  expect_ok(model(p))
+
+  p <- iprobit(normal(0, 1, dim = dim))
+  distribution(y) <- poisson(p)
+  expect_ok(model(p))
+
+  # multivariate discrete distributions
+  y <- extraDistr::rmnom(5, size = 4, prob = runif(3))
+  p <- iprobit(normal(0, 1, dim = 3))
+  distribution(y) <- multinomial(4, t(p), n_realisations = 5)
+  expect_ok(model(p))
+
+  y <- extraDistr::rmnom(5, size = 1, prob = runif(3))
+  p <- iprobit(normal(0, 1, dim = 3))
+  distribution(y) <- categorical(t(p), n_realisations = 5)
+  expect_ok(model(p))
+
+  y <- extraDistr::rmnom(5, size = 4, prob = runif(3))
+  alpha <- lognormal(0, 1, dim = 3)
+  distribution(y) <- dirichlet_multinomial(4, t(alpha), n_realisations = 5)
+  expect_ok(model(alpha))
+
+  # univariate continuous distributions
+  expect_ok(model(normal(-2, 3, dim = dim)))
+  expect_ok(model(student(5.6, -2, 2.3, dim = dim)))
+  expect_ok(model(laplace(-1.2, 1.1, dim = dim)))
+  expect_ok(model(cauchy(-1.2, 1.1, dim = dim)))
+  expect_ok(model(logistic(-1.2, 1.1, dim = dim)))
+
+  expect_ok(model(lognormal(1.2, 0.2, dim = dim)))
+  expect_ok(model(gamma(0.9, 1.3, dim = dim)))
+  expect_ok(model(exponential(6.3, dim = dim)))
+  expect_ok(model(beta(6.3, 5.9, dim = dim)))
+  expect_ok(model(uniform(-13, 2.4, dim = dim)))
+  expect_ok(model(inverse_gamma(0.9, 1.3, dim = dim)))
+  expect_ok(model(weibull(2, 1.1, dim = dim)))
+  expect_ok(model(pareto(2.4, 1.5, dim = dim)))
+  expect_ok(model(chi_squared(4.3, dim = dim)))
+  expect_ok(model(f(24.3, 2.4, dim = dim)))
+
+  # multivariate continuous distributions
+  sig <- rWishart(1, 4, diag(3))[, , 1]
+  expect_ok(
+    model(multivariate_normal(t(rnorm(3)), sig, n_realisations = dim[1]))
+  )
+  expect_ok(model(dirichlet(t(runif(3)), n_realisations = dim[1])))
+  expect_ok(model(wishart(4, sig)))
+  expect_ok(model(lkj_correlation(3, dimension = dim[1])))
+})
+
+test_that("distributions can be sampled from by MCMC", {
+  skip_if_not(check_tf_version())
+
+
+  x <- randn(100)
+  y <- round(randu(100))
+
+  # variable (with a density)
+  a <- variable()
+  distribution(x) <- normal(a, 1)
+  sample_distribution(a)
+
+  b <- variable(lower = -1)
+  distribution(x) <- normal(b, 1)
+  sample_distribution(b)
+
+  c <- variable(upper = -2)
+  distribution(x) <- normal(c, 1)
+  sample_distribution(c)
+
+  d <- variable(lower = 1.2, upper = 1.3)
+  distribution(x) <- normal(d, 1)
+  sample_distribution(d)
+
+  # univariate discrete
+  p <- iprobit(normal(0, 1, dim = 100))
+  distribution(y) <- bernoulli(p)
+  sample_distribution(p)
+
+  p <- iprobit(normal(0, 1, dim = 100))
+  distribution(y) <- binomial(1, p)
+  sample_distribution(p)
+
+  p <- iprobit(normal(0, 1, dim = 100))
+  distribution(y) <- negative_binomial(1, p)
+  sample_distribution(p)
+
+  p <- iprobit(normal(0, 1, dim = 100))
+  distribution(y) <- hypergeometric(10, 5, p)
+  sample_distribution(p)
+
+  p <- iprobit(normal(0, 1, dim = 100))
+  distribution(y) <- poisson(p)
+  sample_distribution(p)
+
+  p <- iprobit(normal(0, 1, dim = 100))
+  distribution(y) <- beta_binomial(1, p, 0.3)
+  sample_distribution(p)
+
+  # multivariate discrete
+  y <- extraDistr::rmnom(5, size = 4, prob = runif(3))
+  p <- uniform(0, 1, dim = 3)
+  distribution(y) <- multinomial(4, t(p), n_realisations = 5)
+  sample_distribution(p)
+
+  y <- extraDistr::rmnom(5, size = 1, prob = runif(3))
+  p <- iprobit(normal(0, 1, dim = 3))
+  distribution(y) <- categorical(t(p), n_realisations = 5)
+  sample_distribution(p)
+
+  y <- extraDistr::rmnom(5, size = 4, prob = runif(3))
+  alpha <- lognormal(0, 1, dim = 3)
+  distribution(y) <- dirichlet_multinomial(4, t(alpha), n_realisations = 5)
+  sample_distribution(alpha)
+
+  # univariate continuous
+  sample_distribution(normal(-2, 3))
+  sample_distribution(student(5.6, -2, 2.3))
+  sample_distribution(laplace(-1.2, 1.1))
+  sample_distribution(cauchy(-1.2, 1.1))
+  sample_distribution(logistic(-1.2, 1.1))
+
+  sample_distribution(lognormal(1.2, 0.2), lower = 0)
+  sample_distribution(gamma(0.9, 1.3), lower = 0)
+  sample_distribution(exponential(6.3), lower = 0)
+  sample_distribution(beta(6.3, 5.9), lower = 0, upper = 1)
+  sample_distribution(inverse_gamma(0.9, 1.3), lower = 0)
+  sample_distribution(weibull(2, 1.1), lower = 0)
+  sample_distribution(pareto(2.4, 0.1), lower = 0.1)
+  sample_distribution(chi_squared(4.3), lower = 0)
+  sample_distribution(f(24.3, 2.4), lower = 0)
+
+  sample_distribution(uniform(-13, 2.4), lower = -13, upper = 2.4)
+
+  # multivariate continuous
+  sig <- rWishart(1, 4, diag(3))[, , 1]
+  sample_distribution(multivariate_normal(t(rnorm(3)), sig))
+  sample_distribution(wishart(10L, Sig = diag(2)), warmup = 0)
+  sample_distribution(lkj_correlation(4, dimension = 3))
+  sample_distribution(dirichlet(t(runif(3))))
+})
+
+test_that("uniform distribution errors informatively", {
+  skip_if_not(check_tf_version())
+  skip_on_ci()
+
+
+  # bad types
+  expect_snapshot_error(
+    uniform(min = 0, max = NA)
+  )
+
+  expect_snapshot_error(
+    uniform(min = 0, max = head)
+  )
+
+  expect_snapshot_error(
+    uniform(min = 1:3, max = 5)
+  )
+
+  # good types, bad values
+  expect_snapshot_error(
+    uniform(min = -Inf, max = Inf)
+  )
+
+  # lower not below upper
+  expect_snapshot_error(
+    uniform(min = 1, max = 1)
+  )
+
+})
+
+test_that("poisson() and binomial() error informatively in glm", {
+  skip_on_ci()
+  skip_if_not(check_tf_version())
+
+  # if passed as an object
+  expect_snapshot_error(
+    glm(1 ~ 1, family = poisson)
+  )
+
+  expect_snapshot_error(
+    glm(1 ~ 1, family = binomial)
+  )
+
+  # if executed alone
+  expect_snapshot_error(
+    glm(1 ~ 1, family = poisson())
+  )
+
+  # if given a link
+  expect_snapshot_error(
+    glm(1 ~ 1, family = poisson("sqrt"))
+  )
+})
+
+test_that("wishart distribution errors informatively", {
+  skip_if_not(check_tf_version())
+
+
+  a <- randn(3, 3)
+  b <- randn(3, 3, 3)
+  c <- randn(3, 2)
+
+  expect_true(inherits(
+    wishart(3, a),
+    "greta_array"
+  ))
+
+  expect_snapshot_error(
+    wishart(3, b)
+  )
+
+  expect_snapshot_error(
+    wishart(3, c)
+  )
+
+})
+
+
+test_that("lkj_correlation distribution errors informatively", {
+  skip_if_not(check_tf_version())
+
+
+  dim <- 3
+
+  expect_true(inherits(
+    lkj_correlation(3, dim),
+    "greta_array"
+  ))
+
+  expect_snapshot_error(
+    lkj_correlation(-1, dim)
+  )
+
+  expect_snapshot_error(
+    lkj_correlation(c(3, 3), dim)
+  )
+
+  expect_snapshot_error(
+    lkj_correlation(uniform(0, 1, dim = 2), dim)
+  )
+
+  expect_snapshot_error(
+    lkj_correlation(4, dimension = -1)
+  )
+
+  expect_snapshot_error(
+    lkj_correlation(4, dim = c(3, 3))
+  )
+
+  expect_snapshot_error(
+    lkj_correlation(4, dim = NA)
+  )
+})
+
+test_that("multivariate_normal distribution errors informatively", {
+  skip_if_not(check_tf_version())
+
+
+  m_a <- randn(1, 3)
+  m_b <- randn(2, 3)
+  m_c <- randn(3)
+  m_d <- randn(3, 1)
+
+  a <- randn(3, 3)
+  b <- randn(3, 3, 3)
+  c <- randn(3, 2)
+  d <- randn(4, 4)
+
+  # good means
+  expect_true(inherits(
+    multivariate_normal(m_a, a),
+    "greta_array"
+  ))
+
+  expect_true(inherits(
+    multivariate_normal(m_b, a),
+    "greta_array"
+  ))
+
+  # bad means
+  expect_snapshot_error(
+    multivariate_normal(m_c, a)
+  )
+
+  expect_snapshot_error(
+    multivariate_normal(m_d, a)
+  )
+
+  # good sigmas
+  expect_true(inherits(
+    multivariate_normal(m_a, a),
+    "greta_array"
+  ))
+
+  # bad sigmas
+  expect_snapshot_error(
+    multivariate_normal(m_a, b)
+  )
+
+  expect_snapshot_error(
+    multivariate_normal(m_a, c)
+  )
+
+  # mismatched parameters
+  expect_snapshot_error(
+    multivariate_normal(m_a, d)
+  )
+
+  # scalars
+  expect_snapshot_error(
+    multivariate_normal(0, 1)
+  )
+
+  # bad n_realisations
+  expect_snapshot_error(
+    multivariate_normal(m_a, a, n_realisations = -1)
+  )
+
+  expect_snapshot_error(
+    multivariate_normal(m_a, a, n_realisations = c(1, 3))
+  )
+
+  # bad dimension
+  expect_snapshot_error(
+    multivariate_normal(m_a, a, dimension = -1)
+  )
+
+  expect_snapshot_error(
+    multivariate_normal(m_a, a, dimension = c(1, 3))
+  )
+})
+
+test_that("multinomial distribution errors informatively", {
+  skip_if_not(check_tf_version())
+
+
+  p_a <- randu(1, 3)
+  p_b <- randu(2, 3)
+
+  # same size & probs
+  expect_true(inherits(
+    multinomial(size = 10, p_a),
+    "greta_array"
+  ))
+
+  expect_true(inherits(
+    multinomial(size = 1:2, p_b),
+    "greta_array"
+  ))
+
+  # n_realisations from prob
+  expect_true(inherits(
+    multinomial(10, p_b),
+    "greta_array"
+  ))
+
+  # n_realisations from size
+  expect_true(inherits(
+    multinomial(c(1, 2), p_a),
+    "greta_array"
+  ))
+
+  # scalars
+  expect_snapshot_error(
+    multinomial(c(1), 1)
+  )
+
+  # bad n_realisations
+  expect_snapshot_error(
+    multinomial(10, p_a, n_realisations = -1)
+  )
+
+  expect_snapshot_error(
+    multinomial(10, p_a, n_realisations = c(1, 3))
+  )
+
+  # bad dimension
+  expect_snapshot_error(
+    multinomial(10, p_a, dimension = -1)
+  )
+
+  expect_snapshot_error(
+    multinomial(10, p_a, dimension = c(1, 3))
+  )
+})
+
+test_that("categorical distribution errors informatively", {
+  skip_if_not(check_tf_version())
+
+
+  p_a <- randu(1, 3)
+  p_b <- randu(2, 3)
+
+  # good probs
+  expect_true(inherits(
+    categorical(p_a),
+    "greta_array"
+  ))
+
+  expect_true(inherits(
+    categorical(p_b),
+    "greta_array"
+  ))
+
+  # scalars
+  expect_snapshot_error(
+    categorical(1),
+  )
+
+  # bad n_realisations
+  expect_snapshot_error(
+    categorical(p_a, n_realisations = -1)
+  )
+
+  expect_snapshot_error(
+    categorical(p_a, n_realisations = c(1, 3))
+  )
+
+  # bad dimension
+  expect_snapshot_error(
+    categorical(p_a, dimension = -1)
+  )
+
+  expect_snapshot_error(
+    categorical(p_a, dimension = c(1, 3))
+  )
+})
+
+test_that("dirichlet distribution errors informatively", {
+  skip_if_not(check_tf_version())
+
+
+  alpha_a <- randu(1, 3)
+  alpha_b <- randu(2, 3)
+
+  # good alpha
+  expect_true(inherits(
+    dirichlet(alpha_a),
+    "greta_array"
+  ))
+
+
+  expect_true(inherits(
+    dirichlet(alpha_b),
+    "greta_array"
+  ))
+
+  # scalars
+  expect_snapshot_error(
+    dirichlet(1),
+  )
+
+  # bad n_realisations
+  expect_snapshot_error(
+    dirichlet(alpha_a, n_realisations = -1)
+  )
+
+  expect_snapshot_error(
+    dirichlet(alpha_a, n_realisations = c(1, 3))
+  )
+
+  # bad dimension
+  expect_snapshot_error(
+    dirichlet(alpha_a, dimension = -1)
+  )
+
+  expect_snapshot_error(
+    dirichlet(alpha_a, dimension = c(1, 3))
+  )
+})
+
+
+test_that("dirichlet values sum to one", {
+  skip_if_not(check_tf_version())
+
+
+  alpha <- uniform(0, 10, dim = c(1, 5))
+  x <- dirichlet(alpha)
+  m <- model(x)
+  draws <- mcmc(m, n_samples = 100, warmup = 100, verbose = FALSE)
+
+  sums <- rowSums(as.matrix(draws))
+  compare_op(sums, 1)
+})
+
+test_that("dirichlet-multinomial distribution errors informatively", {
+  skip_if_not(check_tf_version())
+
+
+  alpha_a <- randu(1, 3)
+  alpha_b <- randu(2, 3)
+
+
+  # same size & probs
+  expect_true(inherits(
+    dirichlet_multinomial(size = 10, alpha_a),
+    "greta_array"
+  ))
+
+  expect_true(inherits(
+    dirichlet_multinomial(size = 1:2, alpha_b),
+    "greta_array"
+  ))
+
+  # n_realisations from alpha
+  expect_true(inherits(
+    dirichlet_multinomial(10, alpha_b),
+    "greta_array"
+  ))
+
+  # n_realisations from size
+  expect_true(inherits(
+    dirichlet_multinomial(c(1, 2), alpha_a),
+    "greta_array"
+  ))
+
+  # scalars
+  expect_snapshot_error(
+    dirichlet_multinomial(c(1), 1)
+  )
+
+  # bad n_realisations
+  expect_snapshot_error(
+    dirichlet_multinomial(10, alpha_a, n_realisations = -1)
+  )
+
+  expect_snapshot_error(
+    dirichlet_multinomial(10, alpha_a, n_realisations = c(1, 3))
+  )
+
+  # bad dimension
+  expect_snapshot_error(
+    dirichlet_multinomial(10, alpha_a, dimension = -1)
+  )
+
+  expect_snapshot_error(
+    dirichlet_multinomial(10, alpha_a, dimension = c(1, 3))
+  )
+})
+
+test_that("Wishart can use a choleskied Sigma", {
+  skip_if_not(check_tf_version())
+
+  sig <- lkj_correlation(2, dim = 2)
+  w <- wishart(5, sig)
+  m <- model(w, precision = "double")
+  expect_ok(draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE))
+})
+
+test_that("multivariate distribs with matrix params can be sampled from", {
+  skip_if_not(check_tf_version())
+
+  n <- 10
+  k <- 3
+
+  # multivariate normal
+  x <- randn(n, k)
+  mu <- normal(0, 1, dim = c(n, k))
+  distribution(x) <- multivariate_normal(mu, diag(k))
+  m <- model(mu)
+  expect_ok(draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE))
+
+  # multinomial
+  size <- 5
+  x <- t(rmultinom(n, size, runif(k)))
+  p <- uniform(0, 1, dim = c(n, k))
+  distribution(x) <- multinomial(size, p)
+  m <- model(p)
+  expect_ok(draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE))
+
+  # categorical
+  x <- t(rmultinom(n, 1, runif(k)))
+  p <- uniform(0, 1, dim = c(n, k))
+  distribution(x) <- categorical(p)
+  m <- model(p)
+  expect_ok(draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE))
+
+  # dirichlet
+  x <- randu(n, k)
+  x <- sweep(x, 1, rowSums(x), "/")
+  a <- normal(0, 1, dim = c(n, k))
+  distribution(x) <- dirichlet(a)
+  m <- model(a)
+  expect_ok(draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE))
+
+  # dirichlet multinomial
+  size <- 5
+  x <- t(rmultinom(n, size, runif(k)))
+  a <- normal(0, 1, dim = c(n, k))
+  distribution(x) <- dirichlet_multinomial(size, a)
+  m <- model(a)
+  expect_ok(draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE))
+})

From 365fd84180dc43fb2a8af0502222b759145350e1 Mon Sep 17 00:00:00 2001
From: Nicholas Tierney <nicholas.tierney@gmail.com>
Date: Fri, 29 Jul 2022 16:32:00 +0800
Subject: [PATCH 02/19] removing re-exported parts of greta

---
 R/package.R                         |    1 +
 R/probability_distributions.R       | 1743 ---------------------------
 R/zero_inflateds.R                  |  104 ++
 tests/testthat/helpers.R            |   17 +
 tests/testthat/test_distributions.R | 1178 ------------------
 tests/testthat/test_zip_zinb.R      |   23 +
 6 files changed, 145 insertions(+), 2921 deletions(-)
 delete mode 100644 R/probability_distributions.R
 create mode 100644 R/zero_inflateds.R
 delete mode 100644 tests/testthat/test_distributions.R
 create mode 100644 tests/testthat/test_zip_zinb.R

diff --git a/R/package.R b/R/package.R
index fdf1251..eb50f27 100644
--- a/R/package.R
+++ b/R/package.R
@@ -7,6 +7,7 @@
 #' 
 #' @importFrom tensorflow tf
 #' @importFrom greta .internals
+#' @importFrom R6 R6Class
 #' 
 #' @examples
 #' 
diff --git a/R/probability_distributions.R b/R/probability_distributions.R
deleted file mode 100644
index 333530f..0000000
--- a/R/probability_distributions.R
+++ /dev/null
@@ -1,1743 +0,0 @@
-uniform_distribution <- R6Class(
-  "uniform_distribution",
-  inherit = distribution_node,
-  public = list(
-    min = NA,
-    max = NA,
-    initialize = function(min, max, dim) {
-      if (inherits(min, "greta_array") | inherits(max, "greta_array")) {
-        msg <- cli::format_error(
-          "{.arg min} and {.arg max} must be fixed, they cannot be another \\
-          greta array"
-        )
-        stop(
-          msg,
-          call. = FALSE
-        )
-      }
-
-      good_types <- is.numeric(min) && length(min) == 1 &
-        is.numeric(max) && length(max) == 1
-
-      if (!good_types) {
-        msg <- cli::format_error(
-          c(
-            "{.arg min} and {.arg max} must be numeric vectors of length 1",
-            "They have class and length:",
-            "{.arg min}: {class(min)}, {length(min)}",
-            "{.arg max}: {class(max)}, {length(max)}"
-          )
-        )
-        stop(
-          msg,
-          call. = FALSE
-        )
-      }
-
-      if (!is.finite(min) | !is.finite(max)) {
-        msg <- cli::format_error(
-          c(
-            "{.arg min} and {.arg max} must finite scalars",
-            "Their values are:",
-            "{.arg min}: {min}",
-            "{.arg max}: {max}"
-          )
-        )
-        stop(
-          msg,
-          call. = FALSE
-        )
-      }
-
-      if (min >= max) {
-        msg <- cli::format_error(
-          c(
-            "{.arg max} must be greater than {.arg min}",
-            "Their values are:",
-            "{.arg min}: {min}",
-            "{.arg max}: {max}"
-          )
-        )
-        stop(
-          msg,
-          call. = FALSE
-        )
-      }
-
-      # store min and max as numeric scalars (needed in create_target, done in
-      # initialisation)
-      self$min <- min
-      self$max <- max
-      self$bounds <- c(min, max)
-
-      # initialize the rest
-      super$initialize("uniform", dim)
-
-      # add them as parents and greta arrays
-      min <- as.greta_array(min)
-      max <- as.greta_array(max)
-      self$add_parameter(min, "min")
-      self$add_parameter(max, "max")
-    },
-
-    # default value (ignore any truncation arguments)
-    create_target = function(...) {
-      vble(
-        truncation = c(self$min, self$max),
-        dim = self$dim
-      )
-    },
-    tf_distrib = function(parameters, dag) {
-      tfp$distributions$Uniform(
-        low = parameters$min,
-        high = parameters$max
-      )
-    }
-  )
-)
-
-normal_distribution <- R6Class(
-  "normal_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(mean, sd, dim, truncation) {
-      mean <- as.greta_array(mean)
-      sd <- as.greta_array(sd)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(mean, sd, target_dim = dim)
-      super$initialize("normal", dim, truncation)
-      self$add_parameter(mean, "mean")
-      self$add_parameter(sd, "sd")
-    },
-    tf_distrib = function(parameters, dag) {
-      tfp$distributions$Normal(
-        loc = parameters$mean,
-        scale = parameters$sd
-      )
-    }
-  )
-)
-
-lognormal_distribution <- R6Class(
-  "lognormal_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(meanlog, sdlog, dim, truncation) {
-      meanlog <- as.greta_array(meanlog)
-      sdlog <- as.greta_array(sdlog)
-
-      dim <- check_dims(meanlog, sdlog, target_dim = dim)
-      check_positive(truncation)
-      self$bounds <- c(0, Inf)
-      super$initialize("lognormal", dim, truncation)
-      self$add_parameter(meanlog, "meanlog")
-      self$add_parameter(sdlog, "sdlog")
-    },
-
-    # nolint start
-    tf_distrib = function(parameters, dag) {
-      tfp$distributions$LogNormal(
-        loc = parameters$meanlog,
-        scale = parameters$sdlog
-      )
-    }
-    # nolint end
-  )
-)
-
-bernoulli_distribution <- R6Class(
-  "bernoulli_distribution",
-  inherit = distribution_node,
-  public = list(
-    prob_is_logit = FALSE,
-    prob_is_probit = FALSE,
-    initialize = function(prob, dim) {
-      prob <- as.greta_array(prob)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(prob, target_dim = dim)
-      super$initialize("bernoulli", dim, discrete = TRUE)
-
-      if (has_representation(prob, "logit")) {
-        prob <- representation(prob, "logit")
-        self$prob_is_logit <- TRUE
-      } else if (has_representation(prob, "probit")) {
-        prob <- representation(prob, "probit")
-        self$prob_is_probit <- TRUE
-      }
-
-      self$add_parameter(prob, "prob")
-    },
-    tf_distrib = function(parameters, dag) {
-      if (self$prob_is_logit) {
-        tfp$distributions$Bernoulli(logits = parameters$prob)
-      } else if (self$prob_is_probit) {
-
-        # in the probit case, get the log probability of success and compute the
-        # log prob directly
-        probit <- parameters$prob
-        d <- tfp$distributions$Normal(fl(0), fl(1))
-        lprob <- d$log_cdf(probit)
-        lprobnot <- d$log_cdf(-probit)
-
-        log_prob <- function(x) {
-          x * lprob + (fl(1) - x) * lprobnot
-        }
-
-        list(log_prob = log_prob)
-      } else {
-        tfp$distributions$Bernoulli(probs = parameters$prob)
-      }
-    }
-  )
-)
-
-binomial_distribution <- R6Class(
-  "binomial_distribution",
-  inherit = distribution_node,
-  public = list(
-    prob_is_logit = FALSE,
-    prob_is_probit = FALSE,
-    initialize = function(size, prob, dim) {
-      size <- as.greta_array(size)
-      prob <- as.greta_array(prob)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(size, prob, target_dim = dim)
-      super$initialize("binomial", dim, discrete = TRUE)
-
-      if (has_representation(prob, "logit")) {
-        prob <- representation(prob, "logit")
-        self$prob_is_logit <- TRUE
-      } else if (has_representation(prob, "probit")) {
-        prob <- representation(prob, "probit")
-        self$prob_is_probit <- TRUE
-      }
-
-      self$add_parameter(prob, "prob")
-      self$add_parameter(size, "size")
-    },
-    tf_distrib = function(parameters, dag) {
-      if (self$prob_is_logit) {
-        tfp$distributions$Binomial(
-          total_count = parameters$size,
-          logits = parameters$prob
-        )
-      } else if (self$prob_is_probit) {
-
-        # in the probit case, get the log probability of success and compute the
-        # log prob directly
-        size <- parameters$size
-        probit <- parameters$prob
-        d <- tfp$distributions$Normal(fl(0), fl(1))
-        lprob <- d$log_cdf(probit)
-        lprobnot <- d$log_cdf(-probit)
-
-        log_prob <- function(x) {
-          log_choose <- tf$math$lgamma(size + fl(1)) -
-            tf$math$lgamma(x + fl(1)) -
-            tf$math$lgamma(size - x + fl(1))
-          log_choose + x * lprob + (size - x) * lprobnot
-        }
-
-        list(log_prob = log_prob)
-      } else {
-        tfp$distributions$Binomial(
-          total_count = parameters$size,
-          probs = parameters$prob
-        )
-      }
-    }
-  )
-)
-
-beta_binomial_distribution <- R6Class(
-  "beta_binomial_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(size, alpha, beta, dim) {
-      size <- as.greta_array(size)
-      alpha <- as.greta_array(alpha)
-      beta <- as.greta_array(beta)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(size, alpha, beta, target_dim = dim)
-      super$initialize("beta_binomial", dim, discrete = TRUE)
-      self$add_parameter(size, "size")
-      self$add_parameter(alpha, "alpha")
-      self$add_parameter(beta, "beta")
-    },
-    tf_distrib = function(parameters, dag) {
-      size <- parameters$size
-      alpha <- parameters$alpha
-      beta <- parameters$beta
-
-      log_prob <- function(x) {
-        tf_lchoose(size, x) +
-          tf_lbeta(x + alpha, size - x + beta) -
-          tf_lbeta(alpha, beta)
-      }
-
-      # generate a beta, then a binomial
-      sample <- function(seed) {
-        beta <- tfp$distributions$Beta(
-          concentration1 = alpha,
-          concentration0 = beta
-        )
-        probs <- beta$sample(seed = seed)
-        binomial <- tfp$distributions$Binomial(
-          total_count = size,
-          probs = probs
-        )
-        binomial$sample(seed = seed)
-      }
-
-      list(log_prob = log_prob, sample = sample)
-    }
-  )
-)
-
-poisson_distribution <- R6Class(
-  "poisson_distribution",
-  inherit = distribution_node,
-  public = list(
-    lambda_is_log = FALSE,
-    initialize = function(lambda, dim) {
-      lambda <- as.greta_array(lambda)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(lambda, target_dim = dim)
-      super$initialize("poisson", dim, discrete = TRUE)
-
-      if (has_representation(lambda, "log")) {
-        lambda <- representation(lambda, "log")
-        self$lambda_is_log <- TRUE
-      }
-      self$add_parameter(lambda, "lambda")
-    },
-    tf_distrib = function(parameters, dag) {
-      if (self$lambda_is_log) {
-        log_lambda <- parameters$lambda
-      } else {
-        log_lambda <- tf$math$log(parameters$lambda)
-      }
-
-      tfp$distributions$Poisson(log_rate = log_lambda)
-    }
-  )
-)
-
-negative_binomial_distribution <- R6Class(
-  "negative_binomial_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(size, prob, dim) {
-      size <- as.greta_array(size)
-      prob <- as.greta_array(prob)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(size, prob, target_dim = dim)
-      super$initialize("negative_binomial", dim, discrete = TRUE)
-      self$add_parameter(size, "size")
-      self$add_parameter(prob, "prob")
-    },
-
-    # nolint start
-    tf_distrib = function(parameters, dag) {
-      tfp$distributions$NegativeBinomial(
-        total_count = parameters$size,
-        probs = fl(1) - parameters$prob
-      )
-    }
-    # nolint end
-  )
-)
-
-zero_inflated_poisson_distribution <- R6Class(
-  "zero_inflated_poisson_distribution",
-  inherit = greta::.internals$nodes$node_classes$distribution_node,
-  public = list(
-    initialize = function(theta, lambda, dim) {
-      theta <- as.greta_array(theta)
-      lambda <- as.greta_array(lambda)
-      # add the nodes as children and parameters
-      dim <- check_dims(theta, lambda, target_dim = dim)
-      super$initialize("zero_inflated_poisson", dim, discrete = TRUE)
-      self$add_parameter(theta, "theta")
-      self$add_parameter(lambda, "lambda")
-    },
-
-    tf_distrib = function(parameters, dag) {
-      theta <- parameters$theta
-      lambda <- parameters$lambda
-      log_prob <- function(x) {
-
-        tf$math$log(theta * tf$nn$relu(fl(1) - x) + (fl(1) - theta) * tf$pow(lambda, x) * tf$exp(-lambda) / tf$exp(tf$math$lgamma(x + fl(1))))
-      }
-
-      sample <- function(seed) {
-
-        binom <- tfp$distributions$Binomial(total_count = 1, probs = theta)
-        pois <- tfp$distributions$Poisson(rate = lambda)
-
-        zi <- binom$sample(seed = seed)
-        lbd <- pois$sample(seed = seed)
-
-        (fl(1) - zi) * lbd
-
-      }
-
-      list(log_prob = log_prob, sample = sample, cdf = NULL, log_cdf = NULL)
-    },
-
-    tf_cdf_function = NULL,
-    tf_log_cdf_function = NULL
-  )
-)
-
-
-zero_inflated_negative_binomial_distribution <- R6Class(
-  "zero_inflated_negative_binomial_distribution",
-  inherit = greta::.internals$nodes$node_classes$distribution_node,
-  public = list(
-    initialize = function(theta, size, prob, dim) {
-      theta <- as.greta_array(theta)
-      size <- as.greta_array(size)
-      prob <- as.greta_array(prob)
-      # add the nodes as children and parameters
-      dim <- check_dims(theta, size, prob, target_dim = dim)
-      super$initialize("zero_inflated_negative_binomial", dim, discrete = TRUE)
-      self$add_parameter(theta, "theta")
-      self$add_parameter(size, "size")
-      self$add_parameter(prob, "prob")
-    },
-  
-    tf_distrib = function(parameters, dag) {
-      theta <- parameters$theta
-      size <- parameters$size
-      p <- parameters$prob # probability of success
-      q <- fl(1) - parameters$prob 
-      log_prob <- function(x) {
-
-        tf$math$log(theta * tf$nn$relu(fl(1) - x) + (fl(1) - theta) * tf$pow(p, size) * tf$pow(q, x) * tf$exp(tf$math$lgamma(x + size)) / tf$exp(tf$math$lgamma(size)) / tf$exp(tf$math$lgamma(x + fl(1))))
-
-      }
-
-      sample <- function(seed) {
-
-        binom <- tfp$distributions$Binomial(total_count = 1, probs = theta)
-        negbin <- tfp$distributions$NegativeBinomial(total_count = size, probs = q) # change of proba / parametrisation in 'stats'
-
-        zi <- binom$sample(seed = seed)
-        lbd <- negbin$sample(seed = seed)
-
-        (fl(1) - zi) * lbd
-
-      }
-
-      list(log_prob = log_prob, sample = sample, cdf = NULL, log_cdf = NULL)
-    },
-
-    tf_cdf_function = NULL,
-    tf_log_cdf_function = NULL
-  )
-)
-
-
-
-hypergeometric_distribution <- R6Class(
-  "hypergeometric_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(m, n, k, dim) {
-      m <- as.greta_array(m)
-      n <- as.greta_array(n)
-      k <- as.greta_array(k)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(m, n, k, target_dim = dim)
-      super$initialize("hypergeometric", dim, discrete = TRUE)
-      self$add_parameter(m, "m")
-      self$add_parameter(n, "n")
-      self$add_parameter(k, "k")
-    },
-    tf_distrib = function(parameters, dag) {
-      m <- parameters$m
-      n <- parameters$n
-      k <- parameters$k
-
-      log_prob <- function(x) {
-        tf_lchoose(m, x) +
-          tf_lchoose(n, k - x) -
-          tf_lchoose(m + n, k)
-      }
-
-      list(log_prob = log_prob)
-    }
-  )
-)
-
-gamma_distribution <- R6Class(
-  "gamma_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(shape, rate, dim, truncation) {
-      shape <- as.greta_array(shape)
-      rate <- as.greta_array(rate)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(shape, rate, target_dim = dim)
-      check_positive(truncation)
-      self$bounds <- c(0, Inf)
-      super$initialize("gamma", dim, truncation)
-      self$add_parameter(shape, "shape")
-      self$add_parameter(rate, "rate")
-    },
-    tf_distrib = function(parameters, dag) {
-      tfp$distributions$Gamma(
-        concentration = parameters$shape,
-        rate = parameters$rate
-      )
-    }
-  )
-)
-
-inverse_gamma_distribution <- R6Class(
-  "inverse_gamma_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(alpha, beta, dim, truncation) {
-      alpha <- as.greta_array(alpha)
-      beta <- as.greta_array(beta)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(alpha, beta, target_dim = dim)
-      check_positive(truncation)
-      self$bounds <- c(0, Inf)
-      super$initialize("inverse_gamma", dim, truncation)
-      self$add_parameter(alpha, "alpha")
-      self$add_parameter(beta, "beta")
-    },
-
-    # nolint start
-    tf_distrib = function(parameters, dag) {
-      tfp$distributions$InverseGamma(
-        concentration = parameters$alpha,
-        rate = parameters$beta
-      )
-    }
-    # nolint end
-  )
-)
-
-weibull_distribution <- R6Class(
-  "weibull_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(shape, scale, dim, truncation) {
-      shape <- as.greta_array(shape)
-      scale <- as.greta_array(scale)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(shape, scale, target_dim = dim)
-      check_positive(truncation)
-      self$bounds <- c(0, Inf)
-      super$initialize("weibull", dim, truncation)
-      self$add_parameter(shape, "shape")
-      self$add_parameter(scale, "scale")
-    },
-    tf_distrib = function(parameters, dag) {
-      a <- parameters$shape
-      b <- parameters$scale
-
-      # use the TFP Weibull CDF bijector
-      bijector <- tfp$bijectors$Weibull(scale = b, concentration = a)
-
-      log_prob <- function(x) {
-        log(a) - log(b) + (a - fl(1)) * (log(x) - log(b)) - (x / b)^a
-      }
-
-      cdf <- function(x) {
-        bijector$forward(x)
-      }
-
-      log_cdf <- function(x) {
-        log(cdf(x))
-      }
-
-      quantile <- function(x) {
-        bijector$inverse(x)
-      }
-
-      sample <- function(seed) {
-
-        # sample by pushing standard uniforms through the inverse cdf
-        u <- tf_randu(self$dim, dag)
-        quantile(u)
-      }
-
-      list(
-        log_prob = log_prob,
-        cdf = cdf,
-        log_cdf = log_cdf,
-        quantile = quantile,
-        sample = sample
-      )
-    }
-  )
-)
-
-exponential_distribution <- R6Class(
-  "exponential_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(rate, dim, truncation) {
-      rate <- as.greta_array(rate)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(rate, target_dim = dim)
-      check_positive(truncation)
-      self$bounds <- c(0, Inf)
-      super$initialize("exponential", dim, truncation)
-      self$add_parameter(rate, "rate")
-    },
-    tf_distrib = function(parameters, dag) {
-      tfp$distributions$Exponential(rate = parameters$rate)
-    }
-  )
-)
-
-pareto_distribution <- R6Class(
-  "pareto_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(a, b, dim, truncation) {
-      a <- as.greta_array(a)
-      b <- as.greta_array(b)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(a, b, target_dim = dim)
-      check_positive(truncation)
-      self$bounds <- c(0, Inf)
-      super$initialize("pareto", dim, truncation)
-      self$add_parameter(a, "a")
-      self$add_parameter(b, "b")
-    },
-    tf_distrib = function(parameters, dag) {
-
-      # a is shape, b is scale
-      tfp$distributions$Pareto(
-        concentration = parameters$a,
-        scale = parameters$b
-      )
-    }
-  )
-)
-
-student_distribution <- R6Class(
-  "student_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(df, mu, sigma, dim, truncation) {
-      df <- as.greta_array(df)
-      mu <- as.greta_array(mu)
-      sigma <- as.greta_array(sigma)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(df, mu, sigma, target_dim = dim)
-      super$initialize("student", dim, truncation)
-      self$add_parameter(df, "df")
-      self$add_parameter(mu, "mu")
-      self$add_parameter(sigma, "sigma")
-    },
-
-    # nolint start
-    tf_distrib = function(parameters, dag) {
-      tfp$distributions$StudentT(
-        df = parameters$df,
-        loc = parameters$mu,
-        scale = parameters$sigma
-      )
-    }
-    # nolint end
-  )
-)
-
-laplace_distribution <- R6Class(
-  "laplace_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(mu, sigma, dim, truncation) {
-      mu <- as.greta_array(mu)
-      sigma <- as.greta_array(sigma)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(mu, sigma, target_dim = dim)
-      super$initialize("laplace", dim, truncation)
-      self$add_parameter(mu, "mu")
-      self$add_parameter(sigma, "sigma")
-    },
-    tf_distrib = function(parameters, dag) {
-      tfp$distributions$Laplace(
-        loc = parameters$mu,
-        scale = parameters$sigma
-      )
-    }
-  )
-)
-
-beta_distribution <- R6Class(
-  "beta_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(shape1, shape2, dim, truncation) {
-      shape1 <- as.greta_array(shape1)
-      shape2 <- as.greta_array(shape2)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(shape1, shape2, target_dim = dim)
-      check_unit(truncation)
-      self$bounds <- c(0, 1)
-      super$initialize("beta", dim, truncation)
-      self$add_parameter(shape1, "shape1")
-      self$add_parameter(shape2, "shape2")
-    },
-    tf_distrib = function(parameters, dag) {
-      tfp$distributions$Beta(
-        concentration1 = parameters$shape1,
-        concentration0 = parameters$shape2
-      )
-    }
-  )
-)
-
-cauchy_distribution <- R6Class(
-  "cauchy_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(location, scale, dim, truncation) {
-      location <- as.greta_array(location)
-      scale <- as.greta_array(scale)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(location, scale, target_dim = dim)
-      super$initialize("cauchy", dim, truncation)
-      self$add_parameter(location, "location")
-      self$add_parameter(scale, "scale")
-    },
-    tf_distrib = function(parameters, dag) {
-      tfp$distributions$Cauchy(
-        loc = parameters$location,
-        scale = parameters$scale
-      )
-    }
-  )
-)
-
-chi_squared_distribution <- R6Class(
-  "chi_squared_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(df, dim, truncation) {
-      df <- as.greta_array(df)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(df, target_dim = dim)
-      check_positive(truncation)
-      self$bounds <- c(0, Inf)
-      super$initialize("chi_squared", dim, truncation)
-      self$add_parameter(df, "df")
-    },
-    tf_distrib = function(parameters, dag) {
-      tfp$distributions$Chi2(df = parameters$df)
-    }
-  )
-)
-
-logistic_distribution <- R6Class(
-  "logistic_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(location, scale, dim, truncation) {
-      location <- as.greta_array(location)
-      scale <- as.greta_array(scale)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(location, scale, target_dim = dim)
-      super$initialize("logistic", dim, truncation)
-      self$add_parameter(location, "location")
-      self$add_parameter(scale, "scale")
-    },
-    tf_distrib = function(parameters, dag) {
-      tfp$distributions$Logistic(
-        loc = parameters$location,
-        scale = parameters$scale
-      )
-    }
-  )
-)
-
-f_distribution <- R6Class(
-  "f_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(df1, df2, dim, truncation) {
-      df1 <- as.greta_array(df1)
-      df2 <- as.greta_array(df2)
-
-      # add the nodes as parents and parameters
-      dim <- check_dims(df1, df2, target_dim = dim)
-      check_positive(truncation)
-      self$bounds <- c(0, Inf)
-      super$initialize("f", dim, truncation)
-      self$add_parameter(df1, "df1")
-      self$add_parameter(df2, "df2")
-    },
-    tf_distrib = function(parameters, dag) {
-      df1 <- parameters$df1
-      df2 <- parameters$df2
-
-      tf_lbeta <- function(a, b) {
-        tf$math$lgamma(a) + tf$math$lgamma(b) - tf$math$lgamma(a + b)
-      }
-
-      log_prob <- function(x) {
-        df1_x <- df1 * x
-        la <- df1 * log(df1_x) + df2 * log(df2)
-        lb <- (df1 + df2) * log(df1_x + df2)
-        lnumerator <- fl(0.5) * (la - lb)
-        lnumerator - log(x) - tf_lbeta(df1 / fl(2), df2 / fl(2))
-      }
-
-      cdf <- function(x) {
-        df1_x <- df1 * x
-        ratio <- df1_x / (df1_x + df2)
-        tf$math$betainc(df1 / fl(2), df2 / fl(2), ratio)
-      }
-
-      log_cdf <- function(x) {
-        log(cdf(x))
-      }
-
-      sample <- function(seed) {
-
-        # sample as the ratio of two scaled chi squared distributions
-        d1 <- tfp$distributions$Chi2(df = df1)
-        d2 <- tfp$distributions$Chi2(df = df2)
-
-        u1 <- d1$sample(seed = seed)
-        u2 <- d2$sample(seed = seed)
-
-        (u1 / df1) / (u2 / df2)
-      }
-
-      list(
-        log_prob = log_prob,
-        cdf = cdf,
-        log_cdf = log_cdf,
-        sample = sample
-      )
-    }
-  )
-)
-
-dirichlet_distribution <- R6Class(
-  "dirichlet_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(alpha, n_realisations, dimension) {
-      # coerce to greta arrays
-      alpha <- as.greta_array(alpha)
-
-      dim <- check_multivariate_dims(
-        vectors = list(alpha),
-        n_realisations = n_realisations,
-        dimension = dimension
-      )
-
-      # coerce the parameter arguments to nodes and add as parents and
-      # parameters
-      self$bounds <- c(0, Inf)
-      super$initialize("dirichlet", dim,
-        truncation = c(0, Inf),
-        multivariate = TRUE
-      )
-      self$add_parameter(alpha, "alpha")
-    },
-    create_target = function(truncation) {
-      simplex_greta_array <- simplex_variable(self$dim)
-
-      # return the node for the simplex
-      target_node <- get_node(simplex_greta_array)
-      target_node
-    },
-    tf_distrib = function(parameters, dag) {
-      alpha <- parameters$alpha
-      tfp$distributions$Dirichlet(concentration = alpha)
-    }
-  )
-)
-
-dirichlet_multinomial_distribution <- R6Class(
-  "dirichlet_multinomial_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(size, alpha, n_realisations, dimension) {
-
-      # coerce to greta arrays
-      size <- as.greta_array(size)
-      alpha <- as.greta_array(alpha)
-
-      dim <- check_multivariate_dims(
-        scalars = list(size),
-        vectors = list(alpha),
-        n_realisations = n_realisations,
-        dimension = dimension
-      )
-
-
-      # need to handle size as a vector!
-
-      # coerce the parameter arguments to nodes and add as parents and
-      # parameters
-      super$initialize("dirichlet_multinomial",
-        dim = dim,
-        discrete = TRUE,
-        multivariate = TRUE
-      )
-      self$add_parameter(size, "size", shape_matches_output = FALSE)
-      self$add_parameter(alpha, "alpha")
-    },
-
-    # nolint start
-    tf_distrib = function(parameters, dag) {
-      parameters$size <- tf_flatten(parameters$size)
-      distrib <- tfp$distributions$DirichletMultinomial
-      distrib(
-        total_count = parameters$size,
-        concentration = parameters$alpha
-      )
-    }
-    # nolint end
-  )
-)
-
-multinomial_distribution <- R6Class(
-  "multinomial_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(size, prob, n_realisations, dimension) {
-
-      # coerce to greta arrays
-      size <- as.greta_array(size)
-      prob <- as.greta_array(prob)
-
-      dim <- check_multivariate_dims(
-        scalars = list(size),
-        vectors = list(prob),
-        n_realisations = n_realisations,
-        dimension = dimension
-      )
-
-      # need to make sure size is a column vector!
-
-      # coerce the parameter arguments to nodes and add as parents and
-      # parameters
-      super$initialize("multinomial",
-        dim = dim,
-        discrete = TRUE,
-        multivariate = TRUE
-      )
-      self$add_parameter(size, "size", shape_matches_output = FALSE)
-      self$add_parameter(prob, "prob")
-    },
-    tf_distrib = function(parameters, dag) {
-      parameters$size <- tf_flatten(parameters$size)
-      # scale probs to get absolute density correct
-      parameters$prob <- parameters$prob / tf_sum(parameters$prob)
-
-      tfp$distributions$Multinomial(
-        total_count = parameters$size,
-        probs = parameters$prob
-      )
-    }
-  )
-)
-
-categorical_distribution <- R6Class(
-  "categorical_distribution",
-  inherit = distribution_node,
-  public = list(
-    initialize = function(prob, n_realisations, dimension) {
-
-      # coerce to greta arrays
-      prob <- as.greta_array(prob)
-
-      dim <- check_multivariate_dims(
-        vectors = list(prob),
-        n_realisations = n_realisations,
-        dimension = dimension
-      )
-
-      # coerce the parameter arguments to nodes and add as parents and
-      # parameters
-      super$initialize("categorical",
-        dim = dim,
-        discrete = TRUE,
-        multivariate = TRUE
-      )
-      self$add_parameter(prob, "prob")
-    },
-    tf_distrib = function(parameters, dag) {
-      # scale probs to get absolute density correct
-      probs <- parameters$prob
-      probs <- probs / tf_sum(probs)
-      tfp$distributions$Multinomial(
-        total_count = fl(1),
-        probs = probs
-      )
-    }
-  )
-)
-
-multivariate_normal_distribution <- R6Class(
-  "multivariate_normal_distribution",
-  inherit = distribution_node,
-  public = list(
-    sigma_is_cholesky = FALSE,
-    # nolint start
-    initialize = function(mean, Sigma, n_realisations, dimension) {
-      # nolint end
-      # coerce to greta arrays
-      mean <- as.greta_array(mean)
-      sigma <- as.greta_array(Sigma)
-
-      # check dim is a positive scalar integer
-      dim <- check_multivariate_dims(
-        vectors = list(mean),
-        squares = list(sigma),
-        n_realisations = n_realisations,
-        dimension = dimension
-      )
-
-      # check dimensions of Sigma
-      if (nrow(sigma) != ncol(sigma) |
-        length(dim(sigma)) != 2) {
-        msg <- cli::format_error(
-          c(
-            "{.arg Sigma} must be a square 2D greta array",
-            "However {.arg Sigma} has dimensions \\
-            {.val {paste(dim(sigma), collapse = 'x')}}"
-          )
-        )
-        stop(
-          msg,
-          call. = FALSE
-        )
-      }
-
-      # compare possible dimensions
-      dim_mean <- ncol(mean)
-      dim_sigma <- nrow(sigma)
-
-      if (dim_mean != dim_sigma) {
-        msg <- cli::format_error(
-          c(
-            "{.arg mean} and {.arg Sigma} must have the same dimensions",
-            "However they are different: {dim_mean} vs {dim_sigma}"
-          )
-        )
-        stop(
-          msg,
-          call. = FALSE
-        )
-      }
-
-      # coerce the parameter arguments to nodes and add as parents and
-      # parameters
-      super$initialize("multivariate_normal", dim, multivariate = TRUE)
-
-      if (has_representation(sigma, "cholesky")) {
-        sigma <- representation(sigma, "cholesky")
-        self$sigma_is_cholesky <- TRUE
-      }
-      self$add_parameter(mean, "mean")
-      self$add_parameter(sigma, "sigma")
-    },
-    tf_distrib = function(parameters, dag) {
-
-      # if Sigma is a cholesky factor transpose it to tensorflow expoectation,
-      # otherwise decompose it
-
-      if (self$sigma_is_cholesky) {
-        l <- tf_transpose(parameters$sigma)
-      } else {
-        l <- tf$linalg$cholesky(parameters$sigma)
-      }
-
-      # add an extra dimension for the observation batch size (otherwise tfp
-      # will try to use the n_chains batch dimension)
-      l <- tf$expand_dims(l, 1L)
-
-      mu <- parameters$mean
-      # nolint start
-      tfp$distributions$MultivariateNormalTriL(
-        loc = mu,
-        scale_tril = l
-      )
-      # nolint end
-    }
-  )
-)
-
-wishart_distribution <- R6Class(
-  "wishart_distribution",
-  inherit = distribution_node,
-  public = list(
-
-    # set when defining the distribution
-    sigma_is_cholesky = FALSE,
-
-    # set when defining the graph
-    target_is_cholesky = FALSE,
-    initialize = function(df, Sigma) { # nolint
-      # add the nodes as parents and parameters
-
-      df <- as.greta_array(df)
-      sigma <- as.greta_array(Sigma)
-
-      # check dimensions of Sigma
-      if (nrow(sigma) != ncol(sigma) |
-        length(dim(sigma)) != 2) {
-        msg <- cli::format_error(
-          c(
-            "{.arg Sigma} must be a square 2D greta array",
-            "However, {.arg Sigma} has dimensions ",
-            "{.val {paste(dim(sigma), collapse = 'x')}}"
-          )
-        )
-        stop(
-          msg,
-          call. = FALSE
-        )
-      }
-
-      dim <- nrow(sigma)
-
-      # initialize with a cholesky factor
-      super$initialize("wishart", dim(sigma), multivariate = TRUE)
-
-      # set parameters
-      if (has_representation(sigma, "cholesky")) {
-        sigma <- representation(sigma, "cholesky")
-        self$sigma_is_cholesky <- TRUE
-      }
-      self$add_parameter(df, "df", shape_matches_output = FALSE)
-      self$add_parameter(sigma, "sigma")
-
-      # make the initial value PD (no idea whether this does anything)
-      self$value(unknowns(dims = c(dim, dim), data = diag(dim)))
-    },
-
-    # create a variable, and transform to a symmetric matrix (with cholesky
-    # factor representation)
-    create_target = function(truncation) {
-
-      # create cholesky factor variable greta array
-      chol_greta_array <- cholesky_variable(self$dim[1])
-
-      # reshape to a symmetric matrix (retaining cholesky representation)
-      matrix_greta_array <- chol2symm(chol_greta_array)
-
-      # return the node for the symmetric matrix
-      target_node <- get_node(matrix_greta_array)
-      target_node
-    },
-
-    # get a cholesky factor for the target if possible
-    get_tf_target_node = function() {
-      target <- self$target
-      if (has_representation(target, "cholesky")) {
-        chol <- representation(target, "cholesky")
-        target <- get_node(chol)
-        self$target_is_cholesky <- TRUE
-      }
-      target
-    },
-
-    # if the target is changed, make sure target_is_cholesky is reset to FALSE
-    # (can be resent on graph definition)
-    reset_target_flags = function() {
-      self$target_is_cholesky <- FALSE
-    },
-    tf_distrib = function(parameters, dag) {
-
-      # this is messy, we want to use the tfp wishart, but can't define the
-      # density without expanding the dimension of x
-
-      log_prob <- function(x) {
-
-        # reshape the dimensions
-        df <- tf_flatten(parameters$df)
-        sigma <- tf$expand_dims(parameters$sigma, 1L)
-        x <- tf$expand_dims(x, 1L)
-
-        # get the cholesky factor of Sigma in tf orientation
-        if (self$sigma_is_cholesky) {
-          sigma_chol <- tf$linalg$matrix_transpose(sigma)
-        } else {
-          sigma_chol <- tf$linalg$cholesky(sigma)
-        }
-
-        # get the cholesky factor of the target in tf_orientation
-        if (self$target_is_cholesky) {
-          x_chol <- tf$linalg$matrix_transpose(x)
-        } else {
-          x_chol <- tf$linalg$cholesky(x)
-        }
-
-        # use the density for choleskied x, with choleskied Sigma
-        distrib <- tfp$distributions$Wishart(
-          df = df,
-          scale_tril = sigma_chol,
-          input_output_cholesky = TRUE
-        )
-
-        distrib$log_prob(x_chol)
-      }
-
-      sample <- function(seed) {
-        df <- tf$squeeze(parameters$df, 1:2)
-        sigma <- parameters$sigma
-
-        # get the cholesky factor of Sigma in tf orientation
-        if (self$sigma_is_cholesky) {
-          sigma_chol <- tf$linalg$matrix_transpose(sigma)
-        } else {
-          sigma_chol <- tf$linalg$cholesky(sigma)
-        }
-
-        # use the density for choleskied x, with choleskied Sigma
-        distrib <- tfp$distributions$Wishart(
-          df = df,
-          scale_tril = sigma_chol
-        )
-
-        draws <- distrib$sample(seed = seed)
-
-        if (self$target_is_cholesky) {
-          draws <- tf_chol(draws)
-        }
-
-        draws
-      }
-
-      list(log_prob = log_prob, sample = sample)
-    }
-  )
-)
-
-lkj_correlation_distribution <- R6Class(
-  "lkj_correlation_distribution",
-  inherit = distribution_node,
-  public = list(
-
-    # set when defining the graph
-    target_is_cholesky = FALSE,
-    initialize = function(eta, dimension = 2) {
-      dimension <- check_dimension(target = dimension)
-
-      if (!inherits(eta, "greta_array")) {
-        if (!is.numeric(eta) || !length(eta) == 1 || eta <= 0) {
-          msg <- cli::format_error(
-            "{.arg eta} must be a positive scalar value, or a scalar \\
-            {.cls greta_array}"
-          )
-          stop(
-            msg,
-            call. = FALSE
-          )
-        }
-      }
-
-      # add the nodes as parents and parameters
-      eta <- as.greta_array(eta)
-
-      if (!is_scalar(eta)) {
-        msg <- cli::format_error(
-          c(
-            "{.arg eta} must be a scalar",
-            "However {.arg eta} had dimensions: \\
-            {paste0(dim(eta), collapse = ', ')}"
-          )
-        )
-        stop(
-          msg,
-          call. = FALSE
-        )
-      }
-
-      dim <- c(dimension, dimension)
-      super$initialize("lkj_correlation", dim, multivariate = TRUE)
-
-      # don't try to expand scalar eta out to match the target size
-      self$add_parameter(eta, "eta", shape_matches_output = FALSE)
-
-      # make the initial value PD
-      self$value(unknowns(dims = dim, data = diag(dimension)))
-    },
-
-    # default (cholesky factor, ignores truncation)
-    create_target = function(truncation) {
-
-      # create (correlation matrix) cholesky factor variable greta array
-      chol_greta_array <- cholesky_variable(self$dim[1], correlation = TRUE)
-
-      # reshape to a symmetric matrix (retaining cholesky representation)
-      matrix_greta_array <- chol2symm(chol_greta_array)
-
-      # return the node for the symmetric matrix
-      target_node <- get_node(matrix_greta_array)
-      target_node
-    },
-
-    # get a cholesky factor for the target if possible
-    get_tf_target_node = function() {
-      target <- self$target
-      if (has_representation(target, "cholesky")) {
-        chol <- representation(target, "cholesky")
-        target <- get_node(chol)
-        self$target_is_cholesky <- TRUE
-      }
-      target
-    },
-
-    # if the target is changed, make sure target_is_cholesky is reset to FALSE
-    # (can be resent on graph definition)
-    reset_target_flags = function() {
-      self$target_is_cholesky <- FALSE
-    },
-    tf_distrib = function(parameters, dag) {
-      eta <- tf$squeeze(parameters$eta, 1:2)
-      dim <- self$dim[1]
-
-      distrib <- tfp$distributions$LKJ(
-        dimension = dim,
-        concentration = eta,
-        input_output_cholesky = self$target_is_cholesky
-      )
-
-      # tfp's lkj sampling can't detect the size of the output from eta, for
-      # some reason. But we can use map_fun to apply their simulation to each
-      # element of eta.
-      sample <- function(seed) {
-        sample_once <- function(eta) {
-          d <- tfp$distributions$LKJ(
-            dimension = dim,
-            concentration = eta,
-            input_output_cholesky = self$target_is_cholesky
-          )
-
-          d$sample(seed = seed)
-        }
-
-        tf$map_fn(sample_once, eta)
-      }
-
-      list(
-        log_prob = distrib$log_prob,
-        sample = sample
-      )
-    }
-  )
-)
-
-# module for export via .internals
-distribution_classes_module <- module(uniform_distribution,
-                                      normal_distribution,
-                                      lognormal_distribution,
-                                      bernoulli_distribution,
-                                      binomial_distribution,
-                                      beta_binomial_distribution,
-                                      negative_binomial_distribution,
-                                      zero_inflated_poisson_distribution,
-                                      zero_inflated_negative_binomial_distribution,
-                                      hypergeometric_distribution,
-                                      poisson_distribution,
-                                      gamma_distribution,
-                                      inverse_gamma_distribution,
-                                      weibull_distribution,
-                                      exponential_distribution,
-                                      pareto_distribution,
-                                      student_distribution,
-                                      laplace_distribution,
-                                      beta_distribution,
-                                      cauchy_distribution,
-                                      chi_squared_distribution,
-                                      logistic_distribution,
-                                      f_distribution,
-                                      multivariate_normal_distribution,
-                                      wishart_distribution,
-                                      lkj_correlation_distribution,
-                                      multinomial_distribution,
-                                      categorical_distribution,
-                                      dirichlet_distribution,
-                                      dirichlet_multinomial_distribution)
-
-# export constructors
-
-# nolint start
-#' @name distributions
-#' @title probability distributions
-#' @description These functions can be used to define random variables in a
-#'   greta model. They return a variable greta array that follows the specified
-#'   distribution. This variable greta array can be used to represent a
-#'   parameter with prior distribution, combined into a mixture distribution
-#'   using [mixture()], or used with [distribution()] to
-#'   define a distribution over a data greta array.
-#'
-#' @param truncation a length-two vector giving values between which to truncate
-#'   the distribution, similarly to the `lower` and `upper` arguments
-#'   to [variable()]
-#'
-#' @param min,max scalar values giving optional limits to `uniform`
-#'   variables. Like `lower` and `upper`, these must be specified as
-#'   numerics, they cannot be greta arrays (though see details for a
-#'   workaround). Unlike `lower` and `upper`, they must be finite.
-#'   `min` must always be less than `max`.
-#'
-#' @param mean,meanlog,location,mu unconstrained parameters
-#'
-#' @param
-#'   sd,sdlog,sigma,lambda,shape,rate,df,scale,shape1,shape2,alpha,beta,df1,df2,a,b,eta
-#'    positive parameters, `alpha` must be a vector for `dirichlet`
-#'   and `dirichlet_multinomial`.
-#'
-#' @param size,m,n,k positive integer parameter
-#'
-#' @param prob probability parameter (`0 < prob < 1`), must be a vector for
-#'   `multinomial` and `categorical`
-#'
-#' @param Sigma positive definite variance-covariance matrix parameter
-#'
-#' @param dim the dimensions of the greta array to be returned, either a scalar
-#'   or a vector of positive integers. See details.
-#'
-#' @param dimension the dimension of a multivariate distribution
-#'
-#' @param n_realisations the number of independent realisation of a multivariate
-#'   distribution
-#'
-#' @details The discrete probability distributions (`bernoulli`,
-#'   `binomial`, `negative_binomial`, `poisson`,
-#'   `multinomial`, `categorical`, `dirichlet_multinomial`) can
-#'   be used when they have fixed values (e.g. defined as a likelihood using
-#'   [distribution()], but not as unknown variables.
-#'
-#'   For univariate distributions `dim` gives the dimensions of the greta
-#'   array to create. Each element of the greta array will be (independently)
-#'   distributed according to the distribution. `dim` can also be left at
-#'   its default of `NULL`, in which case the dimension will be detected
-#'   from the dimensions of the parameters (provided they are compatible with
-#'   one another).
-#'
-#'   For multivariate distributions (`multivariate_normal()`,
-#'   `multinomial()`, `categorical()`, `dirichlet()`, and
-#'   `dirichlet_multinomial()`) each row of the output and parameters
-#'   corresponds to an independent realisation. If a single realisation or
-#'   parameter value is specified, it must therefore be a row vector (see
-#'   example). `n_realisations` gives the number of rows/realisations, and
-#'   `dimension` gives the dimension of the distribution. I.e. a bivariate
-#'   normal distribution would be produced with `multivariate_normal(...,
-#'   dimension = 2)`. The dimension can usually be detected from the parameters.
-#'
-#'   `multinomial()` does not check that observed values sum to
-#'   `size`, and `categorical()` does not check that only one of the
-#'   observed entries is 1. It's the user's responsibility to check their data
-#'   matches the distribution!
-#'
-#'   The parameters of `uniform` must be fixed, not greta arrays. This
-#'   ensures these values can always be transformed to a continuous scale to run
-#'   the samplers efficiently. However, a hierarchical `uniform` parameter
-#'   can always be created by defining a `uniform` variable constrained
-#'   between 0 and 1, and then transforming it to the required scale. See below
-#'   for an example.
-#'
-#'   Wherever possible, the parameterisations and argument names of greta
-#'   distributions match commonly used R functions for distributions, such as
-#'   those in the `stats` or `extraDistr` packages. The following
-#'   table states the distribution function to which greta's implementation
-#'   corresponds:
-#'
-#'   \tabular{ll}{ greta \tab reference\cr `uniform` \tab
-#'   [stats::dunif]\cr `normal` \tab
-#'   [stats::dnorm]\cr `lognormal` \tab
-#'   [stats::dlnorm]\cr `bernoulli` \tab
-#'   [extraDistr::dbern]\cr `binomial` \tab
-#'   [stats::dbinom]\cr `beta_binomial` \tab
-#'   [extraDistr::dbbinom]\cr `negative_binomial`
-#'   \tab [stats::dnbinom]\cr `hypergeometric` \tab
-#'   [stats::dhyper]\cr `poisson` \tab
-#'   [stats::dpois]\cr `gamma` \tab
-#'   [stats::dgamma]\cr `inverse_gamma` \tab
-#'   [extraDistr::dinvgamma]\cr `weibull` \tab
-#'   [stats::dweibull]\cr `exponential` \tab
-#'   [stats::dexp]\cr `pareto` \tab
-#'   [extraDistr::dpareto]\cr `student` \tab
-#'   [extraDistr::dlst]\cr `laplace` \tab
-#'   [extraDistr::dlaplace]\cr `beta` \tab
-#'   [stats::dbeta]\cr `cauchy` \tab
-#'   [stats::dcauchy]\cr `chi_squared` \tab
-#'   [stats::dchisq]\cr `logistic` \tab
-#'   [stats::dlogis]\cr `f` \tab
-#'   [stats::df]\cr `multivariate_normal` \tab
-#'   [mvtnorm::dmvnorm]\cr `multinomial` \tab
-#'   [stats::dmultinom]\cr `categorical` \tab
-#'   {[stats::dmultinom] (size = 1)}\cr `dirichlet`
-#'   \tab [extraDistr::ddirichlet]\cr
-#'   `dirichlet_multinomial` \tab
-#'   [extraDistr::ddirmnom]\cr `wishart` \tab
-#'   [stats::rWishart]\cr `lkj_correlation` \tab
-#'   [rethinking::dlkjcorr](https://rdrr.io/github/rmcelreath/rethinking/man/dlkjcorr.html)
-#'   }
-#'
-#' @examples
-#' \dontrun{
-#'
-#' # a uniform parameter constrained to be between 0 and 1
-#' phi <- uniform(min = 0, max = 1)
-#'
-#' # a length-three variable, with each element following a standard normal
-#' # distribution
-#' alpha <- normal(0, 1, dim = 3)
-#'
-#' # a length-three variable of lognormals
-#' sigma <- lognormal(0, 3, dim = 3)
-#'
-#' # a hierarchical uniform, constrained between alpha and alpha + sigma,
-#' eta <- alpha + uniform(0, 1, dim = 3) * sigma
-#'
-#' # a hierarchical distribution
-#' mu <- normal(0, 1)
-#' sigma <- lognormal(0, 1)
-#' theta <- normal(mu, sigma)
-#'
-#' # a vector of 3 variables drawn from the same hierarchical distribution
-#' thetas <- normal(mu, sigma, dim = 3)
-#'
-#' # a matrix of 12 variables drawn from the same hierarchical distribution
-#' thetas <- normal(mu, sigma, dim = c(3, 4))
-#'
-#' # a multivariate normal variable, with correlation between two elements
-#' # note that the parameter must be a row vector
-#' Sig <- diag(4)
-#' Sig[3, 4] <- Sig[4, 3] <- 0.6
-#' theta <- multivariate_normal(t(rep(mu, 4)), Sig)
-#'
-#' # 10 independent replicates of that
-#' theta <- multivariate_normal(t(rep(mu, 4)), Sig, n_realisations = 10)
-#'
-#' # 10 multivariate normal replicates, each with a different mean vector,
-#' # but the same covariance matrix
-#' means <- matrix(rnorm(40), 10, 4)
-#' theta <- multivariate_normal(means, Sig, n_realisations = 10)
-#' dim(theta)
-#'
-#' # a Wishart variable with the same covariance parameter
-#' theta <- wishart(df = 5, Sigma = Sig)
-#' }
-NULL
-# nolint end
-
-#' @rdname distributions
-#' @export
-uniform <- function(min, max, dim = NULL) {
-  distrib("uniform", min, max, dim)
-}
-
-#' @rdname distributions
-#' @export
-normal <- function(mean, sd, dim = NULL, truncation = c(-Inf, Inf)) {
-  distrib("normal", mean, sd, dim, truncation)
-}
-
-#' @rdname distributions
-#' @export
-lognormal <- function(meanlog, sdlog, dim = NULL, truncation = c(0, Inf)) {
-  distrib("lognormal", meanlog, sdlog, dim, truncation)
-}
-
-#' @rdname distributions
-#' @export
-bernoulli <- function(prob, dim = NULL) {
-  distrib("bernoulli", prob, dim)
-}
-
-#' @rdname distributions
-#' @export
-binomial <- function(size, prob, dim = NULL) {
-  check_in_family("binomial", size)
-  distrib("binomial", size, prob, dim)
-}
-
-#' @rdname distributions
-#' @export
-beta_binomial <- function(size, alpha, beta, dim = NULL) {
-  distrib("beta_binomial", size, alpha, beta, dim)
-}
-
-#' @rdname distributions
-#' @export
-negative_binomial <- function(size, prob, dim = NULL) {
-  distrib("negative_binomial", size, prob, dim)
-}
-
-#' @rdname distributions
-#' @export
-hypergeometric <- function(m, n, k, dim = NULL) {
-  distrib("hypergeometric", m, n, k, dim)
-}
-
-#' @rdname distributions
-#' @export
-poisson <- function(lambda, dim = NULL) {
-  check_in_family("poisson", lambda)
-  distrib("poisson", lambda, dim)
-}
-
-#' @rdname distributions
-#' @export
-zero_inflated_poisson <- function (theta, lambda, dim = NULL) {
-  distrib('zero_inflated_poisson', theta, lambda, dim)
-}
-
-#' @rdname distributions
-#' @export
-zero_inflated_negative_binomial <- function (theta, size, prob, dim = NULL) {
-  distrib('zero_inflated_negative_binomial', theta, size, prob, dim)
-}
-
-#' @rdname distributions
-#' @export
-gamma <- function(shape, rate, dim = NULL, truncation = c(0, Inf)) {
-  distrib("gamma", shape, rate, dim, truncation)
-}
-
-#' @rdname distributions
-#' @export
-inverse_gamma <- function(alpha, beta, dim = NULL, truncation = c(0, Inf)) {
-  distrib("inverse_gamma", alpha, beta, dim, truncation)
-}
-
-#' @rdname distributions
-#' @export
-weibull <- function(shape, scale, dim = NULL, truncation = c(0, Inf)) {
-  distrib("weibull", shape, scale, dim, truncation)
-}
-
-#' @rdname distributions
-#' @export
-exponential <- function(rate, dim = NULL, truncation = c(0, Inf)) {
-  distrib("exponential", rate, dim, truncation)
-}
-
-#' @rdname distributions
-#' @export
-pareto <- function(a, b, dim = NULL, truncation = c(0, Inf)) {
-  distrib("pareto", a, b, dim, truncation)
-}
-
-#' @rdname distributions
-#' @export
-student <- function(df, mu, sigma, dim = NULL, truncation = c(-Inf, Inf)) {
-  distrib("student", df, mu, sigma, dim, truncation)
-}
-
-#' @rdname distributions
-#' @export
-laplace <- function(mu, sigma, dim = NULL, truncation = c(-Inf, Inf)) {
-  distrib("laplace", mu, sigma, dim, truncation)
-}
-
-#' @rdname distributions
-#' @export
-beta <- function(shape1, shape2, dim = NULL, truncation = c(0, 1)) {
-  distrib("beta", shape1, shape2, dim, truncation)
-}
-
-#' @rdname distributions
-#' @export
-cauchy <- function(location, scale, dim = NULL, truncation = c(-Inf, Inf)) {
-  distrib("cauchy", location, scale, dim, truncation)
-}
-
-#' @rdname distributions
-#' @export
-chi_squared <- function(df, dim = NULL, truncation = c(0, Inf)) {
-  distrib("chi_squared", df, dim, truncation)
-}
-
-#' @rdname distributions
-#' @export
-logistic <- function(location, scale, dim = NULL, truncation = c(-Inf, Inf)) {
-  distrib("logistic", location, scale, dim, truncation)
-}
-
-#' @rdname distributions
-#' @export
-f <- function(df1, df2, dim = NULL, truncation = c(0, Inf)) {
-  distrib("f", df1, df2, dim, truncation)
-}
-
-# nolint start
-#' @rdname distributions
-#' @export
-multivariate_normal <- function(mean, Sigma,
-                                n_realisations = NULL, dimension = NULL) {
-  # nolint end
-  distrib(
-    "multivariate_normal", mean, Sigma,
-    n_realisations, dimension
-  )
-}
-
-#' @rdname distributions
-#' @export
-wishart <- function(df, Sigma) { # nolint
-  distrib("wishart", df, Sigma)
-}
-
-#' @rdname distributions
-#' @export
-lkj_correlation <- function(eta, dimension = 2) {
-  distrib("lkj_correlation", eta, dimension)
-}
-
-#' @rdname distributions
-#' @export
-multinomial <- function(size, prob, n_realisations = NULL, dimension = NULL) {
-  distrib("multinomial", size, prob, n_realisations, dimension)
-}
-
-#' @rdname distributions
-#' @export
-categorical <- function(prob, n_realisations = NULL, dimension = NULL) {
-  distrib("categorical", prob, n_realisations, dimension)
-}
-
-#' @rdname distributions
-#' @export
-dirichlet <- function(alpha, n_realisations = NULL, dimension = NULL) {
-  distrib("dirichlet", alpha, n_realisations, dimension)
-}
-
-#' @rdname distributions
-#' @export
-dirichlet_multinomial <- function(size, alpha,
-                                  n_realisations = NULL, dimension = NULL) {
-  distrib(
-    "dirichlet_multinomial",
-    size, alpha, n_realisations, dimension
-  )
-}
diff --git a/R/zero_inflateds.R b/R/zero_inflateds.R
new file mode 100644
index 0000000..d31d5ba
--- /dev/null
+++ b/R/zero_inflateds.R
@@ -0,0 +1,104 @@
+zero_inflated_poisson_distribution <- R6Class(
+  "zero_inflated_poisson_distribution",
+  inherit = greta::.internals$nodes$node_classes$distribution_node,
+  public = list(
+    initialize = function(theta, lambda, dim) {
+      theta <- as.greta_array(theta)
+      lambda <- as.greta_array(lambda)
+      # add the nodes as children and parameters
+      dim <- check_dims(theta, lambda, target_dim = dim)
+      super$initialize("zero_inflated_poisson", dim, discrete = TRUE)
+      self$add_parameter(theta, "theta")
+      self$add_parameter(lambda, "lambda")
+    },
+    
+    tf_distrib = function(parameters, dag) {
+      theta <- parameters$theta
+      lambda <- parameters$lambda
+      log_prob <- function(x) {
+        
+        tf$math$log(theta * tf$nn$relu(fl(1) - x) + (fl(1) - theta) * tf$pow(lambda, x) * tf$exp(-lambda) / tf$exp(tf$math$lgamma(x + fl(1))))
+      }
+      
+      sample <- function(seed) {
+        
+        binom <- tfp$distributions$Binomial(total_count = 1, probs = theta)
+        pois <- tfp$distributions$Poisson(rate = lambda)
+        
+        zi <- binom$sample(seed = seed)
+        lbd <- pois$sample(seed = seed)
+        
+        (fl(1) - zi) * lbd
+        
+      }
+      
+      list(log_prob = log_prob, sample = sample, cdf = NULL, log_cdf = NULL)
+    },
+    
+    tf_cdf_function = NULL,
+    tf_log_cdf_function = NULL
+  )
+)
+
+
+zero_inflated_negative_binomial_distribution <- R6Class(
+  "zero_inflated_negative_binomial_distribution",
+  inherit = greta::.internals$nodes$node_classes$distribution_node,
+  public = list(
+    initialize = function(theta, size, prob, dim) {
+      theta <- as.greta_array(theta)
+      size <- as.greta_array(size)
+      prob <- as.greta_array(prob)
+      # add the nodes as children and parameters
+      dim <- check_dims(theta, size, prob, target_dim = dim)
+      super$initialize("zero_inflated_negative_binomial", dim, discrete = TRUE)
+      self$add_parameter(theta, "theta")
+      self$add_parameter(size, "size")
+      self$add_parameter(prob, "prob")
+    },
+    
+    tf_distrib = function(parameters, dag) {
+      theta <- parameters$theta
+      size <- parameters$size
+      p <- parameters$prob # probability of success
+      q <- fl(1) - parameters$prob 
+      log_prob <- function(x) {
+        
+        tf$math$log(theta * tf$nn$relu(fl(1) - x) + (fl(1) - theta) * tf$pow(p, size) * tf$pow(q, x) * tf$exp(tf$math$lgamma(x + size)) / tf$exp(tf$math$lgamma(size)) / tf$exp(tf$math$lgamma(x + fl(1))))
+        
+      }
+      
+      sample <- function(seed) {
+        
+        binom <- tfp$distributions$Binomial(total_count = 1, probs = theta)
+        negbin <- tfp$distributions$NegativeBinomial(total_count = size, probs = q) # change of proba / parametrisation in 'stats'
+        
+        zi <- binom$sample(seed = seed)
+        lbd <- negbin$sample(seed = seed)
+        
+        (fl(1) - zi) * lbd
+        
+      }
+      
+      list(log_prob = log_prob, sample = sample, cdf = NULL, log_cdf = NULL)
+    },
+    
+    tf_cdf_function = NULL,
+    tf_log_cdf_function = NULL
+  )
+)
+
+#' @rdname distributions
+#' @export
+zero_inflated_poisson <- function (theta, lambda, dim = NULL) {
+  distrib('zero_inflated_poisson', theta, lambda, dim)
+}
+
+#' @rdname distributions
+#' @export
+zero_inflated_negative_binomial <- function (theta, size, prob, dim = NULL) {
+  distrib('zero_inflated_negative_binomial', theta, size, prob, dim)
+}
+
+distribution_classes_module <- module(zero_inflated_poisson_distribution,
+                                      zero_inflated_negative_binomial_distribution)
\ No newline at end of file
diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R
index 00860e3..6f6b8fb 100644
--- a/tests/testthat/helpers.R
+++ b/tests/testthat/helpers.R
@@ -836,6 +836,23 @@ check_samples <- function(x,
   testthat::expect_gte(stat$p.value, 0.01)
 }
 
+# zero inflated poisson using distributional
+
+zero_inflated_pois <- function(lambda,
+                               prob){
+    dist_inflated(
+      dist = dist_poisson(lambda = lambda),
+      prob = prob,
+      x = 0
+    )
+  
+}
+
+sample_zero_inflated_pois <- function(n, lambda, prob){
+  generate(x = zero_inflated_pois(lambda = lambda, prob = prob),
+           n)
+}
+
 # zero-inflated distribution from rethinking package
 dzipois <- function(x , theta , lambda , log=FALSE ) {
   ll <- ifelse( x==0 , theta + (1-theta)*exp(-lambda) , (1-theta)*dpois(x,lambda,FALSE) )
diff --git a/tests/testthat/test_distributions.R b/tests/testthat/test_distributions.R
deleted file mode 100644
index 23b3c71..0000000
--- a/tests/testthat/test_distributions.R
+++ /dev/null
@@ -1,1178 +0,0 @@
-test_that("normal distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::normal,
-    stats::dnorm,
-    parameters = list(mean = -2, sd = 3),
-    x = rnorm(100, -2, 3)
-  )
-})
-
-
-test_that("multidimensional normal distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::normal,
-    stats::dnorm,
-    parameters = list(mean = -2, sd = 3),
-    x = array(rnorm(100, -2, 3),
-      dim = c(10, 2, 5)
-    ),
-    dim = c(10, 2, 5)
-  )
-})
-
-test_that("uniform distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::uniform,
-    stats::dunif,
-    parameters = list(min = -2.1, max = -1.2),
-    x = runif(100, -2.1, -1.2)
-  )
-})
-
-test_that("lognormal distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::lognormal,
-    stats::dlnorm,
-    parameters = list(meanlog = 1, sdlog = 3),
-    x = rlnorm(100, 1, 3)
-  )
-})
-
-test_that("bernoulli distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::bernoulli,
-    extraDistr::dbern,
-    parameters = list(prob = 0.3),
-    x = rbinom(100, 1, 0.3)
-  )
-})
-
-test_that("binomial distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::binomial,
-    stats::dbinom,
-    parameters = list(size = 10, prob = 0.8),
-    x = rbinom(100, 10, 0.8)
-  )
-})
-
-test_that("beta-binomial distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::beta_binomial,
-    extraDistr::dbbinom,
-    parameters = list(
-      size = 10,
-      alpha = 0.8,
-      beta = 1.2
-    ),
-    x = extraDistr::rbbinom(100, 10, 0.8, 1.2)
-  )
-})
-
-test_that("negative binomial distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::negative_binomial,
-    stats::dnbinom,
-    parameters = list(size = 3.3, prob = 0.2),
-    x = rnbinom(100, 3.3, 0.2)
-  )
-})
-
-test_that("zero inflated poisson distribution has correct density", {
-
-  skip_if_not(check_tf_version())
-  source("helpers.R")
-
-  compare_distribution(zero_inflated_poisson,
-                       extraDistr::dzip,
-                       parameters = list(theta = 0.2, lambda = 2, pi = 0.2),
-                       x = extraDistr::rpois(100, 2, 0.2))
-
-})
-
-test_that("zero inflated negative binomial distribution has correct density", {
-
-  skip_if_not(check_tf_version())
-  source("helpers.R")
-
-  compare_distribution(zero_inflated_negative_binomial,
-                       extraDistr::dzinb,
-                       parameters = list(theta = 2, size = 10, prob = 0.1, pi = 0.2),
-                       x = extraDistr::rzinb(100, 10, 0.1, 0.2))
-
-})
-
-
-test_that("hypergeometric distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::hypergeometric,
-    stats::dhyper,
-    parameters = list(m = 11, n = 8, k = 5),
-    x = rhyper(100, 11, 8, 5)
-  )
-})
-
-test_that("poisson distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::poisson,
-    stats::dpois,
-    parameters = list(lambda = 17.2),
-    x = rpois(100, 17.2)
-  )
-})
-
-test_that("gamma distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::gamma,
-    stats::dgamma,
-    parameters = list(shape = 1.2, rate = 2.3),
-    x = rgamma(100, 1.2, 2.3)
-  )
-})
-
-
-test_that("inverse gamma distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::inverse_gamma,
-    extraDistr::dinvgamma,
-    parameters = list(alpha = 1.2, beta = 0.9),
-    x = extraDistr::rinvgamma(100, 1.2, 0.9)
-  )
-})
-
-test_that("weibull distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::weibull,
-    dweibull,
-    parameters = list(
-      shape = 1.2,
-      scale = 0.9
-    ),
-    x = rweibull(100, 1.2, 0.9)
-  )
-})
-
-test_that("exponential distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::exponential,
-    stats::dexp,
-    parameters = list(rate = 1.9),
-    x = rexp(100, 1.9)
-  )
-})
-
-test_that("pareto distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::pareto,
-    extraDistr::dpareto,
-    parameters = list(a = 1.9, b = 2.3),
-    x = extraDistr::rpareto(100, 1.9, 2.3)
-  )
-})
-
-test_that("student distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-  dstudent <- extraDistr::dlst
-
-  compare_distribution(
-    greta::student,
-    dstudent,
-    parameters = list(
-      df = 3,
-      mu = -0.9,
-      sigma = 2
-    ),
-    x = rnorm(100, -0.9, 2)
-  )
-})
-
-test_that("laplace distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::laplace,
-    extraDistr::dlaplace,
-    parameters = list(mu = -0.9, sigma = 2),
-    x = extraDistr::rlaplace(100, -0.9, 2)
-  )
-})
-
-test_that("beta distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::beta,
-    stats::dbeta,
-    parameters = list(
-      shape1 = 2.3,
-      shape2 = 3.4
-    ),
-    x = rbeta(100, 2.3, 3.4)
-  )
-})
-
-test_that("cauchy distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::cauchy,
-    stats::dcauchy,
-    parameters = list(
-      location = -1.3,
-      scale = 3.4
-    ),
-    x = rcauchy(100, -1.3, 3.4)
-  )
-})
-
-test_that("logistic distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::logistic,
-    stats::dlogis,
-    parameters = list(
-      location = -1.3,
-      scale = 2.1
-    ),
-    x = rlogis(100, -1.3, 2.1)
-  )
-})
-
-test_that("f distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::f,
-    df,
-    parameters = list(df1 = 5.9, df2 = 2),
-    x = rf(100, 5.9, 2)
-  )
-})
-
-test_that("chi squared distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  compare_distribution(greta::chi_squared,
-    stats::dchisq,
-    parameters = list(df = 9.3),
-    x = rchisq(100, 9.3)
-  )
-})
-
-test_that("multivariate normal distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  # parameters to test
-  m <- 5
-  mn <- t(rnorm(m))
-  sig <- rWishart(1, m + 1, diag(m))[, , 1]
-
-  # function converting Sigma to sigma
-  dmvnorm2 <- function(x, mean, Sigma, log = FALSE) { # nolint
-    mvtnorm::dmvnorm(x = x, mean = mean, sigma = Sigma, log = log)
-  }
-
-  compare_distribution(greta::multivariate_normal,
-    dmvnorm2,
-    parameters = list(mean = mn, Sigma = sig),
-    x = mvtnorm::rmvnorm(100, mn, sig),
-    multivariate = TRUE
-  )
-})
-
-test_that("Wishart distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  # parameters to test
-  m <- 5
-  df <- m + 1
-  sig <- rWishart(1, df, diag(m))[, , 1]
-
-  # wrapper for argument names
-  dwishart <- function(x, df, Sigma, log = FALSE) { # nolint
-    ans <- MCMCpack::dwish(W = x, v = df, S = Sigma)
-    if (log) {
-      ans <- log(ans)
-    }
-    ans
-  }
-
-  # no vectorised wishart, so loop through all of these
-  replicate(
-    10,
-    compare_distribution(greta::wishart,
-      dwishart,
-      parameters = list(
-        df = df,
-        Sigma = sig
-      ),
-      x = rWishart(1, df, sig)[, , 1],
-      multivariate = TRUE
-    )
-  )
-})
-
-test_that("lkj distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  # parameters to test
-  m <- 5
-  eta <- 3
-
-  # normalising component of lkj (depends only on eta and dimension)
-  lkj_log_normalising <- function(eta, n) {
-    log_pi <- log(pi)
-    ans <- 0
-    for (k in 1:(n - 1)) {
-      ans <- ans + log_pi * (k / 2)
-      ans <- ans + lgamma(eta + (n - 1 - k) / 2)
-      ans <- ans - lgamma(eta + (n - 1) / 2)
-    }
-    ans
-  }
-
-  # lkj density
-  dlkj_correlation <- function(x, eta, log = FALSE, dimension = NULL) {
-    res <- (eta - 1) * log(det(x)) - lkj_log_normalising(eta, ncol(x))
-    if (!log) {
-      res <- exp(res)
-    }
-    res
-  }
-
-  # no vectorised lkj, so loop through all of these
-  replicate(
-    10,
-    compare_distribution(greta::lkj_correlation,
-      dlkj_correlation,
-      parameters = list(eta = eta, dimension = m),
-      x = rlkjcorr(1, eta = 1, dimension = m),
-      multivariate = TRUE
-    )
-  )
-})
-
-test_that("multinomial distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  # parameters to test
-  m <- 5
-  prob <- t(runif(m))
-  size <- 5
-
-  # vectorise R's density function
-  dmultinom_vec <- function(x, size, prob) {
-    apply(x, 1, stats::dmultinom, size = size, prob = prob)
-  }
-
-  compare_distribution(greta::multinomial,
-    dmultinom_vec,
-    parameters = list(
-      size = size,
-      prob = prob
-    ),
-    x = t(rmultinom(100, size, prob)),
-    multivariate = TRUE
-  )
-})
-
-test_that("categorical distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  # parameters to test
-  m <- 5
-  prob <- t(runif(m))
-
-  # vectorise R's density function
-  dcategorical_vec <- function(x, prob) {
-    apply(x, 1, stats::dmultinom, size = 1, prob = prob)
-  }
-
-  compare_distribution(greta::categorical,
-    dcategorical_vec,
-    parameters = list(prob = prob),
-    x = t(rmultinom(100, 1, prob)),
-    multivariate = TRUE
-  )
-})
-
-test_that("dirichlet distribution has correct density", {
-  skip_if_not(check_tf_version())
-  # parameters to test
-  m <- 5
-  alpha <- t(runif(m))
-
-  compare_distribution(
-    greta_fun = greta::dirichlet,
-    r_fun = extraDistr::ddirichlet,
-    parameters = list(alpha = alpha),
-    x = extraDistr::rdirichlet(100, alpha),
-    multivariate = TRUE
-  )
-})
-
-test_that("dirichlet-multinomial distribution has correct density", {
-  skip_if_not(check_tf_version())
-
-
-  # parameters to test
-  size <- 10
-  m <- 5
-  alpha <- t(runif(m))
-
-  compare_distribution(greta::dirichlet_multinomial,
-    extraDistr::ddirmnom,
-    parameters = list(
-      size = size,
-      alpha = alpha
-    ),
-    x = extraDistr::rdirmnom(
-      100,
-      size,
-      alpha
-    ),
-    multivariate = TRUE
-  )
-})
-
-test_that("scalar-valued distributions can be defined in models", {
-  skip_if_not(check_tf_version())
-
-
-  x <- randn(5)
-  y <- round(randu(5))
-  p <- iprobit(normal(0, 1))
-
-  # variable (need to define a likelihood)
-  a <- variable()
-  distribution(x) <- normal(a, 1)
-  expect_ok(model(a))
-
-  # univariate discrete distributions
-  distribution(y) <- bernoulli(p)
-  expect_ok(model(p))
-
-  distribution(y) <- binomial(1, p)
-  expect_ok(model(p))
-
-  distribution(y) <- beta_binomial(1, p, 0.2)
-  expect_ok(model(p))
-
-  distribution(y) <- negative_binomial(1, p)
-  expect_ok(model(p))
-
-  distribution(y) <- hypergeometric(5, 5, p)
-  expect_ok(model(p))
-
-  distribution(y) <- poisson(p)
-  expect_ok(model(p))
-
-  # multivariate discrete distributions
-  y <- extraDistr::rmnom(1, size = 4, prob = runif(3))
-  p <- iprobit(normal(0, 1, dim = 3))
-  distribution(y) <- multinomial(4, t(p))
-  expect_ok(model(p))
-
-  y <- extraDistr::rmnom(1, size = 1, prob = runif(3))
-  p <- iprobit(normal(0, 1, dim = 3))
-  distribution(y) <- categorical(t(p))
-  expect_ok(model(p))
-
-  y <- extraDistr::rmnom(1, size = 4, prob = runif(3))
-  alpha <- lognormal(0, 1, dim = 3)
-  distribution(y) <- dirichlet_multinomial(4, t(alpha))
-  expect_ok(model(alpha))
-
-  # univariate continuous distributions
-  expect_ok(model(normal(-2, 3)))
-  expect_ok(model(student(5.6, -2, 2.3)))
-  expect_ok(model(laplace(-1.2, 1.1)))
-  expect_ok(model(cauchy(-1.2, 1.1)))
-  expect_ok(model(logistic(-1.2, 1.1)))
-
-  expect_ok(model(lognormal(1.2, 0.2)))
-  expect_ok(model(gamma(0.9, 1.3)))
-  expect_ok(model(exponential(6.3)))
-  expect_ok(model(beta(6.3, 5.9)))
-  expect_ok(model(inverse_gamma(0.9, 1.3)))
-  expect_ok(model(weibull(2, 1.1)))
-  expect_ok(model(pareto(2.4, 1.5)))
-  expect_ok(model(chi_squared(4.3)))
-  expect_ok(model(f(24.3, 2.4)))
-
-  expect_ok(model(uniform(-13, 2.4)))
-
-  # multivariate continuous distributions
-  sig <- rWishart(1, 4, diag(3))[, , 1]
-
-  expect_ok(model(multivariate_normal(t(rnorm(3)), sig)))
-  expect_ok(model(wishart(4, sig)))
-  expect_ok(model(lkj_correlation(5, dimension = 3)))
-  expect_ok(model(dirichlet(t(runif(3)))))
-})
-
-test_that("array-valued distributions can be defined in models", {
-  skip_if_not(check_tf_version())
-
-
-  dim <- c(5, 2)
-  x <- randn(5, 2)
-  y <- round(randu(5, 2))
-
-  # variable (need to define a likelihood)
-  a <- variable(dim = dim)
-  distribution(x) <- normal(a, 1)
-  expect_ok(model(a))
-
-  # univariate discrete distributions
-  p <- iprobit(normal(0, 1, dim = dim))
-  distribution(y) <- bernoulli(p)
-  expect_ok(model(p))
-
-  p <- iprobit(normal(0, 1, dim = dim))
-  distribution(y) <- binomial(1, p)
-  expect_ok(model(p))
-
-  p <- iprobit(normal(0, 1, dim = dim))
-  distribution(y) <- beta_binomial(1, p, 0.2)
-  expect_ok(model(p))
-
-  p <- iprobit(normal(0, 1, dim = dim))
-  distribution(y) <- negative_binomial(1, p)
-  expect_ok(model(p))
-
-  p <- iprobit(normal(0, 1, dim = dim))
-  distribution(y) <- hypergeometric(10, 5, p)
-  expect_ok(model(p))
-
-  p <- iprobit(normal(0, 1, dim = dim))
-  distribution(y) <- poisson(p)
-  expect_ok(model(p))
-
-  # multivariate discrete distributions
-  y <- extraDistr::rmnom(5, size = 4, prob = runif(3))
-  p <- iprobit(normal(0, 1, dim = 3))
-  distribution(y) <- multinomial(4, t(p), n_realisations = 5)
-  expect_ok(model(p))
-
-  y <- extraDistr::rmnom(5, size = 1, prob = runif(3))
-  p <- iprobit(normal(0, 1, dim = 3))
-  distribution(y) <- categorical(t(p), n_realisations = 5)
-  expect_ok(model(p))
-
-  y <- extraDistr::rmnom(5, size = 4, prob = runif(3))
-  alpha <- lognormal(0, 1, dim = 3)
-  distribution(y) <- dirichlet_multinomial(4, t(alpha), n_realisations = 5)
-  expect_ok(model(alpha))
-
-  # univariate continuous distributions
-  expect_ok(model(normal(-2, 3, dim = dim)))
-  expect_ok(model(student(5.6, -2, 2.3, dim = dim)))
-  expect_ok(model(laplace(-1.2, 1.1, dim = dim)))
-  expect_ok(model(cauchy(-1.2, 1.1, dim = dim)))
-  expect_ok(model(logistic(-1.2, 1.1, dim = dim)))
-
-  expect_ok(model(lognormal(1.2, 0.2, dim = dim)))
-  expect_ok(model(gamma(0.9, 1.3, dim = dim)))
-  expect_ok(model(exponential(6.3, dim = dim)))
-  expect_ok(model(beta(6.3, 5.9, dim = dim)))
-  expect_ok(model(uniform(-13, 2.4, dim = dim)))
-  expect_ok(model(inverse_gamma(0.9, 1.3, dim = dim)))
-  expect_ok(model(weibull(2, 1.1, dim = dim)))
-  expect_ok(model(pareto(2.4, 1.5, dim = dim)))
-  expect_ok(model(chi_squared(4.3, dim = dim)))
-  expect_ok(model(f(24.3, 2.4, dim = dim)))
-
-  # multivariate continuous distributions
-  sig <- rWishart(1, 4, diag(3))[, , 1]
-  expect_ok(
-    model(multivariate_normal(t(rnorm(3)), sig, n_realisations = dim[1]))
-  )
-  expect_ok(model(dirichlet(t(runif(3)), n_realisations = dim[1])))
-  expect_ok(model(wishart(4, sig)))
-  expect_ok(model(lkj_correlation(3, dimension = dim[1])))
-})
-
-test_that("distributions can be sampled from by MCMC", {
-  skip_if_not(check_tf_version())
-
-
-  x <- randn(100)
-  y <- round(randu(100))
-
-  # variable (with a density)
-  a <- variable()
-  distribution(x) <- normal(a, 1)
-  sample_distribution(a)
-
-  b <- variable(lower = -1)
-  distribution(x) <- normal(b, 1)
-  sample_distribution(b)
-
-  c <- variable(upper = -2)
-  distribution(x) <- normal(c, 1)
-  sample_distribution(c)
-
-  d <- variable(lower = 1.2, upper = 1.3)
-  distribution(x) <- normal(d, 1)
-  sample_distribution(d)
-
-  # univariate discrete
-  p <- iprobit(normal(0, 1, dim = 100))
-  distribution(y) <- bernoulli(p)
-  sample_distribution(p)
-
-  p <- iprobit(normal(0, 1, dim = 100))
-  distribution(y) <- binomial(1, p)
-  sample_distribution(p)
-
-  p <- iprobit(normal(0, 1, dim = 100))
-  distribution(y) <- negative_binomial(1, p)
-  sample_distribution(p)
-
-  p <- iprobit(normal(0, 1, dim = 100))
-  distribution(y) <- hypergeometric(10, 5, p)
-  sample_distribution(p)
-
-  p <- iprobit(normal(0, 1, dim = 100))
-  distribution(y) <- poisson(p)
-  sample_distribution(p)
-
-  p <- iprobit(normal(0, 1, dim = 100))
-  distribution(y) <- beta_binomial(1, p, 0.3)
-  sample_distribution(p)
-
-  # multivariate discrete
-  y <- extraDistr::rmnom(5, size = 4, prob = runif(3))
-  p <- uniform(0, 1, dim = 3)
-  distribution(y) <- multinomial(4, t(p), n_realisations = 5)
-  sample_distribution(p)
-
-  y <- extraDistr::rmnom(5, size = 1, prob = runif(3))
-  p <- iprobit(normal(0, 1, dim = 3))
-  distribution(y) <- categorical(t(p), n_realisations = 5)
-  sample_distribution(p)
-
-  y <- extraDistr::rmnom(5, size = 4, prob = runif(3))
-  alpha <- lognormal(0, 1, dim = 3)
-  distribution(y) <- dirichlet_multinomial(4, t(alpha), n_realisations = 5)
-  sample_distribution(alpha)
-
-  # univariate continuous
-  sample_distribution(normal(-2, 3))
-  sample_distribution(student(5.6, -2, 2.3))
-  sample_distribution(laplace(-1.2, 1.1))
-  sample_distribution(cauchy(-1.2, 1.1))
-  sample_distribution(logistic(-1.2, 1.1))
-
-  sample_distribution(lognormal(1.2, 0.2), lower = 0)
-  sample_distribution(gamma(0.9, 1.3), lower = 0)
-  sample_distribution(exponential(6.3), lower = 0)
-  sample_distribution(beta(6.3, 5.9), lower = 0, upper = 1)
-  sample_distribution(inverse_gamma(0.9, 1.3), lower = 0)
-  sample_distribution(weibull(2, 1.1), lower = 0)
-  sample_distribution(pareto(2.4, 0.1), lower = 0.1)
-  sample_distribution(chi_squared(4.3), lower = 0)
-  sample_distribution(f(24.3, 2.4), lower = 0)
-
-  sample_distribution(uniform(-13, 2.4), lower = -13, upper = 2.4)
-
-  # multivariate continuous
-  sig <- rWishart(1, 4, diag(3))[, , 1]
-  sample_distribution(multivariate_normal(t(rnorm(3)), sig))
-  sample_distribution(wishart(10L, Sig = diag(2)), warmup = 0)
-  sample_distribution(lkj_correlation(4, dimension = 3))
-  sample_distribution(dirichlet(t(runif(3))))
-})
-
-test_that("uniform distribution errors informatively", {
-  skip_if_not(check_tf_version())
-  skip_on_ci()
-
-
-  # bad types
-  expect_snapshot_error(
-    uniform(min = 0, max = NA)
-  )
-
-  expect_snapshot_error(
-    uniform(min = 0, max = head)
-  )
-
-  expect_snapshot_error(
-    uniform(min = 1:3, max = 5)
-  )
-
-  # good types, bad values
-  expect_snapshot_error(
-    uniform(min = -Inf, max = Inf)
-  )
-
-  # lower not below upper
-  expect_snapshot_error(
-    uniform(min = 1, max = 1)
-  )
-
-})
-
-test_that("poisson() and binomial() error informatively in glm", {
-  skip_on_ci()
-  skip_if_not(check_tf_version())
-
-  # if passed as an object
-  expect_snapshot_error(
-    glm(1 ~ 1, family = poisson)
-  )
-
-  expect_snapshot_error(
-    glm(1 ~ 1, family = binomial)
-  )
-
-  # if executed alone
-  expect_snapshot_error(
-    glm(1 ~ 1, family = poisson())
-  )
-
-  # if given a link
-  expect_snapshot_error(
-    glm(1 ~ 1, family = poisson("sqrt"))
-  )
-})
-
-test_that("wishart distribution errors informatively", {
-  skip_if_not(check_tf_version())
-
-
-  a <- randn(3, 3)
-  b <- randn(3, 3, 3)
-  c <- randn(3, 2)
-
-  expect_true(inherits(
-    wishart(3, a),
-    "greta_array"
-  ))
-
-  expect_snapshot_error(
-    wishart(3, b)
-  )
-
-  expect_snapshot_error(
-    wishart(3, c)
-  )
-
-})
-
-
-test_that("lkj_correlation distribution errors informatively", {
-  skip_if_not(check_tf_version())
-
-
-  dim <- 3
-
-  expect_true(inherits(
-    lkj_correlation(3, dim),
-    "greta_array"
-  ))
-
-  expect_snapshot_error(
-    lkj_correlation(-1, dim)
-  )
-
-  expect_snapshot_error(
-    lkj_correlation(c(3, 3), dim)
-  )
-
-  expect_snapshot_error(
-    lkj_correlation(uniform(0, 1, dim = 2), dim)
-  )
-
-  expect_snapshot_error(
-    lkj_correlation(4, dimension = -1)
-  )
-
-  expect_snapshot_error(
-    lkj_correlation(4, dim = c(3, 3))
-  )
-
-  expect_snapshot_error(
-    lkj_correlation(4, dim = NA)
-  )
-})
-
-test_that("multivariate_normal distribution errors informatively", {
-  skip_if_not(check_tf_version())
-
-
-  m_a <- randn(1, 3)
-  m_b <- randn(2, 3)
-  m_c <- randn(3)
-  m_d <- randn(3, 1)
-
-  a <- randn(3, 3)
-  b <- randn(3, 3, 3)
-  c <- randn(3, 2)
-  d <- randn(4, 4)
-
-  # good means
-  expect_true(inherits(
-    multivariate_normal(m_a, a),
-    "greta_array"
-  ))
-
-  expect_true(inherits(
-    multivariate_normal(m_b, a),
-    "greta_array"
-  ))
-
-  # bad means
-  expect_snapshot_error(
-    multivariate_normal(m_c, a)
-  )
-
-  expect_snapshot_error(
-    multivariate_normal(m_d, a)
-  )
-
-  # good sigmas
-  expect_true(inherits(
-    multivariate_normal(m_a, a),
-    "greta_array"
-  ))
-
-  # bad sigmas
-  expect_snapshot_error(
-    multivariate_normal(m_a, b)
-  )
-
-  expect_snapshot_error(
-    multivariate_normal(m_a, c)
-  )
-
-  # mismatched parameters
-  expect_snapshot_error(
-    multivariate_normal(m_a, d)
-  )
-
-  # scalars
-  expect_snapshot_error(
-    multivariate_normal(0, 1)
-  )
-
-  # bad n_realisations
-  expect_snapshot_error(
-    multivariate_normal(m_a, a, n_realisations = -1)
-  )
-
-  expect_snapshot_error(
-    multivariate_normal(m_a, a, n_realisations = c(1, 3))
-  )
-
-  # bad dimension
-  expect_snapshot_error(
-    multivariate_normal(m_a, a, dimension = -1)
-  )
-
-  expect_snapshot_error(
-    multivariate_normal(m_a, a, dimension = c(1, 3))
-  )
-})
-
-test_that("multinomial distribution errors informatively", {
-  skip_if_not(check_tf_version())
-
-
-  p_a <- randu(1, 3)
-  p_b <- randu(2, 3)
-
-  # same size & probs
-  expect_true(inherits(
-    multinomial(size = 10, p_a),
-    "greta_array"
-  ))
-
-  expect_true(inherits(
-    multinomial(size = 1:2, p_b),
-    "greta_array"
-  ))
-
-  # n_realisations from prob
-  expect_true(inherits(
-    multinomial(10, p_b),
-    "greta_array"
-  ))
-
-  # n_realisations from size
-  expect_true(inherits(
-    multinomial(c(1, 2), p_a),
-    "greta_array"
-  ))
-
-  # scalars
-  expect_snapshot_error(
-    multinomial(c(1), 1)
-  )
-
-  # bad n_realisations
-  expect_snapshot_error(
-    multinomial(10, p_a, n_realisations = -1)
-  )
-
-  expect_snapshot_error(
-    multinomial(10, p_a, n_realisations = c(1, 3))
-  )
-
-  # bad dimension
-  expect_snapshot_error(
-    multinomial(10, p_a, dimension = -1)
-  )
-
-  expect_snapshot_error(
-    multinomial(10, p_a, dimension = c(1, 3))
-  )
-})
-
-test_that("categorical distribution errors informatively", {
-  skip_if_not(check_tf_version())
-
-
-  p_a <- randu(1, 3)
-  p_b <- randu(2, 3)
-
-  # good probs
-  expect_true(inherits(
-    categorical(p_a),
-    "greta_array"
-  ))
-
-  expect_true(inherits(
-    categorical(p_b),
-    "greta_array"
-  ))
-
-  # scalars
-  expect_snapshot_error(
-    categorical(1),
-  )
-
-  # bad n_realisations
-  expect_snapshot_error(
-    categorical(p_a, n_realisations = -1)
-  )
-
-  expect_snapshot_error(
-    categorical(p_a, n_realisations = c(1, 3))
-  )
-
-  # bad dimension
-  expect_snapshot_error(
-    categorical(p_a, dimension = -1)
-  )
-
-  expect_snapshot_error(
-    categorical(p_a, dimension = c(1, 3))
-  )
-})
-
-test_that("dirichlet distribution errors informatively", {
-  skip_if_not(check_tf_version())
-
-
-  alpha_a <- randu(1, 3)
-  alpha_b <- randu(2, 3)
-
-  # good alpha
-  expect_true(inherits(
-    dirichlet(alpha_a),
-    "greta_array"
-  ))
-
-
-  expect_true(inherits(
-    dirichlet(alpha_b),
-    "greta_array"
-  ))
-
-  # scalars
-  expect_snapshot_error(
-    dirichlet(1),
-  )
-
-  # bad n_realisations
-  expect_snapshot_error(
-    dirichlet(alpha_a, n_realisations = -1)
-  )
-
-  expect_snapshot_error(
-    dirichlet(alpha_a, n_realisations = c(1, 3))
-  )
-
-  # bad dimension
-  expect_snapshot_error(
-    dirichlet(alpha_a, dimension = -1)
-  )
-
-  expect_snapshot_error(
-    dirichlet(alpha_a, dimension = c(1, 3))
-  )
-})
-
-
-test_that("dirichlet values sum to one", {
-  skip_if_not(check_tf_version())
-
-
-  alpha <- uniform(0, 10, dim = c(1, 5))
-  x <- dirichlet(alpha)
-  m <- model(x)
-  draws <- mcmc(m, n_samples = 100, warmup = 100, verbose = FALSE)
-
-  sums <- rowSums(as.matrix(draws))
-  compare_op(sums, 1)
-})
-
-test_that("dirichlet-multinomial distribution errors informatively", {
-  skip_if_not(check_tf_version())
-
-
-  alpha_a <- randu(1, 3)
-  alpha_b <- randu(2, 3)
-
-
-  # same size & probs
-  expect_true(inherits(
-    dirichlet_multinomial(size = 10, alpha_a),
-    "greta_array"
-  ))
-
-  expect_true(inherits(
-    dirichlet_multinomial(size = 1:2, alpha_b),
-    "greta_array"
-  ))
-
-  # n_realisations from alpha
-  expect_true(inherits(
-    dirichlet_multinomial(10, alpha_b),
-    "greta_array"
-  ))
-
-  # n_realisations from size
-  expect_true(inherits(
-    dirichlet_multinomial(c(1, 2), alpha_a),
-    "greta_array"
-  ))
-
-  # scalars
-  expect_snapshot_error(
-    dirichlet_multinomial(c(1), 1)
-  )
-
-  # bad n_realisations
-  expect_snapshot_error(
-    dirichlet_multinomial(10, alpha_a, n_realisations = -1)
-  )
-
-  expect_snapshot_error(
-    dirichlet_multinomial(10, alpha_a, n_realisations = c(1, 3))
-  )
-
-  # bad dimension
-  expect_snapshot_error(
-    dirichlet_multinomial(10, alpha_a, dimension = -1)
-  )
-
-  expect_snapshot_error(
-    dirichlet_multinomial(10, alpha_a, dimension = c(1, 3))
-  )
-})
-
-test_that("Wishart can use a choleskied Sigma", {
-  skip_if_not(check_tf_version())
-
-  sig <- lkj_correlation(2, dim = 2)
-  w <- wishart(5, sig)
-  m <- model(w, precision = "double")
-  expect_ok(draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE))
-})
-
-test_that("multivariate distribs with matrix params can be sampled from", {
-  skip_if_not(check_tf_version())
-
-  n <- 10
-  k <- 3
-
-  # multivariate normal
-  x <- randn(n, k)
-  mu <- normal(0, 1, dim = c(n, k))
-  distribution(x) <- multivariate_normal(mu, diag(k))
-  m <- model(mu)
-  expect_ok(draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE))
-
-  # multinomial
-  size <- 5
-  x <- t(rmultinom(n, size, runif(k)))
-  p <- uniform(0, 1, dim = c(n, k))
-  distribution(x) <- multinomial(size, p)
-  m <- model(p)
-  expect_ok(draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE))
-
-  # categorical
-  x <- t(rmultinom(n, 1, runif(k)))
-  p <- uniform(0, 1, dim = c(n, k))
-  distribution(x) <- categorical(p)
-  m <- model(p)
-  expect_ok(draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE))
-
-  # dirichlet
-  x <- randu(n, k)
-  x <- sweep(x, 1, rowSums(x), "/")
-  a <- normal(0, 1, dim = c(n, k))
-  distribution(x) <- dirichlet(a)
-  m <- model(a)
-  expect_ok(draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE))
-
-  # dirichlet multinomial
-  size <- 5
-  x <- t(rmultinom(n, size, runif(k)))
-  a <- normal(0, 1, dim = c(n, k))
-  distribution(x) <- dirichlet_multinomial(size, a)
-  m <- model(a)
-  expect_ok(draws <- mcmc(m, warmup = 0, n_samples = 5, verbose = FALSE))
-})
diff --git a/tests/testthat/test_zip_zinb.R b/tests/testthat/test_zip_zinb.R
new file mode 100644
index 0000000..0699f3a
--- /dev/null
+++ b/tests/testthat/test_zip_zinb.R
@@ -0,0 +1,23 @@
+test_that("zero inflated poisson distribution has correct density", {
+
+  skip_if_not(check_tf_version())
+  source("helpers.R")
+
+  compare_distribution(zero_inflated_poisson,
+                       extraDistr::dzip,
+                       parameters = list(theta = 0.2, lambda = 2, pi = 0.2),
+                       x = sample_zero_inflated_pois(100, 2, 0.2))
+
+})
+
+test_that("zero inflated negative binomial distribution has correct density", {
+
+  skip_if_not(check_tf_version())
+  source("helpers.R")
+
+  compare_distribution(zero_inflated_negative_binomial,
+                       extraDistr::dzinb,
+                       parameters = list(theta = 2, size = 10, prob = 0.1, pi = 0.2),
+                       x = extraDistr::rzinb(100, 10, 0.1, 0.2))
+
+})

From 89eb6a2e2cc66a9a3226b1e7ed0a5cd8e9aa3f9b Mon Sep 17 00:00:00 2001
From: Nicholas Tierney <nicholas.tierney@gmail.com>
Date: Fri, 29 Jul 2022 22:45:34 +0800
Subject: [PATCH 03/19] explore using distributional for distribution sampling

---
 NAMESPACE                              |   3 +
 R/zero_inflated_negative_binomial.R    |  65 ++++
 R/zero_inflated_poisson.R              |  71 ++++
 R/zero_inflateds.R                     | 104 ------
 man/distributions.Rd                   | 254 --------------
 man/zero_inflated_negative_binomial.Rd |  20 ++
 man/zero_inflated_poisson.Rd           |  18 +
 tests/testthat/helpers.R               | 462 +++++++++++++------------
 tests/testthat/test_zip_zinb.R         |   5 +-
 9 files changed, 428 insertions(+), 574 deletions(-)
 create mode 100644 R/zero_inflated_negative_binomial.R
 create mode 100644 R/zero_inflated_poisson.R
 delete mode 100644 R/zero_inflateds.R
 delete mode 100644 man/distributions.Rd
 create mode 100644 man/zero_inflated_negative_binomial.Rd
 create mode 100644 man/zero_inflated_poisson.Rd

diff --git a/NAMESPACE b/NAMESPACE
index edd1cf8..1a1e9c2 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -1,4 +1,7 @@
 # Generated by roxygen2: do not edit by hand
 
+export(zero_inflated_negative_binomial)
+export(zero_inflated_poisson)
+importFrom(R6,R6Class)
 importFrom(greta,.internals)
 importFrom(tensorflow,tf)
diff --git a/R/zero_inflated_negative_binomial.R b/R/zero_inflated_negative_binomial.R
new file mode 100644
index 0000000..88bd374
--- /dev/null
+++ b/R/zero_inflated_negative_binomial.R
@@ -0,0 +1,65 @@
+#' @name zero_inflated_negative_binomial
+#' @title Zero Inflated Negative Binomial
+#' @description A Zero Inflated Negative Binomial distribution
+#' @param theta proportion of zeros
+#' @param size positive integer parameter
+#' @param prob probability parameter (`0 < prob < 1`),
+#' @param dim a scalar giving the number of rows in the resulting greta array
+#' @export
+zero_inflated_negative_binomial <-
+  function (theta, size, prob, dim = NULL) {
+    distrib('zero_inflated_negative_binomial', theta, size, prob, dim)
+  }
+
+zero_inflated_negative_binomial_distribution <- R6::R6Class(
+  "zero_inflated_negative_binomial_distribution",
+  inherit = greta::.internals$nodes$node_classes$distribution_node,
+  public = list(
+    initialize = function(theta, size, prob, dim) {
+      theta <- as.greta_array(theta)
+      size <- as.greta_array(size)
+      prob <- as.greta_array(prob)
+      # add the nodes as children and parameters
+      dim <- check_dims(theta, size, prob, target_dim = dim)
+      super$initialize("zero_inflated_negative_binomial", dim, discrete = TRUE)
+      self$add_parameter(theta, "theta")
+      self$add_parameter(size, "size")
+      self$add_parameter(prob, "prob")
+    },
+    
+    tf_distrib = function(parameters, dag) {
+      theta <- parameters$theta
+      size <- parameters$size
+      p <- parameters$prob # probability of success
+      q <- fl(1) - parameters$prob
+      log_prob <- function(x) {
+        tf$math$log(
+          theta * tf$nn$relu(fl(1) - x) + (fl(1) - theta) * tf$pow(p, size) * tf$pow(q, x) * tf$exp(tf$math$lgamma(x + size)) / tf$exp(tf$math$lgamma(size)) / tf$exp(tf$math$lgamma(x + fl(1)))
+        )
+        
+      }
+      
+      sample <- function(seed) {
+        binom <- tfp$distributions$Binomial(total_count = 1, probs = theta)
+        negbin <-
+          tfp$distributions$NegativeBinomial(total_count = size, probs = q) # change of proba / parametrisation in 'stats'
+        
+        zi <- binom$sample(seed = seed)
+        lbd <- negbin$sample(seed = seed)
+        
+        (fl(1) - zi) * lbd
+        
+      }
+      
+      list(
+        log_prob = log_prob,
+        sample = sample,
+        cdf = NULL,
+        log_cdf = NULL
+      )
+    },
+    
+    tf_cdf_function = NULL,
+    tf_log_cdf_function = NULL
+  )
+)
\ No newline at end of file
diff --git a/R/zero_inflated_poisson.R b/R/zero_inflated_poisson.R
new file mode 100644
index 0000000..1853406
--- /dev/null
+++ b/R/zero_inflated_poisson.R
@@ -0,0 +1,71 @@
+#' @name zero_inflated_poisson
+#' @title Zero Inflated Poisson distribution
+#'
+#' @description A zero inflated poisson distribution.
+#'
+#' @param theta proportion of zeros
+#' @param lambda rate parameter
+#' @param dim a scalar giving the number of rows in the resulting greta array
+#' @importFrom R6 R6Class
+#' @export
+zero_inflated_poisson <- function (theta, lambda, dim = NULL) {
+  distrib('zero_inflated_poisson', theta, lambda, dim)
+}
+
+#' @importFrom R6 R6Class
+zero_inflated_poisson_distribution <- R6::R6Class(
+  classname = "zero_inflated_poisson_distribution",
+  inherit = distribution_node,
+  public = list(
+    initialize = function(theta, lambda, dim) {
+      theta <- as.greta_array(theta)
+      lambda <- as.greta_array(lambda)
+      # add the nodes as children and parameters
+      dim <- check_dims(theta, lambda, target_dim = dim)
+      super$initialize("zero_inflated_poisson", dim, discrete = TRUE)
+      self$add_parameter(theta, "theta")
+      self$add_parameter(lambda, "lambda")
+    },
+    
+    tf_distrib = function(parameters, dag) {
+      theta <- parameters$theta
+      lambda <- parameters$lambda
+      log_prob <- function(x) {
+        tf$math$log(
+          theta * 
+            tf$nn$relu(fl(1) - x) + 
+            (fl(1) - theta) * 
+            tf$pow(lambda, x) * 
+            tf$exp(-lambda) / tf$exp(tf$math$lgamma(x + fl(1)))
+        )
+      }
+      
+      sample <- function(seed) {
+        binom <- tfp$distributions$Binomial(total_count = 1, probs = theta)
+        pois <- tfp$distributions$Poisson(rate = lambda)
+        
+        zi <- binom$sample(seed = seed)
+        lbd <- pois$sample(seed = seed)
+        
+        (fl(1) - zi) * lbd
+        
+      }
+      
+      list(
+        log_prob = log_prob,
+        sample = sample,
+        cdf = NULL,
+        log_cdf = NULL
+      )
+    },
+    
+    tf_cdf_function = NULL,
+    tf_log_cdf_function = NULL
+  )
+)
+# NOTE - not sure what to do here with the module stuff?
+# distribution_classes_module <-
+#   module(
+#     zero_inflated_poisson_distribution,
+#     zero_inflated_negative_binomial_distribution
+#   )
\ No newline at end of file
diff --git a/R/zero_inflateds.R b/R/zero_inflateds.R
deleted file mode 100644
index d31d5ba..0000000
--- a/R/zero_inflateds.R
+++ /dev/null
@@ -1,104 +0,0 @@
-zero_inflated_poisson_distribution <- R6Class(
-  "zero_inflated_poisson_distribution",
-  inherit = greta::.internals$nodes$node_classes$distribution_node,
-  public = list(
-    initialize = function(theta, lambda, dim) {
-      theta <- as.greta_array(theta)
-      lambda <- as.greta_array(lambda)
-      # add the nodes as children and parameters
-      dim <- check_dims(theta, lambda, target_dim = dim)
-      super$initialize("zero_inflated_poisson", dim, discrete = TRUE)
-      self$add_parameter(theta, "theta")
-      self$add_parameter(lambda, "lambda")
-    },
-    
-    tf_distrib = function(parameters, dag) {
-      theta <- parameters$theta
-      lambda <- parameters$lambda
-      log_prob <- function(x) {
-        
-        tf$math$log(theta * tf$nn$relu(fl(1) - x) + (fl(1) - theta) * tf$pow(lambda, x) * tf$exp(-lambda) / tf$exp(tf$math$lgamma(x + fl(1))))
-      }
-      
-      sample <- function(seed) {
-        
-        binom <- tfp$distributions$Binomial(total_count = 1, probs = theta)
-        pois <- tfp$distributions$Poisson(rate = lambda)
-        
-        zi <- binom$sample(seed = seed)
-        lbd <- pois$sample(seed = seed)
-        
-        (fl(1) - zi) * lbd
-        
-      }
-      
-      list(log_prob = log_prob, sample = sample, cdf = NULL, log_cdf = NULL)
-    },
-    
-    tf_cdf_function = NULL,
-    tf_log_cdf_function = NULL
-  )
-)
-
-
-zero_inflated_negative_binomial_distribution <- R6Class(
-  "zero_inflated_negative_binomial_distribution",
-  inherit = greta::.internals$nodes$node_classes$distribution_node,
-  public = list(
-    initialize = function(theta, size, prob, dim) {
-      theta <- as.greta_array(theta)
-      size <- as.greta_array(size)
-      prob <- as.greta_array(prob)
-      # add the nodes as children and parameters
-      dim <- check_dims(theta, size, prob, target_dim = dim)
-      super$initialize("zero_inflated_negative_binomial", dim, discrete = TRUE)
-      self$add_parameter(theta, "theta")
-      self$add_parameter(size, "size")
-      self$add_parameter(prob, "prob")
-    },
-    
-    tf_distrib = function(parameters, dag) {
-      theta <- parameters$theta
-      size <- parameters$size
-      p <- parameters$prob # probability of success
-      q <- fl(1) - parameters$prob 
-      log_prob <- function(x) {
-        
-        tf$math$log(theta * tf$nn$relu(fl(1) - x) + (fl(1) - theta) * tf$pow(p, size) * tf$pow(q, x) * tf$exp(tf$math$lgamma(x + size)) / tf$exp(tf$math$lgamma(size)) / tf$exp(tf$math$lgamma(x + fl(1))))
-        
-      }
-      
-      sample <- function(seed) {
-        
-        binom <- tfp$distributions$Binomial(total_count = 1, probs = theta)
-        negbin <- tfp$distributions$NegativeBinomial(total_count = size, probs = q) # change of proba / parametrisation in 'stats'
-        
-        zi <- binom$sample(seed = seed)
-        lbd <- negbin$sample(seed = seed)
-        
-        (fl(1) - zi) * lbd
-        
-      }
-      
-      list(log_prob = log_prob, sample = sample, cdf = NULL, log_cdf = NULL)
-    },
-    
-    tf_cdf_function = NULL,
-    tf_log_cdf_function = NULL
-  )
-)
-
-#' @rdname distributions
-#' @export
-zero_inflated_poisson <- function (theta, lambda, dim = NULL) {
-  distrib('zero_inflated_poisson', theta, lambda, dim)
-}
-
-#' @rdname distributions
-#' @export
-zero_inflated_negative_binomial <- function (theta, size, prob, dim = NULL) {
-  distrib('zero_inflated_negative_binomial', theta, size, prob, dim)
-}
-
-distribution_classes_module <- module(zero_inflated_poisson_distribution,
-                                      zero_inflated_negative_binomial_distribution)
\ No newline at end of file
diff --git a/man/distributions.Rd b/man/distributions.Rd
deleted file mode 100644
index c133420..0000000
--- a/man/distributions.Rd
+++ /dev/null
@@ -1,254 +0,0 @@
-% Generated by roxygen2: do not edit by hand
-% Please edit documentation in R/probability_distributions.R
-\name{distributions}
-\alias{distributions}
-\alias{uniform}
-\alias{normal}
-\alias{lognormal}
-\alias{bernoulli}
-\alias{binomial}
-\alias{beta_binomial}
-\alias{negative_binomial}
-\alias{hypergeometric}
-\alias{poisson}
-\alias{zero_inflated_poisson}
-\alias{zero_inflated_negative_binomial}
-\alias{gamma}
-\alias{inverse_gamma}
-\alias{weibull}
-\alias{exponential}
-\alias{pareto}
-\alias{student}
-\alias{laplace}
-\alias{beta}
-\alias{cauchy}
-\alias{chi_squared}
-\alias{logistic}
-\alias{f}
-\alias{multivariate_normal}
-\alias{wishart}
-\alias{lkj_correlation}
-\alias{multinomial}
-\alias{categorical}
-\alias{dirichlet}
-\alias{dirichlet_multinomial}
-\title{probability distributions}
-\usage{
-uniform(min, max, dim = NULL)
-
-normal(mean, sd, dim = NULL, truncation = c(-Inf, Inf))
-
-lognormal(meanlog, sdlog, dim = NULL, truncation = c(0, Inf))
-
-bernoulli(prob, dim = NULL)
-
-binomial(size, prob, dim = NULL)
-
-beta_binomial(size, alpha, beta, dim = NULL)
-
-negative_binomial(size, prob, dim = NULL)
-
-hypergeometric(m, n, k, dim = NULL)
-
-poisson(lambda, dim = NULL)
-
-zero_inflated_poisson(theta, lambda, dim = NULL)
-
-zero_inflated_negative_binomial(theta, size, prob, dim = NULL)
-
-gamma(shape, rate, dim = NULL, truncation = c(0, Inf))
-
-inverse_gamma(alpha, beta, dim = NULL, truncation = c(0, Inf))
-
-weibull(shape, scale, dim = NULL, truncation = c(0, Inf))
-
-exponential(rate, dim = NULL, truncation = c(0, Inf))
-
-pareto(a, b, dim = NULL, truncation = c(0, Inf))
-
-student(df, mu, sigma, dim = NULL, truncation = c(-Inf, Inf))
-
-laplace(mu, sigma, dim = NULL, truncation = c(-Inf, Inf))
-
-beta(shape1, shape2, dim = NULL, truncation = c(0, 1))
-
-cauchy(location, scale, dim = NULL, truncation = c(-Inf, Inf))
-
-chi_squared(df, dim = NULL, truncation = c(0, Inf))
-
-logistic(location, scale, dim = NULL, truncation = c(-Inf, Inf))
-
-f(df1, df2, dim = NULL, truncation = c(0, Inf))
-
-multivariate_normal(mean, Sigma, n_realisations = NULL, dimension = NULL)
-
-wishart(df, Sigma)
-
-lkj_correlation(eta, dimension = 2)
-
-multinomial(size, prob, n_realisations = NULL, dimension = NULL)
-
-categorical(prob, n_realisations = NULL, dimension = NULL)
-
-dirichlet(alpha, n_realisations = NULL, dimension = NULL)
-
-dirichlet_multinomial(size, alpha, n_realisations = NULL, dimension = NULL)
-}
-\arguments{
-\item{min, max}{scalar values giving optional limits to \code{uniform}
-variables. Like \code{lower} and \code{upper}, these must be specified as
-numerics, they cannot be greta arrays (though see details for a
-workaround). Unlike \code{lower} and \code{upper}, they must be finite.
-\code{min} must always be less than \code{max}.}
-
-\item{dim}{the dimensions of the greta array to be returned, either a scalar
-or a vector of positive integers. See details.}
-
-\item{mean, meanlog, location, mu}{unconstrained parameters}
-
-\item{sd, sdlog, sigma, lambda, shape, rate, df, scale, shape1, shape2, alpha, beta, df1, df2, a, b, eta}{positive parameters, \code{alpha} must be a vector for \code{dirichlet}
-and \code{dirichlet_multinomial}.}
-
-\item{truncation}{a length-two vector giving values between which to truncate
-the distribution, similarly to the \code{lower} and \code{upper} arguments
-to \code{\link[=variable]{variable()}}}
-
-\item{prob}{probability parameter (\verb{0 < prob < 1}), must be a vector for
-\code{multinomial} and \code{categorical}}
-
-\item{size, m, n, k}{positive integer parameter}
-
-\item{Sigma}{positive definite variance-covariance matrix parameter}
-
-\item{n_realisations}{the number of independent realisation of a multivariate
-distribution}
-
-\item{dimension}{the dimension of a multivariate distribution}
-}
-\description{
-These functions can be used to define random variables in a
-greta model. They return a variable greta array that follows the specified
-distribution. This variable greta array can be used to represent a
-parameter with prior distribution, combined into a mixture distribution
-using \code{\link[=mixture]{mixture()}}, or used with \code{\link[=distribution]{distribution()}} to
-define a distribution over a data greta array.
-}
-\details{
-The discrete probability distributions (\code{bernoulli},
-\code{binomial}, \code{negative_binomial}, \code{poisson},
-\code{multinomial}, \code{categorical}, \code{dirichlet_multinomial}) can
-be used when they have fixed values (e.g. defined as a likelihood using
-\code{\link[=distribution]{distribution()}}, but not as unknown variables.
-
-For univariate distributions \code{dim} gives the dimensions of the greta
-array to create. Each element of the greta array will be (independently)
-distributed according to the distribution. \code{dim} can also be left at
-its default of \code{NULL}, in which case the dimension will be detected
-from the dimensions of the parameters (provided they are compatible with
-one another).
-
-For multivariate distributions (\code{multivariate_normal()},
-\code{multinomial()}, \code{categorical()}, \code{dirichlet()}, and
-\code{dirichlet_multinomial()}) each row of the output and parameters
-corresponds to an independent realisation. If a single realisation or
-parameter value is specified, it must therefore be a row vector (see
-example). \code{n_realisations} gives the number of rows/realisations, and
-\code{dimension} gives the dimension of the distribution. I.e. a bivariate
-normal distribution would be produced with \code{multivariate_normal(..., dimension = 2)}. The dimension can usually be detected from the parameters.
-
-\code{multinomial()} does not check that observed values sum to
-\code{size}, and \code{categorical()} does not check that only one of the
-observed entries is 1. It's the user's responsibility to check their data
-matches the distribution!
-
-The parameters of \code{uniform} must be fixed, not greta arrays. This
-ensures these values can always be transformed to a continuous scale to run
-the samplers efficiently. However, a hierarchical \code{uniform} parameter
-can always be created by defining a \code{uniform} variable constrained
-between 0 and 1, and then transforming it to the required scale. See below
-for an example.
-
-Wherever possible, the parameterisations and argument names of greta
-distributions match commonly used R functions for distributions, such as
-those in the \code{stats} or \code{extraDistr} packages. The following
-table states the distribution function to which greta's implementation
-corresponds:
-
-\tabular{ll}{ greta \tab reference\cr \code{uniform} \tab
-\link[stats:Uniform]{stats::dunif}\cr \code{normal} \tab
-\link[stats:Normal]{stats::dnorm}\cr \code{lognormal} \tab
-\link[stats:Lognormal]{stats::dlnorm}\cr \code{bernoulli} \tab
-\link[extraDistr:Bernoulli]{extraDistr::dbern}\cr \code{binomial} \tab
-\link[stats:Binomial]{stats::dbinom}\cr \code{beta_binomial} \tab
-\link[extraDistr:BetaBinom]{extraDistr::dbbinom}\cr \code{negative_binomial}
-\tab \link[stats:NegBinomial]{stats::dnbinom}\cr \code{hypergeometric} \tab
-\link[stats:Hypergeometric]{stats::dhyper}\cr \code{poisson} \tab
-\link[stats:Poisson]{stats::dpois}\cr \code{gamma} \tab
-\link[stats:GammaDist]{stats::dgamma}\cr \code{inverse_gamma} \tab
-\link[extraDistr:InvGamma]{extraDistr::dinvgamma}\cr \code{weibull} \tab
-\link[stats:Weibull]{stats::dweibull}\cr \code{exponential} \tab
-\link[stats:Exponential]{stats::dexp}\cr \code{pareto} \tab
-\link[extraDistr:Pareto]{extraDistr::dpareto}\cr \code{student} \tab
-\link[extraDistr:LocationScaleT]{extraDistr::dlst}\cr \code{laplace} \tab
-\link[extraDistr:Laplace]{extraDistr::dlaplace}\cr \code{beta} \tab
-\link[stats:Beta]{stats::dbeta}\cr \code{cauchy} \tab
-\link[stats:Cauchy]{stats::dcauchy}\cr \code{chi_squared} \tab
-\link[stats:Chisquare]{stats::dchisq}\cr \code{logistic} \tab
-\link[stats:Logistic]{stats::dlogis}\cr \code{f} \tab
-\link[stats:Fdist]{stats::df}\cr \code{multivariate_normal} \tab
-\link[mvtnorm:Mvnorm]{mvtnorm::dmvnorm}\cr \code{multinomial} \tab
-\link[stats:Multinom]{stats::dmultinom}\cr \code{categorical} \tab
-{\link[stats:Multinom]{stats::dmultinom} (size = 1)}\cr \code{dirichlet}
-\tab \link[extraDistr:Dirichlet]{extraDistr::ddirichlet}\cr
-\code{dirichlet_multinomial} \tab
-\link[extraDistr:DirMnom]{extraDistr::ddirmnom}\cr \code{wishart} \tab
-\link[stats:rWishart]{stats::rWishart}\cr \code{lkj_correlation} \tab
-\href{https://rdrr.io/github/rmcelreath/rethinking/man/dlkjcorr.html}{rethinking::dlkjcorr}
-}
-}
-\examples{
-\dontrun{
-
-# a uniform parameter constrained to be between 0 and 1
-phi <- uniform(min = 0, max = 1)
-
-# a length-three variable, with each element following a standard normal
-# distribution
-alpha <- normal(0, 1, dim = 3)
-
-# a length-three variable of lognormals
-sigma <- lognormal(0, 3, dim = 3)
-
-# a hierarchical uniform, constrained between alpha and alpha + sigma,
-eta <- alpha + uniform(0, 1, dim = 3) * sigma
-
-# a hierarchical distribution
-mu <- normal(0, 1)
-sigma <- lognormal(0, 1)
-theta <- normal(mu, sigma)
-
-# a vector of 3 variables drawn from the same hierarchical distribution
-thetas <- normal(mu, sigma, dim = 3)
-
-# a matrix of 12 variables drawn from the same hierarchical distribution
-thetas <- normal(mu, sigma, dim = c(3, 4))
-
-# a multivariate normal variable, with correlation between two elements
-# note that the parameter must be a row vector
-Sig <- diag(4)
-Sig[3, 4] <- Sig[4, 3] <- 0.6
-theta <- multivariate_normal(t(rep(mu, 4)), Sig)
-
-# 10 independent replicates of that
-theta <- multivariate_normal(t(rep(mu, 4)), Sig, n_realisations = 10)
-
-# 10 multivariate normal replicates, each with a different mean vector,
-# but the same covariance matrix
-means <- matrix(rnorm(40), 10, 4)
-theta <- multivariate_normal(means, Sig, n_realisations = 10)
-dim(theta)
-
-# a Wishart variable with the same covariance parameter
-theta <- wishart(df = 5, Sigma = Sig)
-}
-}
diff --git a/man/zero_inflated_negative_binomial.Rd b/man/zero_inflated_negative_binomial.Rd
new file mode 100644
index 0000000..972209a
--- /dev/null
+++ b/man/zero_inflated_negative_binomial.Rd
@@ -0,0 +1,20 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/zero_inflated_negative_binomial.R
+\name{zero_inflated_negative_binomial}
+\alias{zero_inflated_negative_binomial}
+\title{Zero Inflated Negative Binomial}
+\usage{
+zero_inflated_negative_binomial(theta, size, prob, dim = NULL)
+}
+\arguments{
+\item{theta}{proportion of zeros}
+
+\item{size}{positive integer parameter}
+
+\item{prob}{probability parameter (\verb{0 < prob < 1}),}
+
+\item{dim}{a scalar giving the number of rows in the resulting greta array}
+}
+\description{
+A Zero Inflated Negative Binomial distribution
+}
diff --git a/man/zero_inflated_poisson.Rd b/man/zero_inflated_poisson.Rd
new file mode 100644
index 0000000..efa7b47
--- /dev/null
+++ b/man/zero_inflated_poisson.Rd
@@ -0,0 +1,18 @@
+% Generated by roxygen2: do not edit by hand
+% Please edit documentation in R/zero_inflated_poisson.R
+\name{zero_inflated_poisson}
+\alias{zero_inflated_poisson}
+\title{Zero Inflated Poisson distribution}
+\usage{
+zero_inflated_poisson(theta, lambda, dim = NULL)
+}
+\arguments{
+\item{theta}{proportion of zeros}
+
+\item{lambda}{rate parameter}
+
+\item{dim}{a scalar giving the number of rows in the resulting greta array}
+}
+\description{
+A zero inflated poisson distribution.
+}
diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R
index 6f6b8fb..46f2537 100644
--- a/tests/testthat/helpers.R
+++ b/tests/testthat/helpers.R
@@ -16,13 +16,13 @@ grab <- function(x, dag = NULL) {
   if (inherits(x, "node")) {
     x <- as.greta_array(x)
   }
-
+  
   if (inherits(x, "greta_array")) {
     node <- get_node(x)
     dag <- dag_class$new(list(x))
     dag$define_tf()
   }
-
+  
   dag$set_tf_data_list("batch_size", 1L)
   dag$build_feed_dict()
   out <- dag$tf_sess_run(dag$tf_name(node), as_text = TRUE)
@@ -44,19 +44,23 @@ set_distribution <- function(dist, data) {
 get_density <- function(distrib, data) {
   x <- as_data(data)
   distribution(x) <- distrib
-
+  
   # create dag and define the density
   dag <- dag_class$new(list(x))
   get_node(x)$distribution$define_tf(dag)
-
+  
   # get the log density as a vector
   tensor_name <- dag$tf_name(get_node(distrib)$distribution)
   tensor <- get(tensor_name, envir = dag$tf_environment)
   as.vector(grab(tensor, dag))
 }
 
-compare_distribution <- function(greta_fun, r_fun, parameters, x,
-                                 dim = NULL, multivariate = FALSE,
+compare_distribution <- function(greta_fun,
+                                 r_fun,
+                                 parameters,
+                                 x,
+                                 dim = NULL,
+                                 multivariate = FALSE,
                                  tolerance = 1e-4) {
   # calculate the absolute difference in the log density of some data between
   # greta and a r benchmark.
@@ -65,65 +69,65 @@ compare_distribution <- function(greta_fun, r_fun, parameters, x,
   # both of these functions must take the same parameters in the same order
   # 'parameters' is an optionally named list of numeric parameter values
   # x is the vector of values at which to evaluate the log density
-
+  
   # define greta distribution, with fixed values
-  greta_log_density <- greta_density(
-    greta_fun, parameters, x,
-    dim, multivariate
-  )
+  greta_log_density <- greta_density(greta_fun, parameters, x,
+                                     dim, multivariate)
   # get R version
   r_log_density <- log(do.call(r_fun, c(list(x), parameters)))
-
+  
   # return absolute difference
   compare_op(r_log_density, greta_log_density, tolerance)
 }
 
 # evaluate the log density of x, given 'parameters' and a distribution
 # constructor function 'fun'
-greta_density <- function(fun, parameters, x,
-                          dim = NULL, multivariate = FALSE) {
+greta_density <- function(fun,
+                          parameters,
+                          x,
+                          dim = NULL,
+                          multivariate = FALSE) {
   if (is.null(dim)) {
     dim <- NROW(x)
   }
-
+  
   # add the output dimension to the arguments list
   dim_list <- list(dim = dim)
-
+  
   # if it's a multivariate distribution name it n_realisations
   if (multivariate) {
     names(dim_list) <- "n_realisations"
   }
-
+  
   # don't add it for wishart & lkj, which don't mave multiple realisations
   is_wishart <- identical(names(parameters), c("df", "Sigma"))
   is_lkj <- identical(names(parameters), c("eta", "dimension"))
   if (is_wishart | is_lkj) {
     dim_list <- list()
   }
-
+  
   parameters <- c(parameters, dim_list)
-
+  
   # evaluate greta distribution
   dist <- do.call(fun, parameters)
   distrib_node <- get_node(dist)$distribution
-
+  
   # set density
   x_ <- as.greta_array(x)
   distrib_node$remove_target()
   distrib_node$add_target(get_node(x_))
-
+  
   # create dag
   dag <- dag_class$new(list(x_))
   dag$define_tf()
   dag$set_tf_data_list("batch_size", 1L)
   dag$build_feed_dict()
-
+  
   # get the log density as a vector
-  dag$on_graph(
-    result <- dag$evaluate_density(distrib_node, get_node(x_))
-  )
+  dag$on_graph(result <-
+                 dag$evaluate_density(distrib_node, get_node(x_)))
   assign("test_density", result, dag$tf_environment)
-
+  
   density <- dag$tf_sess_run(test_density)
   as.vector(density)
 }
@@ -132,44 +136,43 @@ greta_density <- function(fun, parameters, x,
 # arrays, then converting the result back to R. 'swap_scope' tells eval() how
 # many environments to go up to get the objects for the swap; 1 would be
 # environment above the funct, 2 would be the environment above that etc.
-with_greta <- function(call, swap = c("x"), swap_scope = 1) {
+with_greta <- function(call,
+                       swap = c("x"),
+                       swap_scope = 1) {
   swap_entries <- paste0(swap, " = as_data(", swap, ")")
-  swap_text <- paste0(
-    "list(",
-    paste(swap_entries, collapse = ", "),
-    ")"
-  )
+  swap_text <- paste0("list(",
+                      paste(swap_entries, collapse = ", "),
+                      ")")
   swap_list <- eval(parse(text = swap_text),
-    envir = parent.frame(n = swap_scope)
-  )
-
-  greta_result <- with(
-    swap_list,
-    eval(call)
-  )
+                    envir = parent.frame(n = swap_scope))
+  
+  greta_result <- with(swap_list,
+                       eval(call))
   result <- grab(greta_result)
-
+  
   # account for the fact that greta outputs are 1D arrays; convert them back to
   # R vectors
-  if (is.array(result) && length(dim(result)) == 2 && dim(result)[2] == 1) {
+  if (is.array(result) &&
+      length(dim(result)) == 2 && dim(result)[2] == 1) {
     result <- as.vector(result)
   }
-
+  
   result
 }
 
 # check an expression is equivalent when done in R, and when done on greta
 # arrays with results ported back to R
 # e.g. check_expr(a[1:3], swap = 'a')
-check_expr <- function(expr, swap = c("x"), tolerance = 1e-4) {
+check_expr <- function(expr,
+                       swap = c("x"),
+                       tolerance = 1e-4) {
   call <- substitute(expr)
-
+  
   r_out <- eval(expr)
   greta_out <- with_greta(call,
-    swap = swap,
-    swap_scope = 2
-  )
-
+                          swap = swap,
+                          swap_scope = 2)
+  
   compare_op(r_out, greta_out, tolerance)
 }
 
@@ -190,22 +193,27 @@ gen_opfun <- function(n, ops) {
   for (i in seq_len(n)) {
     string <- add_op_string(string, ops = ops)
   }
-
+  
   fun_string <- sprintf("function(a, b) {%s}", string)
-
+  
   eval(parse(text = fun_string))
 }
 
 # sample n values from a distribution by HMC, check they all have the correct
 # support greta array is defined as a stochastic in the call
-sample_distribution <- function(greta_array, n = 10,
-                                lower = -Inf, upper = Inf,
+sample_distribution <- function(greta_array,
+                                n = 10,
+                                lower = -Inf,
+                                upper = Inf,
                                 warmup = 1) {
   m <- model(greta_array, precision = "double")
-  draws <- mcmc(m, n_samples = n, warmup = warmup, verbose = FALSE)
+  draws <- mcmc(m,
+                n_samples = n,
+                warmup = warmup,
+                verbose = FALSE)
   samples <- as.matrix(draws)
   vectorised <- length(lower) > 1 | length(upper) > 1
-
+  
   if (vectorised) {
     above_lower <- sweep(samples, 2, lower, `>=`)
     below_upper <- sweep(samples, 2, upper, `<=`)
@@ -213,7 +221,7 @@ sample_distribution <- function(greta_array, n = 10,
     above_lower <- samples >= lower
     below_upper <- samples <= upper
   }
-
+  
   expect_true(all(above_lower & below_upper))
 }
 
@@ -227,29 +235,27 @@ compare_truncated_distribution <- function(greta_fun,
   # is a greta array created from a distribution and a constrained variable
   # greta array. 'r_fun' is an r function returning the log density for the same
   # truncated distribution, taking x as its only argument.
-
-  x <- do.call(
-    truncdist::rtrunc,
-    c(
-      n = 100,
-      spec = which,
-      a = truncation[1],
-      b = truncation[2],
-      parameters
-    )
-  )
-
+  
+  x <- do.call(truncdist::rtrunc,
+               c(
+                 n = 100,
+                 spec = which,
+                 a = truncation[1],
+                 b = truncation[2],
+                 parameters
+               ))
+  
   # create truncated R function and evaluate it
   r_fun <- truncfun(which, parameters, truncation)
   r_log_density <- log(r_fun(x))
-
+  
   greta_log_density <- greta_density(
     fun = greta_fun,
     parameters = c(parameters, list(truncation = truncation)),
     x = x,
     dim = 1
   )
-
+  
   # return absolute difference
   compare_op(r_log_density, greta_log_density, tolerance)
 }
@@ -257,13 +263,11 @@ compare_truncated_distribution <- function(greta_fun,
 # use the truncdist package to crete a truncated distribution function for use
 # in compare_truncated_distribution
 truncfun <- function(which = "norm", parameters, truncation) {
-  args <- c(
-    spec = which,
-    a = truncation[1],
-    b = truncation[2],
-    parameters
-  )
-
+  args <- c(spec = which,
+            a = truncation[1],
+            b = truncation[2],
+            parameters)
+  
   function(x) {
     arg_list <- c(x = list(x), args)
     do.call(truncdist::dtrunc, arg_list)
@@ -296,7 +300,9 @@ qt_ls <- function(p, df, location, scale, log.p = FALSE) {
 }
 
 # mock up the progress bar to force its output to stdout for testing
-cpb <- eval(parse(text = capture.output(dput(create_progress_bar))))
+cpb <- eval(parse(text = capture.output(dput(
+  create_progress_bar
+))))
 mock_create_progress_bar <- function(...) {
   cpb(..., stream = stdout())
 }
@@ -309,16 +315,18 @@ get_output <- function(expr) {
   i <- 0
   suppressMessages(withCallingHandlers(
     expr,
-    message = function(e) msgs[[i <<- i + 1]] <<- conditionMessage(e)
+    message = function(e)
+      msgs[[i <<- i + 1]] <<- conditionMessage(e)
   ))
   paste0(msgs, collapse = "")
 }
 
 # mock up mcmc progress bar output for neurotic testing
 mock_mcmc <- function(n_samples = 1010) {
-  pb <- create_progress_bar("sampling", c(0, n_samples),
-    pb_update = 10, width = 50
-  )
+  pb <- create_progress_bar("sampling",
+                            c(0, n_samples),
+                            pb_update = 10,
+                            width = 50)
   iterate_progress_bar(pb, n_samples, rejects = 10, chains = 1)
 }
 
@@ -328,15 +336,15 @@ rlkjcorr <- function(n, eta = 1, dimension = 2) {
   k <- dimension
   stopifnot(is.numeric(k), k >= 2, k == as.integer(k))
   stopifnot(eta > 0)
-
+  
   f <- function() {
     alpha <- eta + (k - 2) / 2
     r12 <- 2 * stats::rbeta(1, alpha, alpha) - 1
     r <- matrix(0, k, k)
     r[1, 1] <- 1
     r[1, 2] <- r12
-    r[2, 2] <- sqrt(1 - r12^2)
-
+    r[2, 2] <- sqrt(1 - r12 ^ 2)
+    
     if (k > 2) {
       for (m in 2:(k - 1)) {
         alpha <- alpha - 0.5
@@ -347,27 +355,29 @@ rlkjcorr <- function(n, eta = 1, dimension = 2) {
         r[m + 1, m + 1] <- sqrt(1 - y)
       }
     }
-
+    
     crossprod(r)
   }
-
+  
   r <- replicate(n, f())
-
+  
   if (dim(r)[3] == 1) {
     r <- r[, , 1]
   } else {
     r <- aperm(r, c(3, 1, 2))
   }
-
+  
   r
 }
 
 # helper RNG functions
-rmvnorm <- function(n, mean, Sigma) { # nolint
+rmvnorm <- function(n, mean, Sigma) {
+  # nolint
   mvtnorm::rmvnorm(n = n, mean = mean, sigma = Sigma)
 }
 
-rwish <- function(n, df, Sigma) { # nolint
+rwish <- function(n, df, Sigma) {
+  # nolint
   draws <- stats::rWishart(n = n, df = df, Sigma = Sigma)
   aperm(draws, c(3, 1, 2))
 }
@@ -429,21 +439,28 @@ rtf <- function(n, df1, df2, truncation) {
 # joint testing functions
 joint_normals <- function(...) {
   params_list <- list(...)
-  components <- lapply(params_list, function(par) do.call(normal, par))
+  components <-
+    lapply(params_list, function(par)
+      do.call(normal, par))
   do.call(joint, components)
 }
 
 rjnorm <- function(n, ...) {
   params_list <- list(...)
-  args_list <- lapply(params_list, function(par) c(n, par))
-  sims <- lapply(args_list, function(par) do.call(stats::rnorm, par))
+  args_list <- lapply(params_list, function(par)
+    c(n, par))
+  sims <-
+    lapply(args_list, function(par)
+      do.call(stats::rnorm, par))
   do.call(cbind, sims)
 }
 
 rjtnorm <- function(n, ...) {
   params_list <- list(...)
-  args_list <- lapply(params_list, function(par) c(n, par))
-  sims <- lapply(args_list, function(par) do.call(rtnorm, par))
+  args_list <- lapply(params_list, function(par)
+    c(n, par))
+  sims <- lapply(args_list, function(par)
+    do.call(rtnorm, par))
   do.call(cbind, sims)
 }
 
@@ -452,7 +469,9 @@ mixture_normals <- function(...) {
   args <- list(...)
   is_weights <- names(args) == "weights"
   params_list <- args[!is_weights]
-  components <- lapply(params_list, function(par) do.call(normal, par))
+  components <-
+    lapply(params_list, function(par)
+      do.call(normal, par))
   do.call(mixture, c(components, args[is_weights]))
 }
 
@@ -460,12 +479,10 @@ mixture_multivariate_normals <- function(...) {
   args <- list(...)
   is_weights <- names(args) == "weights"
   params_list <- args[!is_weights]
-  components <- lapply(
-    params_list,
-    function(par) {
-      do.call(multivariate_normal, par)
-    }
-  )
+  components <- lapply(params_list,
+                       function(par) {
+                         do.call(multivariate_normal, par)
+                       })
   do.call(mixture, c(components, args[is_weights]))
 }
 
@@ -474,10 +491,13 @@ rmixnorm <- function(n, ...) {
   is_weights <- names(args) == "weights"
   params_list <- args[!is_weights]
   weights <- args[[which(is_weights)]]
-  args_list <- lapply(params_list, function(par) c(n, par))
-  sims <- lapply(args_list, function(par) do.call(rnorm, par))
+  args_list <- lapply(params_list, function(par)
+    c(n, par))
+  sims <- lapply(args_list, function(par)
+    do.call(rnorm, par))
   draws <- do.call(cbind, sims)
-  components <- sample.int(length(sims), n, prob = weights, replace = TRUE)
+  components <-
+    sample.int(length(sims), n, prob = weights, replace = TRUE)
   idx <- cbind(seq_len(n), components)
   draws[idx]
 }
@@ -487,10 +507,13 @@ rmixtnorm <- function(n, ...) {
   is_weights <- names(args) == "weights"
   params_list <- args[!is_weights]
   weights <- args[[which(is_weights)]]
-  args_list <- lapply(params_list, function(par) c(n, par))
-  sims <- lapply(args_list, function(par) do.call(rtnorm, par))
+  args_list <- lapply(params_list, function(par)
+    c(n, par))
+  sims <- lapply(args_list, function(par)
+    do.call(rtnorm, par))
   draws <- do.call(cbind, sims)
-  components <- sample.int(length(sims), n, prob = weights, replace = TRUE)
+  components <-
+    sample.int(length(sims), n, prob = weights, replace = TRUE)
   idx <- cbind(seq_len(n), components)
   draws[idx]
 }
@@ -500,11 +523,14 @@ rmixmvnorm <- function(n, ...) {
   is_weights <- names(args) == "weights"
   params_list <- args[!is_weights]
   weights <- args[[which(is_weights)]]
-  args_list <- lapply(params_list, function(par) c(n, par))
-  sims <- lapply(args_list, function(par) do.call(rmvnorm, par))
-
-  components <- sample.int(length(sims), n, prob = weights, replace = TRUE)
-
+  args_list <- lapply(params_list, function(par)
+    c(n, par))
+  sims <- lapply(args_list, function(par)
+    do.call(rmvnorm, par))
+  
+  components <-
+    sample.int(length(sims), n, prob = weights, replace = TRUE)
+  
   # loop through the n observations, pulling out the corresponding slice
   draws_out <- array(NA, dim(sims[[1]]))
   for (i in seq_len(n)) {
@@ -515,10 +541,8 @@ rmixmvnorm <- function(n, ...) {
 
 # a form of two-sample chi squared test for discrete multivariate distributions
 combined_chisq_test <- function(x, y) {
-  stats::chisq.test(
-    x = colSums(x),
-    y = colSums(y)
-  )
+  stats::chisq.test(x = colSums(x),
+                    y = colSums(y))
 }
 
 # flatten unique part of a symmetric matrix
@@ -534,24 +558,25 @@ compare_iid_samples <- function(greta_fun,
                                 nsim = 200,
                                 p_value_threshold = 0.001) {
   greta_array <- do.call(greta_fun, parameters)
-
+  
   # get information about distribution
   distribution <- get_node(greta_array)$distribution
   multivariate <- distribution$multivariate
   discrete <- distribution$discrete
   name <- distribution$distribution_name
-
+  
   greta_samples <- calculate(greta_array, nsim = nsim)[[1]]
   r_samples <- do.call(r_fun, c(n = nsim, parameters))
-
+  
   # reshape to matrix or vector
   if (multivariate) {
-
     # if it's a symmetric matrix, take only a triangle and flatten it
     if (name %in% c("wishart", "lkj_correlation")) {
       include_diag <- name == "wishart"
-      t_greta_samples <- apply(greta_samples, 1, get_upper_tri, include_diag)
-      t_r_samples <- apply(r_samples, 1, get_upper_tri, include_diag)
+      t_greta_samples <-
+        apply(greta_samples, 1, get_upper_tri, include_diag)
+      t_r_samples <-
+        apply(r_samples, 1, get_upper_tri, include_diag)
       greta_samples <- t(t_greta_samples)
       r_samples <- t(t_r_samples)
     } else {
@@ -563,14 +588,14 @@ compare_iid_samples <- function(greta_fun,
   } else {
     greta_samples <- as.vector(greta_samples)
   }
-
+  
   # find a vaguely appropriate test
   if (discrete) {
     test <- ifelse(multivariate, combined_chisq_test, stats::chisq.test)
   } else {
     test <- ifelse(multivariate, cramer::cramer.test, stats::ks.test)
   }
-
+  
   # do Kolmogorov Smirnov test on samples
   suppressWarnings(test_result <- test(greta_samples, r_samples))
   testthat::expect_gte(test_result$p.value, p_value_threshold)
@@ -589,14 +614,17 @@ skip_if_not_release <- function() {
 # the two IID random number generators for the data generating function
 # ('p_theta' = generator for the prior, 'p_x_bar_theta' = generator for the
 # likelihood), 'niter' the number of MCMC samples to compare
-check_geweke <- function(sampler, model, data,
-                         p_theta, p_x_bar_theta,
-                         niter = 2000, warmup = 1000,
+check_geweke <- function(sampler,
+                         model,
+                         data,
+                         p_theta,
+                         p_x_bar_theta,
+                         niter = 2000,
+                         warmup = 1000,
                          title = "Geweke test") {
-
   # sample independently
   target_theta <- p_theta(niter)
-
+  
   # sample with Markov chain
   greta_theta <- p_theta_greta(
     niter = niter,
@@ -607,64 +635,66 @@ check_geweke <- function(sampler, model, data,
     sampler = sampler,
     warmup = warmup
   )
-
+  
   # visualise correspondence
   quants <- (1:99) / 100
   q1 <- stats::quantile(target_theta, quants)
   q2 <- stats::quantile(greta_theta, quants)
   plot(q2, q1, main = title)
   graphics::abline(0, 1)
-
+  
   # do a formal hypothesis test
-  suppressWarnings(stat <- stats::ks.test(target_theta, greta_theta))
+  suppressWarnings(stat <-
+                     stats::ks.test(target_theta, greta_theta))
   testthat::expect_gte(stat$p.value, 0.005)
 }
 
 # sample from a prior on theta the long way round, fro use in a Geweke test:
 # gibbs sampling the posterior p(theta | x) and the data generating function p(x
 # | theta). Only retain the samples of theta from the joint distribution,
-p_theta_greta <- function(niter, model, data,
-                          p_theta, p_x_bar_theta,
+p_theta_greta <- function(niter,
+                          model,
+                          data,
+                          p_theta,
+                          p_x_bar_theta,
                           sampler = hmc(),
                           warmup = 1000) {
-
   # set up and initialize trace
   theta <- rep(NA, niter)
   theta[1] <- p_theta(1)
-
+  
   # set up and tune sampler
-  draws <- mcmc(model,
+  draws <- mcmc(
+    model,
     warmup = warmup,
     n_samples = 1,
     chains = 1,
     sampler = sampler,
     verbose = FALSE
   )
-
+  
   # now loop through, sampling and updating x and returning theta
   for (i in 2:niter) {
-
     # sample x given theta
     x <- p_x_bar_theta(theta[i - 1])
-
+    
     # put x in the data list
     dag <- model$dag
     target_name <- dag$tf_name(get_node(data))
     x_array <- array(x, dim = c(1, dim(data)))
     dag$tf_environment$data_list[[target_name]] <- x_array
-
+    
     # put theta in the free state
     sampler <- attr(draws, "model_info")$samplers[[1]]
     sampler$free_state <- as.matrix(theta[i - 1])
-
+    
     draws <- extra_samples(draws,
-      n_samples = 1,
-      verbose = FALSE
-    )
-
+                           n_samples = 1,
+                           verbose = FALSE)
+    
     theta[i] <- tail(as.numeric(draws[[1]]), 1)
   }
-
+  
   theta
 }
 
@@ -672,15 +702,13 @@ p_theta_greta <- function(niter, model, data,
 
 not_finished <- function(draws, target_samples = 5000) {
   neff <- coda::effectiveSize(draws)
-  rhats <- coda::gelman.diag(
-    x = draws,
-    multivariate = FALSE,
-    autoburnin = FALSE
-  )
+  rhats <- coda::gelman.diag(x = draws,
+                             multivariate = FALSE,
+                             autoburnin = FALSE)
   rhats <- rhats$psrf[, 1]
   converged <- all(rhats < 1.01)
   enough_samples <- all(neff >= target_samples)
-  !(converged & enough_samples)
+  ! (converged & enough_samples)
 }
 
 new_samples <- function(draws, target_samples = 5000) {
@@ -702,24 +730,23 @@ get_enough_draws <- function(model,
                              one_by_one = FALSE) {
   start_time <- Sys.time()
   draws <- mcmc(model,
-    sampler = sampler,
-    verbose = verbose,
-    one_by_one = one_by_one
-  )
-
+                sampler = sampler,
+                verbose = verbose,
+                one_by_one = one_by_one)
+  
   while (not_finished(draws, n_effective) &
-    not_timed_out(start_time, time_limit)) {
+         not_timed_out(start_time, time_limit)) {
     n_samples <- new_samples(draws, n_effective)
-    draws <- extra_samples(draws, n_samples,
-      verbose = verbose,
-      one_by_one = one_by_one
-    )
+    draws <- extra_samples(draws,
+                           n_samples,
+                           verbose = verbose,
+                           one_by_one = one_by_one)
   }
-
+  
   if (not_finished(draws, n_effective)) {
     stop("could not draws enough effective samples within the time limit")
   }
-
+  
   draws
 }
 
@@ -728,24 +755,22 @@ mcse <- function(draws) {
   n <- nrow(draws)
   b <- floor(sqrt(n))
   a <- floor(n / b)
-
+  
   group <- function(k) {
     idx <- ((k - 1) * b + 1):(k * b)
     colMeans(draws[idx, , drop = FALSE])
   }
-
-  bm <- vapply(
-    seq_len(a),
-    group,
-    draws[1, ]
-  )
-
+  
+  bm <- vapply(seq_len(a),
+               group,
+               draws[1, ])
+  
   if (is.null(dim(bm))) {
     bm <- t(bm)
   }
-
+  
   mu_hat <- as.matrix(colMeans(draws))
-  ss <- sweep(t(bm), 2, mu_hat, "-")^2
+  ss <- sweep(t(bm), 2, mu_hat, "-") ^ 2
   var_hat <- b * colSums(ss) / (a - 1)
   sqrt(var_hat / n)
 }
@@ -761,35 +786,31 @@ scaled_error <- function(draws, expectation) {
 # given a sampler (e.g. hmc()) and minimum number of effective samples, ensure
 # that the sampler can draw correct samples from a bivariate normal distribution
 check_mvn_samples <- function(sampler, n_effective = 3000) {
-
   # get multivariate normal samples
   mu <- as_data(t(rnorm(2, 0, 5)))
   sigma <- stats::rWishart(1, 3, diag(2))[, , 1]
   x <- multivariate_normal(mu, sigma)
   m <- model(x, precision = "single")
-
+  
   draws <- get_enough_draws(m,
-    sampler = sampler,
-    n_effective = n_effective,
-    verbose = FALSE
-  )
-
+                            sampler = sampler,
+                            n_effective = n_effective,
+                            verbose = FALSE)
+  
   # get MCMC samples for statistics of the samples (value, variance and
   # correlation of error wrt mean)
   err <- x - mu
-  var <- (err)^2
+  var <- (err) ^ 2
   corr <- prod(err) / prod(sqrt(diag(sigma)))
   err_var_corr <- c(err, var, corr)
   stat_draws <- calculate(err_var_corr, values = draws)
-
+  
   # get true values of these - on average the error should be 0, and the
   # variance and correlation of the errors should encoded in Sigma
-  stat_truth <- c(
-    rep(0, 2),
-    diag(sigma),
-    cov2cor(sigma)[1, 2]
-  )
-
+  stat_truth <- c(rep(0, 2),
+                  diag(sigma),
+                  cov2cor(sigma)[1, 2])
+  
   # get absolute errors between posterior means and true values, and scale them
   # by time-series Monte Carlo standard errors (the expected amount of
   # uncertainty in the MCMC estimate), to give the number of standard errors
@@ -810,27 +831,28 @@ check_samples <- function(x,
                           title = NULL,
                           one_by_one = FALSE) {
   m <- model(x, precision = "single")
-  draws <- get_enough_draws(m,
+  draws <- get_enough_draws(
+    m,
     sampler = sampler,
     n_effective = n_effective,
     verbose = FALSE,
     one_by_one = one_by_one
   )
-
+  
   neff <- coda::effectiveSize(draws)
   iid_samples <- iid_function(neff)
   mcmc_samples <- as.matrix(draws)
-
+  
   # plot
   if (is.null(title)) {
     distrib <- get_node(x)$distribution$distribution_name
     sampler_name <- class(sampler)[1]
     title <- paste(distrib, "with", sampler_name)
   }
-
+  
   stats::qqplot(mcmc_samples, iid_samples, main = title)
   graphics::abline(0, 1)
-
+  
   # do a formal hypothesis test
   suppressWarnings(stat <- ks.test(mcmc_samples, iid_samples))
   testthat::expect_gte(stat$p.value, 0.01)
@@ -838,25 +860,41 @@ check_samples <- function(x,
 
 # zero inflated poisson using distributional
 
-zero_inflated_pois <- function(lambda,
-                               prob){
-    dist_inflated(
-      dist = dist_poisson(lambda = lambda),
-      prob = prob,
-      x = 0
-    )
+dist_zero_inflated_pois <- function(lambda, prob_zeros) {
+  dist_inflated(dist = dist_poisson(lambda = lambda),
+                prob = prob_zeros,
+                x = 0)
   
 }
 
-sample_zero_inflated_pois <- function(n, lambda, prob){
-  generate(x = zero_inflated_pois(lambda = lambda, prob = prob),
-           n)
+dist_zero_inflated_negative_binomial <-
+  function(size, prob, prob_zeros) {
+    distributional::dist_inflated(
+      dist = distributional::dist_negative_binomial(size = size,
+                                                    prob = prob),
+      prob = prob_zeros,
+      x = 0
+    )
+  }
+
+sample_zero_inflated_pois <- function(n, lambda, prob) {
+  distributional::generate(x = dist_zero_inflated_pois(lambda = lambda, prob = prob),
+                           n)[[1]]
 }
 
+sample_zero_inflated_neg_binomial <-
+  function(n, size, lambda, prob_zeros) {
+    distributional::generate(x = dist_zero_inflated_pois(lambda = lambda, prob = prob_zeros),
+                             n)[[1]]
+  }
+
 # zero-inflated distribution from rethinking package
-dzipois <- function(x , theta , lambda , log=FALSE ) {
-  ll <- ifelse( x==0 , theta + (1-theta)*exp(-lambda) , (1-theta)*dpois(x,lambda,FALSE) )
-  if(log){
+dzipois <- function(x , theta , lambda , log = FALSE) {
+  ll <-
+    ifelse(x == 0 ,
+           theta + (1 - theta) * exp(-lambda) ,
+           (1 - theta) * dpois(x, lambda, FALSE))
+  if (log) {
     return(log(ll))
   }
   else {
@@ -866,9 +904,5 @@ dzipois <- function(x , theta , lambda , log=FALSE ) {
 
 
 # zero-inflated negative binomial likelihood from likelihoodExplore package
-require(likelihoodExplore)
 dzinb <- function(x, theta, size, prob, log = FALSE)
-    return(liknbinom(x, size = size, prob = prob, log = log))
-
-
-
+  return(likelihoodExplore::liknbinom(x, size = size, prob = prob, log = log))
diff --git a/tests/testthat/test_zip_zinb.R b/tests/testthat/test_zip_zinb.R
index 0699f3a..16eb9c9 100644
--- a/tests/testthat/test_zip_zinb.R
+++ b/tests/testthat/test_zip_zinb.R
@@ -1,7 +1,9 @@
+source("helpers.R")
+
 test_that("zero inflated poisson distribution has correct density", {
 
   skip_if_not(check_tf_version())
-  source("helpers.R")
+  
 
   compare_distribution(zero_inflated_poisson,
                        extraDistr::dzip,
@@ -13,7 +15,6 @@ test_that("zero inflated poisson distribution has correct density", {
 test_that("zero inflated negative binomial distribution has correct density", {
 
   skip_if_not(check_tf_version())
-  source("helpers.R")
 
   compare_distribution(zero_inflated_negative_binomial,
                        extraDistr::dzinb,

From bf77c79cd479ff692ac2d920efb8725233018ad8 Mon Sep 17 00:00:00 2001
From: Nicholas Tierney <nicholas.tierney@gmail.com>
Date: Mon, 1 Aug 2022 16:18:14 +0800
Subject: [PATCH 04/19] add some missing internal functions

---
 R/internals.R        |   3 ++
 R/testthat-helpers.R | 123 +++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 126 insertions(+)
 create mode 100644 R/testthat-helpers.R

diff --git a/R/internals.R b/R/internals.R
index 60fe3d6..eb7b30c 100644
--- a/R/internals.R
+++ b/R/internals.R
@@ -24,3 +24,6 @@ fl <- .internals$utils$misc$fl
 tf_as_float <- .internals$tensors$tf_as_float
 tf_rowsums <- .internals$tensors$tf_rowsums
 op <- .internals$nodes$constructors$op
+
+get_node <- .internals$greta_arrays$get_node
+dag_class <- .internals$inference$dag_class
diff --git a/R/testthat-helpers.R b/R/testthat-helpers.R
new file mode 100644
index 0000000..e10a6bd
--- /dev/null
+++ b/R/testthat-helpers.R
@@ -0,0 +1,123 @@
+# an array of random standard normals with the specificed dims
+# e.g. randn(3, 2, 1)
+randn <- function(...) {
+  dim <- c(...)
+  array(stats::rnorm(prod(dim)), dim = dim)
+}
+
+# ditto for standard uniforms
+randu <- function(...) {
+  dim <- c(...)
+  array(stats::runif(prod(dim)), dim = dim)
+}
+
+# create a variable with the same dimensions as as_data(x)
+as_variable <- function(x) {
+  x <- as_2d_array(x)
+  variable(dim = dim(x))
+}
+
+
+# check a greta operation and the equivalent R operation give the same output
+# e.g. check_op(sum, randn(100, 3))
+check_op <- function(op, a, b, greta_op = NULL,
+                     other_args = list(),
+                     tolerance = 1e-3,
+                     only = c("data", "variable", "batched"),
+                     relative_error = FALSE) {
+  if (is.null(greta_op)) {
+    greta_op <- op
+  }
+  
+  r_out <- run_r_op(op, a, b, other_args)
+  
+  for (type in only) {
+    # compare with ops on data greta arrays
+    greta_out <- run_greta_op(greta_op, a, b, other_args, type)
+    compare_op(r_out, greta_out, tolerance, relative_error = relative_error)
+  }
+}
+
+compare_op <- function(r_out, greta_out, tolerance = 1e-4, relative_error = FALSE) {
+  if (relative_error){
+    difference <- as.vector(abs(r_out - greta_out) / abs(r_out))
+  } else if (!relative_error){
+    difference <- as.vector(abs(r_out - greta_out))
+  }
+  difference_lt_tolerance <- difference < tolerance
+  are_all_true <- all(difference_lt_tolerance)
+  are_all_true
+  testthat::expect_true(are_all_true)
+}
+
+run_r_op <- function(op, a, b, other_args) {
+  arg_list <- list(a)
+  if (!missing(b)) {
+    arg_list <- c(arg_list, list(b))
+  }
+  arg_list <- c(arg_list, other_args)
+  do.call(op, arg_list)
+}
+
+run_greta_op <- function(greta_op, a, b, other_args,
+                         type = c("data", "variable", "batched")) {
+  type <- match.arg(type)
+  
+  converter <- switch(type,
+                      data = as_data,
+                      variable = as_variable,
+                      batched = as_variable
+  )
+  
+  g_a <- converter(a)
+  
+  arg_list <- list(g_a)
+  values <- list(g_a = a)
+  
+  if (!missing(b)) {
+    g_b <- converter(b)
+    arg_list <- c(arg_list, list(g_b))
+    values <- c(values, list(g_b = b))
+  }
+  
+  arg_list <- c(arg_list, other_args)
+  out <- do.call(greta_op, arg_list)
+  
+  if (type == "data") {
+    # data greta arrays should provide their own values
+    result <- calculate(out, values = list())[[1]]
+  } else if (type == "variable") {
+    result <- grab_via_free_state(out, values)
+  } else if (type == "batched") {
+    result <- grab_via_free_state(out, values, batches = 3)
+  } else {
+    result <- calculate(out, values = values)[[1]]
+  }
+  
+  result
+}
+
+# get the value of the target greta array, by passing values for the named
+# variable greta arrays via the free state parameter, optionally with batches
+grab_via_free_state <- function(target, values, batches = 1) {
+  dag <- dag_class$new(list(target))
+  dag$define_tf()
+  inits <- do.call(initials, values)
+  inits_flat <- prep_initials(inits, 1, dag)[[1]]
+  if (batches > 1) {
+    inits_list <- replicate(batches, inits_flat, simplify = FALSE)
+    inits_flat <- do.call(rbind, inits_list)
+    vals <- dag$trace_values(inits_flat)[1, ]
+  } else {
+    vals <- dag$trace_values(inits_flat)
+  }
+  array(vals, dim = dim(target))
+}
+
+expect_ok <- function(expr) {
+  testthat::expect_error(expr, NA)
+}
+
+is.greta_array <- function(x) { # nolint
+  inherits(x, "greta_array")
+}

From d4d39bab219e75c97340e8b4c1d47f7d30cabac8 Mon Sep 17 00:00:00 2001
From: Nicholas Tierney <nicholas.tierney@gmail.com>
Date: Mon, 1 Aug 2022 16:18:46 +0800
Subject: [PATCH 05/19] swap theta and lambda arguments to match
 extraDistr::dzip args

---
 R/zero_inflated_poisson.R      |  16 +-
 tests/testthat/helpers.R       | 391 +++++++++++++++++++--------------
 tests/testthat/test_zip_zinb.R |  24 +-
 3 files changed, 242 insertions(+), 189 deletions(-)

diff --git a/R/zero_inflated_poisson.R b/R/zero_inflated_poisson.R
index 1853406..d20b138 100644
--- a/R/zero_inflated_poisson.R
+++ b/R/zero_inflated_poisson.R
@@ -3,13 +3,13 @@
 #'
 #' @description A zero inflated poisson distribution.
 #'
-#' @param theta proportion of zeros
 #' @param lambda rate parameter
+#' @param theta proportion of zeros
 #' @param dim a scalar giving the number of rows in the resulting greta array
 #' @importFrom R6 R6Class
 #' @export
-zero_inflated_poisson <- function (theta, lambda, dim = NULL) {
-  distrib('zero_inflated_poisson', theta, lambda, dim)
+zero_inflated_poisson <- function (lambda, theta, dim = NULL) {
+  distrib('zero_inflated_poisson', lambda, theta, dim)
 }
 
 #' @importFrom R6 R6Class
@@ -17,19 +17,19 @@ zero_inflated_poisson_distribution <- R6::R6Class(
   classname = "zero_inflated_poisson_distribution",
   inherit = distribution_node,
   public = list(
-    initialize = function(theta, lambda, dim) {
-      theta <- as.greta_array(theta)
+    initialize = function(lambda, theta, dim) {
       lambda <- as.greta_array(lambda)
+      theta <- as.greta_array(theta)
       # add the nodes as children and parameters
-      dim <- check_dims(theta, lambda, target_dim = dim)
+      dim <- check_dims(lambda, theta, target_dim = dim)
       super$initialize("zero_inflated_poisson", dim, discrete = TRUE)
-      self$add_parameter(theta, "theta")
       self$add_parameter(lambda, "lambda")
+      self$add_parameter(theta, "theta")
     },
     
     tf_distrib = function(parameters, dag) {
-      theta <- parameters$theta
       lambda <- parameters$lambda
+      theta <- parameters$theta
       log_prob <- function(x) {
         tf$math$log(
           theta * 
diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R
index 46f2537..b900954 100644
--- a/tests/testthat/helpers.R
+++ b/tests/testthat/helpers.R
@@ -16,13 +16,13 @@ grab <- function(x, dag = NULL) {
   if (inherits(x, "node")) {
     x <- as.greta_array(x)
   }
-  
+
   if (inherits(x, "greta_array")) {
     node <- get_node(x)
     dag <- dag_class$new(list(x))
     dag$define_tf()
   }
-  
+
   dag$set_tf_data_list("batch_size", 1L)
   dag$build_feed_dict()
   out <- dag$tf_sess_run(dag$tf_name(node), as_text = TRUE)
@@ -44,11 +44,11 @@ set_distribution <- function(dist, data) {
 get_density <- function(distrib, data) {
   x <- as_data(data)
   distribution(x) <- distrib
-  
+
   # create dag and define the density
   dag <- dag_class$new(list(x))
   get_node(x)$distribution$define_tf(dag)
-  
+
   # get the log density as a vector
   tensor_name <- dag$tf_name(get_node(distrib)$distribution)
   tensor <- get(tensor_name, envir = dag$tf_environment)
@@ -69,13 +69,15 @@ compare_distribution <- function(greta_fun,
   # both of these functions must take the same parameters in the same order
   # 'parameters' is an optionally named list of numeric parameter values
   # x is the vector of values at which to evaluate the log density
-  
+
   # define greta distribution, with fixed values
-  greta_log_density <- greta_density(greta_fun, parameters, x,
-                                     dim, multivariate)
+  greta_log_density <- greta_density(
+    greta_fun, parameters, x,
+    dim, multivariate
+  )
   # get R version
   r_log_density <- log(do.call(r_fun, c(list(x), parameters)))
-  
+
   # return absolute difference
   compare_op(r_log_density, greta_log_density, tolerance)
 }
@@ -90,44 +92,44 @@ greta_density <- function(fun,
   if (is.null(dim)) {
     dim <- NROW(x)
   }
-  
+
   # add the output dimension to the arguments list
   dim_list <- list(dim = dim)
-  
+
   # if it's a multivariate distribution name it n_realisations
   if (multivariate) {
     names(dim_list) <- "n_realisations"
   }
-  
+
   # don't add it for wishart & lkj, which don't mave multiple realisations
   is_wishart <- identical(names(parameters), c("df", "Sigma"))
   is_lkj <- identical(names(parameters), c("eta", "dimension"))
   if (is_wishart | is_lkj) {
     dim_list <- list()
   }
-  
+
   parameters <- c(parameters, dim_list)
-  
+
   # evaluate greta distribution
   dist <- do.call(fun, parameters)
   distrib_node <- get_node(dist)$distribution
-  
+
   # set density
   x_ <- as.greta_array(x)
   distrib_node$remove_target()
   distrib_node$add_target(get_node(x_))
-  
+
   # create dag
   dag <- dag_class$new(list(x_))
   dag$define_tf()
   dag$set_tf_data_list("batch_size", 1L)
   dag$build_feed_dict()
-  
+
   # get the log density as a vector
   dag$on_graph(result <-
-                 dag$evaluate_density(distrib_node, get_node(x_)))
+    dag$evaluate_density(distrib_node, get_node(x_)))
   assign("test_density", result, dag$tf_environment)
-  
+
   density <- dag$tf_sess_run(test_density)
   as.vector(density)
 }
@@ -140,23 +142,28 @@ with_greta <- function(call,
                        swap = c("x"),
                        swap_scope = 1) {
   swap_entries <- paste0(swap, " = as_data(", swap, ")")
-  swap_text <- paste0("list(",
-                      paste(swap_entries, collapse = ", "),
-                      ")")
+  swap_text <- paste0(
+    "list(",
+    paste(swap_entries, collapse = ", "),
+    ")"
+  )
   swap_list <- eval(parse(text = swap_text),
-                    envir = parent.frame(n = swap_scope))
-  
-  greta_result <- with(swap_list,
-                       eval(call))
+    envir = parent.frame(n = swap_scope)
+  )
+
+  greta_result <- with(
+    swap_list,
+    eval(call)
+  )
   result <- grab(greta_result)
-  
+
   # account for the fact that greta outputs are 1D arrays; convert them back to
   # R vectors
   if (is.array(result) &&
-      length(dim(result)) == 2 && dim(result)[2] == 1) {
+    length(dim(result)) == 2 && dim(result)[2] == 1) {
     result <- as.vector(result)
   }
-  
+
   result
 }
 
@@ -167,12 +174,13 @@ check_expr <- function(expr,
                        swap = c("x"),
                        tolerance = 1e-4) {
   call <- substitute(expr)
-  
+
   r_out <- eval(expr)
   greta_out <- with_greta(call,
-                          swap = swap,
-                          swap_scope = 2)
-  
+    swap = swap,
+    swap_scope = 2
+  )
+
   compare_op(r_out, greta_out, tolerance)
 }
 
@@ -193,9 +201,9 @@ gen_opfun <- function(n, ops) {
   for (i in seq_len(n)) {
     string <- add_op_string(string, ops = ops)
   }
-  
+
   fun_string <- sprintf("function(a, b) {%s}", string)
-  
+
   eval(parse(text = fun_string))
 }
 
@@ -208,12 +216,13 @@ sample_distribution <- function(greta_array,
                                 warmup = 1) {
   m <- model(greta_array, precision = "double")
   draws <- mcmc(m,
-                n_samples = n,
-                warmup = warmup,
-                verbose = FALSE)
+    n_samples = n,
+    warmup = warmup,
+    verbose = FALSE
+  )
   samples <- as.matrix(draws)
   vectorised <- length(lower) > 1 | length(upper) > 1
-  
+
   if (vectorised) {
     above_lower <- sweep(samples, 2, lower, `>=`)
     below_upper <- sweep(samples, 2, upper, `<=`)
@@ -221,7 +230,7 @@ sample_distribution <- function(greta_array,
     above_lower <- samples >= lower
     below_upper <- samples <= upper
   }
-  
+
   expect_true(all(above_lower & below_upper))
 }
 
@@ -235,27 +244,29 @@ compare_truncated_distribution <- function(greta_fun,
   # is a greta array created from a distribution and a constrained variable
   # greta array. 'r_fun' is an r function returning the log density for the same
   # truncated distribution, taking x as its only argument.
-  
-  x <- do.call(truncdist::rtrunc,
-               c(
-                 n = 100,
-                 spec = which,
-                 a = truncation[1],
-                 b = truncation[2],
-                 parameters
-               ))
-  
+
+  x <- do.call(
+    truncdist::rtrunc,
+    c(
+      n = 100,
+      spec = which,
+      a = truncation[1],
+      b = truncation[2],
+      parameters
+    )
+  )
+
   # create truncated R function and evaluate it
   r_fun <- truncfun(which, parameters, truncation)
   r_log_density <- log(r_fun(x))
-  
+
   greta_log_density <- greta_density(
     fun = greta_fun,
     parameters = c(parameters, list(truncation = truncation)),
     x = x,
     dim = 1
   )
-  
+
   # return absolute difference
   compare_op(r_log_density, greta_log_density, tolerance)
 }
@@ -263,11 +274,13 @@ compare_truncated_distribution <- function(greta_fun,
 # use the truncdist package to crete a truncated distribution function for use
 # in compare_truncated_distribution
 truncfun <- function(which = "norm", parameters, truncation) {
-  args <- c(spec = which,
-            a = truncation[1],
-            b = truncation[2],
-            parameters)
-  
+  args <- c(
+    spec = which,
+    a = truncation[1],
+    b = truncation[2],
+    parameters
+  )
+
   function(x) {
     arg_list <- c(x = list(x), args)
     do.call(truncdist::dtrunc, arg_list)
@@ -315,8 +328,9 @@ get_output <- function(expr) {
   i <- 0
   suppressMessages(withCallingHandlers(
     expr,
-    message = function(e)
+    message = function(e) {
       msgs[[i <<- i + 1]] <<- conditionMessage(e)
+    }
   ))
   paste0(msgs, collapse = "")
 }
@@ -324,9 +338,10 @@ get_output <- function(expr) {
 # mock up mcmc progress bar output for neurotic testing
 mock_mcmc <- function(n_samples = 1010) {
   pb <- create_progress_bar("sampling",
-                            c(0, n_samples),
-                            pb_update = 10,
-                            width = 50)
+    c(0, n_samples),
+    pb_update = 10,
+    width = 50
+  )
   iterate_progress_bar(pb, n_samples, rejects = 10, chains = 1)
 }
 
@@ -336,15 +351,15 @@ rlkjcorr <- function(n, eta = 1, dimension = 2) {
   k <- dimension
   stopifnot(is.numeric(k), k >= 2, k == as.integer(k))
   stopifnot(eta > 0)
-  
+
   f <- function() {
     alpha <- eta + (k - 2) / 2
     r12 <- 2 * stats::rbeta(1, alpha, alpha) - 1
     r <- matrix(0, k, k)
     r[1, 1] <- 1
     r[1, 2] <- r12
-    r[2, 2] <- sqrt(1 - r12 ^ 2)
-    
+    r[2, 2] <- sqrt(1 - r12^2)
+
     if (k > 2) {
       for (m in 2:(k - 1)) {
         alpha <- alpha - 0.5
@@ -355,18 +370,18 @@ rlkjcorr <- function(n, eta = 1, dimension = 2) {
         r[m + 1, m + 1] <- sqrt(1 - y)
       }
     }
-    
+
     crossprod(r)
   }
-  
+
   r <- replicate(n, f())
-  
+
   if (dim(r)[3] == 1) {
     r <- r[, , 1]
   } else {
     r <- aperm(r, c(3, 1, 2))
   }
-  
+
   r
 }
 
@@ -440,27 +455,32 @@ rtf <- function(n, df1, df2, truncation) {
 joint_normals <- function(...) {
   params_list <- list(...)
   components <-
-    lapply(params_list, function(par)
-      do.call(normal, par))
+    lapply(params_list, function(par) {
+      do.call(normal, par)
+    })
   do.call(joint, components)
 }
 
 rjnorm <- function(n, ...) {
   params_list <- list(...)
-  args_list <- lapply(params_list, function(par)
-    c(n, par))
+  args_list <- lapply(params_list, function(par) {
+    c(n, par)
+  })
   sims <-
-    lapply(args_list, function(par)
-      do.call(stats::rnorm, par))
+    lapply(args_list, function(par) {
+      do.call(stats::rnorm, par)
+    })
   do.call(cbind, sims)
 }
 
 rjtnorm <- function(n, ...) {
   params_list <- list(...)
-  args_list <- lapply(params_list, function(par)
-    c(n, par))
-  sims <- lapply(args_list, function(par)
-    do.call(rtnorm, par))
+  args_list <- lapply(params_list, function(par) {
+    c(n, par)
+  })
+  sims <- lapply(args_list, function(par) {
+    do.call(rtnorm, par)
+  })
   do.call(cbind, sims)
 }
 
@@ -470,8 +490,9 @@ mixture_normals <- function(...) {
   is_weights <- names(args) == "weights"
   params_list <- args[!is_weights]
   components <-
-    lapply(params_list, function(par)
-      do.call(normal, par))
+    lapply(params_list, function(par) {
+      do.call(normal, par)
+    })
   do.call(mixture, c(components, args[is_weights]))
 }
 
@@ -479,10 +500,12 @@ mixture_multivariate_normals <- function(...) {
   args <- list(...)
   is_weights <- names(args) == "weights"
   params_list <- args[!is_weights]
-  components <- lapply(params_list,
-                       function(par) {
-                         do.call(multivariate_normal, par)
-                       })
+  components <- lapply(
+    params_list,
+    function(par) {
+      do.call(multivariate_normal, par)
+    }
+  )
   do.call(mixture, c(components, args[is_weights]))
 }
 
@@ -491,10 +514,12 @@ rmixnorm <- function(n, ...) {
   is_weights <- names(args) == "weights"
   params_list <- args[!is_weights]
   weights <- args[[which(is_weights)]]
-  args_list <- lapply(params_list, function(par)
-    c(n, par))
-  sims <- lapply(args_list, function(par)
-    do.call(rnorm, par))
+  args_list <- lapply(params_list, function(par) {
+    c(n, par)
+  })
+  sims <- lapply(args_list, function(par) {
+    do.call(rnorm, par)
+  })
   draws <- do.call(cbind, sims)
   components <-
     sample.int(length(sims), n, prob = weights, replace = TRUE)
@@ -507,10 +532,12 @@ rmixtnorm <- function(n, ...) {
   is_weights <- names(args) == "weights"
   params_list <- args[!is_weights]
   weights <- args[[which(is_weights)]]
-  args_list <- lapply(params_list, function(par)
-    c(n, par))
-  sims <- lapply(args_list, function(par)
-    do.call(rtnorm, par))
+  args_list <- lapply(params_list, function(par) {
+    c(n, par)
+  })
+  sims <- lapply(args_list, function(par) {
+    do.call(rtnorm, par)
+  })
   draws <- do.call(cbind, sims)
   components <-
     sample.int(length(sims), n, prob = weights, replace = TRUE)
@@ -523,14 +550,16 @@ rmixmvnorm <- function(n, ...) {
   is_weights <- names(args) == "weights"
   params_list <- args[!is_weights]
   weights <- args[[which(is_weights)]]
-  args_list <- lapply(params_list, function(par)
-    c(n, par))
-  sims <- lapply(args_list, function(par)
-    do.call(rmvnorm, par))
-  
+  args_list <- lapply(params_list, function(par) {
+    c(n, par)
+  })
+  sims <- lapply(args_list, function(par) {
+    do.call(rmvnorm, par)
+  })
+
   components <-
     sample.int(length(sims), n, prob = weights, replace = TRUE)
-  
+
   # loop through the n observations, pulling out the corresponding slice
   draws_out <- array(NA, dim(sims[[1]]))
   for (i in seq_len(n)) {
@@ -541,8 +570,10 @@ rmixmvnorm <- function(n, ...) {
 
 # a form of two-sample chi squared test for discrete multivariate distributions
 combined_chisq_test <- function(x, y) {
-  stats::chisq.test(x = colSums(x),
-                    y = colSums(y))
+  stats::chisq.test(
+    x = colSums(x),
+    y = colSums(y)
+  )
 }
 
 # flatten unique part of a symmetric matrix
@@ -558,16 +589,16 @@ compare_iid_samples <- function(greta_fun,
                                 nsim = 200,
                                 p_value_threshold = 0.001) {
   greta_array <- do.call(greta_fun, parameters)
-  
+
   # get information about distribution
   distribution <- get_node(greta_array)$distribution
   multivariate <- distribution$multivariate
   discrete <- distribution$discrete
   name <- distribution$distribution_name
-  
+
   greta_samples <- calculate(greta_array, nsim = nsim)[[1]]
   r_samples <- do.call(r_fun, c(n = nsim, parameters))
-  
+
   # reshape to matrix or vector
   if (multivariate) {
     # if it's a symmetric matrix, take only a triangle and flatten it
@@ -588,14 +619,14 @@ compare_iid_samples <- function(greta_fun,
   } else {
     greta_samples <- as.vector(greta_samples)
   }
-  
+
   # find a vaguely appropriate test
   if (discrete) {
     test <- ifelse(multivariate, combined_chisq_test, stats::chisq.test)
   } else {
     test <- ifelse(multivariate, cramer::cramer.test, stats::ks.test)
   }
-  
+
   # do Kolmogorov Smirnov test on samples
   suppressWarnings(test_result <- test(greta_samples, r_samples))
   testthat::expect_gte(test_result$p.value, p_value_threshold)
@@ -624,7 +655,7 @@ check_geweke <- function(sampler,
                          title = "Geweke test") {
   # sample independently
   target_theta <- p_theta(niter)
-  
+
   # sample with Markov chain
   greta_theta <- p_theta_greta(
     niter = niter,
@@ -635,17 +666,17 @@ check_geweke <- function(sampler,
     sampler = sampler,
     warmup = warmup
   )
-  
+
   # visualise correspondence
   quants <- (1:99) / 100
   q1 <- stats::quantile(target_theta, quants)
   q2 <- stats::quantile(greta_theta, quants)
   plot(q2, q1, main = title)
   graphics::abline(0, 1)
-  
+
   # do a formal hypothesis test
   suppressWarnings(stat <-
-                     stats::ks.test(target_theta, greta_theta))
+    stats::ks.test(target_theta, greta_theta))
   testthat::expect_gte(stat$p.value, 0.005)
 }
 
@@ -662,7 +693,7 @@ p_theta_greta <- function(niter,
   # set up and initialize trace
   theta <- rep(NA, niter)
   theta[1] <- p_theta(1)
-  
+
   # set up and tune sampler
   draws <- mcmc(
     model,
@@ -672,29 +703,30 @@ p_theta_greta <- function(niter,
     sampler = sampler,
     verbose = FALSE
   )
-  
+
   # now loop through, sampling and updating x and returning theta
   for (i in 2:niter) {
     # sample x given theta
     x <- p_x_bar_theta(theta[i - 1])
-    
+
     # put x in the data list
     dag <- model$dag
     target_name <- dag$tf_name(get_node(data))
     x_array <- array(x, dim = c(1, dim(data)))
     dag$tf_environment$data_list[[target_name]] <- x_array
-    
+
     # put theta in the free state
     sampler <- attr(draws, "model_info")$samplers[[1]]
     sampler$free_state <- as.matrix(theta[i - 1])
-    
+
     draws <- extra_samples(draws,
-                           n_samples = 1,
-                           verbose = FALSE)
-    
+      n_samples = 1,
+      verbose = FALSE
+    )
+
     theta[i] <- tail(as.numeric(draws[[1]]), 1)
   }
-  
+
   theta
 }
 
@@ -702,13 +734,15 @@ p_theta_greta <- function(niter,
 
 not_finished <- function(draws, target_samples = 5000) {
   neff <- coda::effectiveSize(draws)
-  rhats <- coda::gelman.diag(x = draws,
-                             multivariate = FALSE,
-                             autoburnin = FALSE)
+  rhats <- coda::gelman.diag(
+    x = draws,
+    multivariate = FALSE,
+    autoburnin = FALSE
+  )
   rhats <- rhats$psrf[, 1]
   converged <- all(rhats < 1.01)
   enough_samples <- all(neff >= target_samples)
-  ! (converged & enough_samples)
+  !(converged & enough_samples)
 }
 
 new_samples <- function(draws, target_samples = 5000) {
@@ -730,23 +764,25 @@ get_enough_draws <- function(model,
                              one_by_one = FALSE) {
   start_time <- Sys.time()
   draws <- mcmc(model,
-                sampler = sampler,
-                verbose = verbose,
-                one_by_one = one_by_one)
-  
+    sampler = sampler,
+    verbose = verbose,
+    one_by_one = one_by_one
+  )
+
   while (not_finished(draws, n_effective) &
-         not_timed_out(start_time, time_limit)) {
+    not_timed_out(start_time, time_limit)) {
     n_samples <- new_samples(draws, n_effective)
     draws <- extra_samples(draws,
-                           n_samples,
-                           verbose = verbose,
-                           one_by_one = one_by_one)
+      n_samples,
+      verbose = verbose,
+      one_by_one = one_by_one
+    )
   }
-  
+
   if (not_finished(draws, n_effective)) {
     stop("could not draws enough effective samples within the time limit")
   }
-  
+
   draws
 }
 
@@ -755,22 +791,24 @@ mcse <- function(draws) {
   n <- nrow(draws)
   b <- floor(sqrt(n))
   a <- floor(n / b)
-  
+
   group <- function(k) {
     idx <- ((k - 1) * b + 1):(k * b)
     colMeans(draws[idx, , drop = FALSE])
   }
-  
-  bm <- vapply(seq_len(a),
-               group,
-               draws[1, ])
-  
+
+  bm <- vapply(
+    seq_len(a),
+    group,
+    draws[1, ]
+  )
+
   if (is.null(dim(bm))) {
     bm <- t(bm)
   }
-  
+
   mu_hat <- as.matrix(colMeans(draws))
-  ss <- sweep(t(bm), 2, mu_hat, "-") ^ 2
+  ss <- sweep(t(bm), 2, mu_hat, "-")^2
   var_hat <- b * colSums(ss) / (a - 1)
   sqrt(var_hat / n)
 }
@@ -791,26 +829,29 @@ check_mvn_samples <- function(sampler, n_effective = 3000) {
   sigma <- stats::rWishart(1, 3, diag(2))[, , 1]
   x <- multivariate_normal(mu, sigma)
   m <- model(x, precision = "single")
-  
+
   draws <- get_enough_draws(m,
-                            sampler = sampler,
-                            n_effective = n_effective,
-                            verbose = FALSE)
-  
+    sampler = sampler,
+    n_effective = n_effective,
+    verbose = FALSE
+  )
+
   # get MCMC samples for statistics of the samples (value, variance and
   # correlation of error wrt mean)
   err <- x - mu
-  var <- (err) ^ 2
+  var <- (err)^2
   corr <- prod(err) / prod(sqrt(diag(sigma)))
   err_var_corr <- c(err, var, corr)
   stat_draws <- calculate(err_var_corr, values = draws)
-  
+
   # get true values of these - on average the error should be 0, and the
   # variance and correlation of the errors should encoded in Sigma
-  stat_truth <- c(rep(0, 2),
-                  diag(sigma),
-                  cov2cor(sigma)[1, 2])
-  
+  stat_truth <- c(
+    rep(0, 2),
+    diag(sigma),
+    cov2cor(sigma)[1, 2]
+  )
+
   # get absolute errors between posterior means and true values, and scale them
   # by time-series Monte Carlo standard errors (the expected amount of
   # uncertainty in the MCMC estimate), to give the number of standard errors
@@ -838,21 +879,21 @@ check_samples <- function(x,
     verbose = FALSE,
     one_by_one = one_by_one
   )
-  
+
   neff <- coda::effectiveSize(draws)
   iid_samples <- iid_function(neff)
   mcmc_samples <- as.matrix(draws)
-  
+
   # plot
   if (is.null(title)) {
     distrib <- get_node(x)$distribution$distribution_name
     sampler_name <- class(sampler)[1]
     title <- paste(distrib, "with", sampler_name)
   }
-  
+
   stats::qqplot(mcmc_samples, iid_samples, main = title)
   graphics::abline(0, 1)
-  
+
   # do a formal hypothesis test
   suppressWarnings(stat <- ks.test(mcmc_samples, iid_samples))
   testthat::expect_gte(stat$p.value, 0.01)
@@ -861,48 +902,56 @@ check_samples <- function(x,
 # zero inflated poisson using distributional
 
 dist_zero_inflated_pois <- function(lambda, prob_zeros) {
-  dist_inflated(dist = dist_poisson(lambda = lambda),
-                prob = prob_zeros,
-                x = 0)
-  
+  distributional::dist_inflated(
+    dist = distributional::dist_poisson(lambda = lambda),
+    prob = prob_zeros,
+    x = 0
+  )
 }
 
 dist_zero_inflated_negative_binomial <-
   function(size, prob, prob_zeros) {
     distributional::dist_inflated(
-      dist = distributional::dist_negative_binomial(size = size,
-                                                    prob = prob),
+      dist = distributional::dist_negative_binomial(
+        size = size,
+        prob = prob
+      ),
       prob = prob_zeros,
       x = 0
     )
   }
 
 sample_zero_inflated_pois <- function(n, lambda, prob) {
-  distributional::generate(x = dist_zero_inflated_pois(lambda = lambda, prob = prob),
-                           n)[[1]]
+  distributional::generate(
+    x = dist_zero_inflated_pois(lambda = lambda, prob = prob),
+    n
+  )[[1]]
 }
 
 sample_zero_inflated_neg_binomial <-
   function(n, size, lambda, prob_zeros) {
-    distributional::generate(x = dist_zero_inflated_pois(lambda = lambda, prob = prob_zeros),
-                             n)[[1]]
+    distributional::generate(
+      x = dist_zero_inflated_pois(lambda = lambda, prob = prob_zeros),
+      n
+    )[[1]]
   }
 
 # zero-inflated distribution from rethinking package
-dzipois <- function(x , theta , lambda , log = FALSE) {
+dzipois <- function(x, theta, lambda, log = FALSE) {
   ll <-
-    ifelse(x == 0 ,
-           theta + (1 - theta) * exp(-lambda) ,
-           (1 - theta) * dpois(x, lambda, FALSE))
+    ifelse(x == 0,
+      theta + (1 - theta) * exp(-lambda),
+      (1 - theta) * dpois(x, lambda, FALSE)
+    )
   if (log) {
     return(log(ll))
-  }
-  else {
+  } else {
     return(ll)
   }
 }
 
 
 # zero-inflated negative binomial likelihood from likelihoodExplore package
-dzinb <- function(x, theta, size, prob, log = FALSE)
+dzinb <- function(x, theta, size, prob, log = FALSE) {
   return(likelihoodExplore::liknbinom(x, size = size, prob = prob, log = log))
+}
diff --git a/tests/testthat/test_zip_zinb.R b/tests/testthat/test_zip_zinb.R
index 16eb9c9..fa2da81 100644
--- a/tests/testthat/test_zip_zinb.R
+++ b/tests/testthat/test_zip_zinb.R
@@ -4,21 +4,25 @@ test_that("zero inflated poisson distribution has correct density", {
 
   skip_if_not(check_tf_version())
   
-
-  compare_distribution(zero_inflated_poisson,
-                       extraDistr::dzip,
-                       parameters = list(theta = 0.2, lambda = 2, pi = 0.2),
-                       x = sample_zero_inflated_pois(100, 2, 0.2))
+  compare_distribution(greta_fun = zero_inflated_poisson,
+                       r_fun = extraDistr::dzip,
+                       parameters = list(2, 0.2),
+                       x = sample_zero_inflated_pois(
+                         n = 100, 
+                         lambda = 2, 
+                         prob = 0.2)
+                       )
 
 })
 
 test_that("zero inflated negative binomial distribution has correct density", {
-
   skip_if_not(check_tf_version())
 
   compare_distribution(zero_inflated_negative_binomial,
-                       extraDistr::dzinb,
-                       parameters = list(theta = 2, size = 10, prob = 0.1, pi = 0.2),
-                       x = extraDistr::rzinb(100, 10, 0.1, 0.2))
-
+    extraDistr::dzinb,
+    parameters = list(theta = 2, size = 10, prob = 0.1, pi = 0.2),
+    x = extraDistr::rzinb(
+      n = 100, size = 10, prob = 0.1, pi = 0.2
+    )
+  )
 })

From 10c1af713d6ca6087025cef51dae1e59bb27a9f5 Mon Sep 17 00:00:00 2001
From: Nicholas Tierney <nicholas.tierney@gmail.com>
Date: Mon, 1 Aug 2022 16:28:30 +0800
Subject: [PATCH 06/19] the greta_log_density is currently returning a vector
 with NaN values, need to explore how the distribution is defined and why this
 could be happening from the TF end

---
 R/zero_inflated_negative_binomial.R | 33 ++++++++++++-----------------
 tests/testthat/test_zip_zinb.R      |  7 +++---
 2 files changed, 18 insertions(+), 22 deletions(-)

diff --git a/R/zero_inflated_negative_binomial.R b/R/zero_inflated_negative_binomial.R
index 88bd374..377b82a 100644
--- a/R/zero_inflated_negative_binomial.R
+++ b/R/zero_inflated_negative_binomial.R
@@ -1,56 +1,52 @@
 #' @name zero_inflated_negative_binomial
 #' @title Zero Inflated Negative Binomial
 #' @description A Zero Inflated Negative Binomial distribution
-#' @param theta proportion of zeros
 #' @param size positive integer parameter
 #' @param prob probability parameter (`0 < prob < 1`),
+#' @param theta proportion of zeros
 #' @param dim a scalar giving the number of rows in the resulting greta array
 #' @export
-zero_inflated_negative_binomial <-
-  function (theta, size, prob, dim = NULL) {
-    distrib('zero_inflated_negative_binomial', theta, size, prob, dim)
-  }
+zero_inflated_negative_binomial <- function(size, prob, theta, dim = NULL) {
+  distrib("zero_inflated_negative_binomial", size, prob, theta, dim)
+}
 
 zero_inflated_negative_binomial_distribution <- R6::R6Class(
   "zero_inflated_negative_binomial_distribution",
-  inherit = greta::.internals$nodes$node_classes$distribution_node,
+  inherit = distribution_node,
   public = list(
-    initialize = function(theta, size, prob, dim) {
-      theta <- as.greta_array(theta)
+    initialize = function(size, prob, theta, dim) {
       size <- as.greta_array(size)
       prob <- as.greta_array(prob)
+      theta <- as.greta_array(theta)
       # add the nodes as children and parameters
-      dim <- check_dims(theta, size, prob, target_dim = dim)
+      dim <- check_dims(size, prob, theta, target_dim = dim)
       super$initialize("zero_inflated_negative_binomial", dim, discrete = TRUE)
-      self$add_parameter(theta, "theta")
       self$add_parameter(size, "size")
       self$add_parameter(prob, "prob")
+      self$add_parameter(theta, "theta")
     },
-    
     tf_distrib = function(parameters, dag) {
-      theta <- parameters$theta
       size <- parameters$size
       p <- parameters$prob # probability of success
+      theta <- parameters$theta
       q <- fl(1) - parameters$prob
       log_prob <- function(x) {
         tf$math$log(
           theta * tf$nn$relu(fl(1) - x) + (fl(1) - theta) * tf$pow(p, size) * tf$pow(q, x) * tf$exp(tf$math$lgamma(x + size)) / tf$exp(tf$math$lgamma(size)) / tf$exp(tf$math$lgamma(x + fl(1)))
         )
-        
       }
-      
+
       sample <- function(seed) {
         binom <- tfp$distributions$Binomial(total_count = 1, probs = theta)
         negbin <-
           tfp$distributions$NegativeBinomial(total_count = size, probs = q) # change of proba / parametrisation in 'stats'
-        
+
         zi <- binom$sample(seed = seed)
         lbd <- negbin$sample(seed = seed)
-        
+
         (fl(1) - zi) * lbd
-        
       }
-      
+
       list(
         log_prob = log_prob,
         sample = sample,
@@ -58,7 +54,6 @@ zero_inflated_negative_binomial_distribution <- R6::R6Class(
         log_cdf = NULL
       )
     },
-    
     tf_cdf_function = NULL,
     tf_log_cdf_function = NULL
   )
diff --git a/tests/testthat/test_zip_zinb.R b/tests/testthat/test_zip_zinb.R
index fa2da81..f5bf33e 100644
--- a/tests/testthat/test_zip_zinb.R
+++ b/tests/testthat/test_zip_zinb.R
@@ -18,9 +18,10 @@ test_that("zero inflated poisson distribution has correct density", {
 test_that("zero inflated negative binomial distribution has correct density", {
   skip_if_not(check_tf_version())
 
-  compare_distribution(zero_inflated_negative_binomial,
-    extraDistr::dzinb,
-    parameters = list(theta = 2, size = 10, prob = 0.1, pi = 0.2),
+  compare_distribution(
+    greta_fun = zero_inflated_negative_binomial,
+    r_fun = extraDistr::dzinb,
+    parameters = list(10, 0.1, 0.2),
     x = extraDistr::rzinb(
       n = 100, size = 10, prob = 0.1, pi = 0.2
     )

From 2e965b38c49f397ac34fe6d312003d280b2ab9a4 Mon Sep 17 00:00:00 2001
From: Nicholas Tierney <nicholas.tierney@gmail.com>
Date: Mon, 1 Aug 2022 16:54:23 +0800
Subject: [PATCH 07/19] change parameter name from theta --> pi

---
 R/zero_inflated_negative_binomial.R | 20 ++++++++++----------
 R/zero_inflated_poisson.R           | 22 +++++++++++-----------
 tests/testthat/helpers.R            | 16 ++++++++--------
 tests/testthat/test_zip_zinb.R      | 22 +++++++++++-----------
 4 files changed, 40 insertions(+), 40 deletions(-)

diff --git a/R/zero_inflated_negative_binomial.R b/R/zero_inflated_negative_binomial.R
index 377b82a..7cb9a8d 100644
--- a/R/zero_inflated_negative_binomial.R
+++ b/R/zero_inflated_negative_binomial.R
@@ -3,41 +3,41 @@
 #' @description A Zero Inflated Negative Binomial distribution
 #' @param size positive integer parameter
 #' @param prob probability parameter (`0 < prob < 1`),
-#' @param theta proportion of zeros
+#' @param pi proportion of zeros
 #' @param dim a scalar giving the number of rows in the resulting greta array
 #' @export
-zero_inflated_negative_binomial <- function(size, prob, theta, dim = NULL) {
-  distrib("zero_inflated_negative_binomial", size, prob, theta, dim)
+zero_inflated_negative_binomial <- function(size, prob, pi, dim = NULL) {
+  distrib("zero_inflated_negative_binomial", size, prob, pi, dim)
 }
 
 zero_inflated_negative_binomial_distribution <- R6::R6Class(
   "zero_inflated_negative_binomial_distribution",
   inherit = distribution_node,
   public = list(
-    initialize = function(size, prob, theta, dim) {
+    initialize = function(size, prob, pi, dim) {
       size <- as.greta_array(size)
       prob <- as.greta_array(prob)
-      theta <- as.greta_array(theta)
+      pi <- as.greta_array(pi)
       # add the nodes as children and parameters
-      dim <- check_dims(size, prob, theta, target_dim = dim)
+      dim <- check_dims(size, prob, pi, target_dim = dim)
       super$initialize("zero_inflated_negative_binomial", dim, discrete = TRUE)
       self$add_parameter(size, "size")
       self$add_parameter(prob, "prob")
-      self$add_parameter(theta, "theta")
+      self$add_parameter(pi, "pi")
     },
     tf_distrib = function(parameters, dag) {
       size <- parameters$size
       p <- parameters$prob # probability of success
-      theta <- parameters$theta
+      pi <- parameters$pi
       q <- fl(1) - parameters$prob
       log_prob <- function(x) {
         tf$math$log(
-          theta * tf$nn$relu(fl(1) - x) + (fl(1) - theta) * tf$pow(p, size) * tf$pow(q, x) * tf$exp(tf$math$lgamma(x + size)) / tf$exp(tf$math$lgamma(size)) / tf$exp(tf$math$lgamma(x + fl(1)))
+          pi * tf$nn$relu(fl(1) - x) + (fl(1) - pi) * tf$pow(p, size) * tf$pow(q, x) * tf$exp(tf$math$lgamma(x + size)) / tf$exp(tf$math$lgamma(size)) / tf$exp(tf$math$lgamma(x + fl(1)))
         )
       }
 
       sample <- function(seed) {
-        binom <- tfp$distributions$Binomial(total_count = 1, probs = theta)
+        binom <- tfp$distributions$Binomial(total_count = 1, probs = pi)
         negbin <-
           tfp$distributions$NegativeBinomial(total_count = size, probs = q) # change of proba / parametrisation in 'stats'
 
diff --git a/R/zero_inflated_poisson.R b/R/zero_inflated_poisson.R
index d20b138..1b0655a 100644
--- a/R/zero_inflated_poisson.R
+++ b/R/zero_inflated_poisson.R
@@ -4,12 +4,12 @@
 #' @description A zero inflated poisson distribution.
 #'
 #' @param lambda rate parameter
-#' @param theta proportion of zeros
+#' @param pi proportion of zeros
 #' @param dim a scalar giving the number of rows in the resulting greta array
 #' @importFrom R6 R6Class
 #' @export
-zero_inflated_poisson <- function (lambda, theta, dim = NULL) {
-  distrib('zero_inflated_poisson', lambda, theta, dim)
+zero_inflated_poisson <- function (lambda, pi, dim = NULL) {
+  distrib('zero_inflated_poisson', lambda, pi, dim)
 }
 
 #' @importFrom R6 R6Class
@@ -17,31 +17,31 @@ zero_inflated_poisson_distribution <- R6::R6Class(
   classname = "zero_inflated_poisson_distribution",
   inherit = distribution_node,
   public = list(
-    initialize = function(lambda, theta, dim) {
+    initialize = function(lambda, pi, dim) {
       lambda <- as.greta_array(lambda)
-      theta <- as.greta_array(theta)
+      pi <- as.greta_array(pi)
       # add the nodes as children and parameters
-      dim <- check_dims(lambda, theta, target_dim = dim)
+      dim <- check_dims(lambda, pi, target_dim = dim)
       super$initialize("zero_inflated_poisson", dim, discrete = TRUE)
       self$add_parameter(lambda, "lambda")
-      self$add_parameter(theta, "theta")
+      self$add_parameter(pi, "pi")
     },
     
     tf_distrib = function(parameters, dag) {
       lambda <- parameters$lambda
-      theta <- parameters$theta
+      pi <- parameters$pi
       log_prob <- function(x) {
         tf$math$log(
-          theta * 
+          pi * 
             tf$nn$relu(fl(1) - x) + 
-            (fl(1) - theta) * 
+            (fl(1) - pi) * 
             tf$pow(lambda, x) * 
             tf$exp(-lambda) / tf$exp(tf$math$lgamma(x + fl(1)))
         )
       }
       
       sample <- function(seed) {
-        binom <- tfp$distributions$Binomial(total_count = 1, probs = theta)
+        binom <- tfp$distributions$Binomial(total_count = 1, probs = pi)
         pois <- tfp$distributions$Poisson(rate = lambda)
         
         zi <- binom$sample(seed = seed)
diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R
index b900954..aa2de6b 100644
--- a/tests/testthat/helpers.R
+++ b/tests/testthat/helpers.R
@@ -901,37 +901,37 @@ check_samples <- function(x,
 
 # zero inflated poisson using distributional
 
-dist_zero_inflated_pois <- function(lambda, prob_zeros) {
+dist_zero_inflated_pois <- function(lambda, pi) {
   distributional::dist_inflated(
     dist = distributional::dist_poisson(lambda = lambda),
-    prob = prob_zeros,
+    prob = pi,
     x = 0
   )
 }
 
 dist_zero_inflated_negative_binomial <-
-  function(size, prob, prob_zeros) {
+  function(size, prob, pi) {
     distributional::dist_inflated(
       dist = distributional::dist_negative_binomial(
         size = size,
         prob = prob
       ),
-      prob = prob_zeros,
+      prob = pi,
       x = 0
     )
   }
 
-sample_zero_inflated_pois <- function(n, lambda, prob) {
+sample_zero_inflated_pois <- function(n, lambda, pi) {
   distributional::generate(
-    x = dist_zero_inflated_pois(lambda = lambda, prob = prob),
+    x = dist_zero_inflated_pois(lambda = lambda, pi = pi),
     n
   )[[1]]
 }
 
 sample_zero_inflated_neg_binomial <-
-  function(n, size, lambda, prob_zeros) {
+  function(n, size, lambda, pi) {
     distributional::generate(
-      x = dist_zero_inflated_pois(lambda = lambda, prob = prob_zeros),
+      x = dist_zero_inflated_pois(lambda = lambda, pi = pi),
       n
     )[[1]]
   }
diff --git a/tests/testthat/test_zip_zinb.R b/tests/testthat/test_zip_zinb.R
index f5bf33e..887437c 100644
--- a/tests/testthat/test_zip_zinb.R
+++ b/tests/testthat/test_zip_zinb.R
@@ -1,18 +1,18 @@
 source("helpers.R")
 
 test_that("zero inflated poisson distribution has correct density", {
-
   skip_if_not(check_tf_version())
-  
-  compare_distribution(greta_fun = zero_inflated_poisson,
-                       r_fun = extraDistr::dzip,
-                       parameters = list(2, 0.2),
-                       x = sample_zero_inflated_pois(
-                         n = 100, 
-                         lambda = 2, 
-                         prob = 0.2)
-                       )
 
+  compare_distribution(
+    greta_fun = zero_inflated_poisson,
+    r_fun = extraDistr::dzip,
+    parameters = list(lambda = 2, pi = 0.2),
+    x = sample_zero_inflated_pois(
+      n = 100,
+      lambda = 2,
+      pi = 0.2
+    )
+  )
 })
 
 test_that("zero inflated negative binomial distribution has correct density", {
@@ -21,7 +21,7 @@ test_that("zero inflated negative binomial distribution has correct density", {
   compare_distribution(
     greta_fun = zero_inflated_negative_binomial,
     r_fun = extraDistr::dzinb,
-    parameters = list(10, 0.1, 0.2),
+    parameters = list(size = 10, prob = 0.1, pi = 0.2),
     x = extraDistr::rzinb(
       n = 100, size = 10, prob = 0.1, pi = 0.2
     )

From 05664eeb6848ffeb716362452b35a4c490731f1f Mon Sep 17 00:00:00 2001
From: Nicholas Tierney <nicholas.tierney@gmail.com>
Date: Mon, 1 Aug 2022 17:07:19 +0800
Subject: [PATCH 08/19] try `pi_var` not `pi` in case `pi` is being interpreted
 as 3.1415...

---
 R/zero_inflated_negative_binomial.R | 6 +++---
 1 file changed, 3 insertions(+), 3 deletions(-)

diff --git a/R/zero_inflated_negative_binomial.R b/R/zero_inflated_negative_binomial.R
index 7cb9a8d..35cfce8 100644
--- a/R/zero_inflated_negative_binomial.R
+++ b/R/zero_inflated_negative_binomial.R
@@ -28,16 +28,16 @@ zero_inflated_negative_binomial_distribution <- R6::R6Class(
     tf_distrib = function(parameters, dag) {
       size <- parameters$size
       p <- parameters$prob # probability of success
-      pi <- parameters$pi
+      pi_var <- parameters$pi
       q <- fl(1) - parameters$prob
       log_prob <- function(x) {
         tf$math$log(
-          pi * tf$nn$relu(fl(1) - x) + (fl(1) - pi) * tf$pow(p, size) * tf$pow(q, x) * tf$exp(tf$math$lgamma(x + size)) / tf$exp(tf$math$lgamma(size)) / tf$exp(tf$math$lgamma(x + fl(1)))
+          pi_var * tf$nn$relu(fl(1) - x) + (fl(1) - pi_var) * tf$pow(p, size) * tf$pow(q, x) * tf$exp(tf$math$lgamma(x + size)) / tf$exp(tf$math$lgamma(size)) / tf$exp(tf$math$lgamma(x + fl(1)))
         )
       }
 
       sample <- function(seed) {
-        binom <- tfp$distributions$Binomial(total_count = 1, probs = pi)
+        binom <- tfp$distributions$Binomial(total_count = 1, probs = pi_var)
         negbin <-
           tfp$distributions$NegativeBinomial(total_count = size, probs = q) # change of proba / parametrisation in 'stats'
 

From 08c82fdc4992b31f7ba908ddf0ef934a420146da Mon Sep 17 00:00:00 2001
From: Hao Ran Lai <hrlai.ecology@gmail.com>
Date: Wed, 30 Nov 2022 14:48:16 +1300
Subject: [PATCH 09/19] refactor code to (1) use sign(abs(x)) instead of relu
 (I believe the latter is less foolproof because I am not sure how greta
 handles negative integers for ZINB that only support non-negative integers),
 then (2) leverage on tf_lchoose that's already being used in some greta
 distributions, and finally (3) calculate things in log space before
 converting them back to regular space

---
 R/zero_inflated_negative_binomial.R | 17 ++++++++++++-----
 1 file changed, 12 insertions(+), 5 deletions(-)

diff --git a/R/zero_inflated_negative_binomial.R b/R/zero_inflated_negative_binomial.R
index 35cfce8..d051e58 100644
--- a/R/zero_inflated_negative_binomial.R
+++ b/R/zero_inflated_negative_binomial.R
@@ -32,21 +32,28 @@ zero_inflated_negative_binomial_distribution <- R6::R6Class(
       q <- fl(1) - parameters$prob
       log_prob <- function(x) {
         tf$math$log(
-          pi_var * tf$nn$relu(fl(1) - x) + (fl(1) - pi_var) * tf$pow(p, size) * tf$pow(q, x) * tf$exp(tf$math$lgamma(x + size)) / tf$exp(tf$math$lgamma(size)) / tf$exp(tf$math$lgamma(x + fl(1)))
+          (pi_var * (fl(1) - tf$math$sign(tf$math$abs(x))) + 
+             tf$math$exp(
+               tf$math$log1p(-pi_var) + 
+                 tf_lchoose(x+size-fl(1), x) + 
+                 size * tf$math$log(p) + 
+                 x * tf$math$log1p(-p)
+             )
+          )
         )
       }
-
+      
       sample <- function(seed) {
         binom <- tfp$distributions$Binomial(total_count = 1, probs = pi_var)
         negbin <-
           tfp$distributions$NegativeBinomial(total_count = size, probs = q) # change of proba / parametrisation in 'stats'
-
+        
         zi <- binom$sample(seed = seed)
         lbd <- negbin$sample(seed = seed)
-
+        
         (fl(1) - zi) * lbd
       }
-
+      
       list(
         log_prob = log_prob,
         sample = sample,

From 69ac9d5484521de16434e9205a333c03362ed2d9 Mon Sep 17 00:00:00 2001
From: Hao Ran Lai <hrlai.ecology@gmail.com>
Date: Wed, 30 Nov 2022 14:49:30 +1300
Subject: [PATCH 10/19] update documentations with new argument names and
 roxygen version

---
 DESCRIPTION                            | 2 +-
 man/zero_inflated_negative_binomial.Rd | 6 +++---
 man/zero_inflated_poisson.Rd           | 6 +++---
 3 files changed, 7 insertions(+), 7 deletions(-)

diff --git a/DESCRIPTION b/DESCRIPTION
index e4e588c..c3a6081 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -44,7 +44,7 @@ Encoding: UTF-8
 Language: en-GB
 LazyData: true
 Roxygen: list(markdown = TRUE)
-RoxygenNote: 7.2.0
+RoxygenNote: 7.2.1
 SystemRequirements: Python (>= 2.7.0) with header files and shared
     library; TensorFlow (v1.14; https://www.tensorflow.org/); TensorFlow
     Probability (v0.7.0; https://www.tensorflow.org/probability/)
diff --git a/man/zero_inflated_negative_binomial.Rd b/man/zero_inflated_negative_binomial.Rd
index 972209a..708e5d4 100644
--- a/man/zero_inflated_negative_binomial.Rd
+++ b/man/zero_inflated_negative_binomial.Rd
@@ -4,15 +4,15 @@
 \alias{zero_inflated_negative_binomial}
 \title{Zero Inflated Negative Binomial}
 \usage{
-zero_inflated_negative_binomial(theta, size, prob, dim = NULL)
+zero_inflated_negative_binomial(size, prob, pi, dim = NULL)
 }
 \arguments{
-\item{theta}{proportion of zeros}
-
 \item{size}{positive integer parameter}
 
 \item{prob}{probability parameter (\verb{0 < prob < 1}),}
 
+\item{pi}{proportion of zeros}
+
 \item{dim}{a scalar giving the number of rows in the resulting greta array}
 }
 \description{
diff --git a/man/zero_inflated_poisson.Rd b/man/zero_inflated_poisson.Rd
index efa7b47..4916698 100644
--- a/man/zero_inflated_poisson.Rd
+++ b/man/zero_inflated_poisson.Rd
@@ -4,13 +4,13 @@
 \alias{zero_inflated_poisson}
 \title{Zero Inflated Poisson distribution}
 \usage{
-zero_inflated_poisson(theta, lambda, dim = NULL)
+zero_inflated_poisson(lambda, pi, dim = NULL)
 }
 \arguments{
-\item{theta}{proportion of zeros}
-
 \item{lambda}{rate parameter}
 
+\item{pi}{proportion of zeros}
+
 \item{dim}{a scalar giving the number of rows in the resulting greta array}
 }
 \description{

From 093ae0b22ea96360f24eabc6dd0c1daba8cf2900 Mon Sep 17 00:00:00 2001
From: Hao Ran Lai <hrlai.ecology@gmail.com>
Date: Wed, 30 Nov 2022 15:01:24 +1300
Subject: [PATCH 11/19] ditto changes from ZINB, and rename pi to pi_var in
 tf_distrib (still not sure if this is necessary, see
 https://github.com/greta-dev/greta.distributions/pull/15#issuecomment-1200926368)

---
 R/zero_inflated_poisson.R | 12 ++++++------
 1 file changed, 6 insertions(+), 6 deletions(-)

diff --git a/R/zero_inflated_poisson.R b/R/zero_inflated_poisson.R
index 1b0655a..f089aa4 100644
--- a/R/zero_inflated_poisson.R
+++ b/R/zero_inflated_poisson.R
@@ -29,14 +29,14 @@ zero_inflated_poisson_distribution <- R6::R6Class(
     
     tf_distrib = function(parameters, dag) {
       lambda <- parameters$lambda
-      pi <- parameters$pi
+      pi_var <- parameters$pi
       log_prob <- function(x) {
         tf$math$log(
-          pi * 
-            tf$nn$relu(fl(1) - x) + 
-            (fl(1) - pi) * 
-            tf$pow(lambda, x) * 
-            tf$exp(-lambda) / tf$exp(tf$math$lgamma(x + fl(1)))
+          (pi_var * (fl(1) - tf$math$sign(tf$math$abs(x))) + 
+             tf$math$exp(
+               tf$math$log1p(-pi_var) - lambda + 
+                 x * tf$math$log(lambda) - tf$math$lgamma(x + fl(1)))
+          )
         )
       }
       

From 6a6ab7b59c9e817b5f75f4072e0f9ef4c53a804e Mon Sep 17 00:00:00 2001
From: Hao Ran Lai <hrlai.ecology@gmail.com>
Date: Thu, 1 Dec 2022 09:02:07 +1300
Subject: [PATCH 12/19] load tfp in zzz.R, see also
 https://github.com/greta-dev/greta.distributions/pull/16#discussion_r936220933

---
 R/zzz.R | 2 ++
 1 file changed, 2 insertions(+)
 create mode 100644 R/zzz.R

diff --git a/R/zzz.R b/R/zzz.R
new file mode 100644
index 0000000..f5f8ae3
--- /dev/null
+++ b/R/zzz.R
@@ -0,0 +1,2 @@
+# load tf probability
+tfp <- reticulate::import("tensorflow_probability", delay_load = TRUE)

From 011fbed51049e9cf114df7fd36e41373e4f63179 Mon Sep 17 00:00:00 2001
From: njtierney <nicholas.tierney@gmail.com>
Date: Wed, 8 Feb 2023 19:43:25 +1100
Subject: [PATCH 13/19] add CRAN badge

---
 README.Rmd | 2 +-
 README.md  | 3 ++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/README.Rmd b/README.Rmd
index 416bdc5..8bc33f0 100644
--- a/README.Rmd
+++ b/README.Rmd
@@ -15,8 +15,8 @@ knitr::opts_chunk$set(
 # greta.distributions
 
 <!-- badges: start -->
-<!-- once you've signed into travis and set it to wath your new repository, you can edit the following badges to point to your repo -->
 [![Codecov test coverage](https://codecov.io/gh/greta-dev/greta.distributions/branch/main/graph/badge.svg)](https://codecov.io/gh/greta-dev/greta.distributions?branch=main)
+[![CRAN status](https://www.r-pkg.org/badges/version/greta.distributions)](https://CRAN.R-project.org/package=greta.distributions)
 [![R-CMD-check](https://github.com/njtierney/greta.distributions/workflows/R-CMD-check/badge.svg)](https://github.com/njtierney/greta.distributions/actions)
 <!-- badges: end -->
 
diff --git a/README.md b/README.md
index c1060d4..7369232 100644
--- a/README.md
+++ b/README.md
@@ -4,10 +4,11 @@
 # greta.distributions
 
 <!-- badges: start -->
-<!-- once you've signed into travis and set it to wath your new repository, you can edit the following badges to point to your repo -->
 
 [![Codecov test
 coverage](https://codecov.io/gh/greta-dev/greta.distributions/branch/main/graph/badge.svg)](https://codecov.io/gh/greta-dev/greta.distributions?branch=main)
+[![CRAN
+status](https://www.r-pkg.org/badges/version/greta.distributions)](https://CRAN.R-project.org/package=greta.distributions)
 [![R-CMD-check](https://github.com/njtierney/greta.distributions/workflows/R-CMD-check/badge.svg)](https://github.com/njtierney/greta.distributions/actions)
 <!-- badges: end -->
 

From 16291fff772498a19109088ed7e5e0c49575e01d Mon Sep 17 00:00:00 2001
From: njtierney <nicholas.tierney@gmail.com>
Date: Fri, 22 Mar 2024 11:18:33 +1000
Subject: [PATCH 14/19] Get package to build without warnings * add coda,
 cramer, distributional, extraDistr, likelihoodExplore, mvtnorm, and truncdist
 to Suggests * Update to Roxygen 7.3.1 * use globalVariables to capture NOTE *
 use new sentinel package structure - "greta-distributions-packate.R" * import
 some greta internals directly into helpers.R

---
 DESCRIPTION                                   | 12 +++++++--
 NAMESPACE                                     |  1 +
 ...ackage.R => greta-distributions-package.R} | 27 ++++++++++++-------
 man/greta.distributions.Rd                    | 20 +++++++++++---
 tests/testthat/helpers.R                      |  3 ++-
 5 files changed, 48 insertions(+), 15 deletions(-)
 rename R/{package.R => greta-distributions-package.R} (50%)

diff --git a/DESCRIPTION b/DESCRIPTION
index c3a6081..8fcbdae 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -29,14 +29,22 @@ Depends:
 Imports: 
     cli,
     glue,
+    progress,
     R6,
     tensorflow (>= 1.14.0)
 Suggests: 
+    coda,
     covr,
+    cramer,
+    distributional,
+    extraDistr,
     knitr,
+    likelihoodExplore,
+    mvtnorm,
     rmarkdown,
     spelling,
-    testthat (>= 3.1.0)
+    testthat (>= 3.1.0),
+    truncdist
 VignetteBuilder: 
     knitr
 Config/testthat/edition: 3
@@ -44,7 +52,7 @@ Encoding: UTF-8
 Language: en-GB
 LazyData: true
 Roxygen: list(markdown = TRUE)
-RoxygenNote: 7.2.1
+RoxygenNote: 7.3.1
 SystemRequirements: Python (>= 2.7.0) with header files and shared
     library; TensorFlow (v1.14; https://www.tensorflow.org/); TensorFlow
     Probability (v0.7.0; https://www.tensorflow.org/probability/)
diff --git a/NAMESPACE b/NAMESPACE
index 1a1e9c2..b5e6eab 100644
--- a/NAMESPACE
+++ b/NAMESPACE
@@ -5,3 +5,4 @@ export(zero_inflated_poisson)
 importFrom(R6,R6Class)
 importFrom(greta,.internals)
 importFrom(tensorflow,tf)
+importFrom(utils,globalVariables)
diff --git a/R/package.R b/R/greta-distributions-package.R
similarity index 50%
rename from R/package.R
rename to R/greta-distributions-package.R
index eb50f27..fdb8fa2 100644
--- a/R/package.R
+++ b/R/greta-distributions-package.R
@@ -2,15 +2,24 @@
 #' @name greta.distributions
 #' 
 #' @description describe your package here, you can re-use the text from DESCRIPTION
-#' 
-#' @docType package
-#' 
+#' @keywords internal
+"_PACKAGE"
+
+## usethis namespace: start
 #' @importFrom tensorflow tf
 #' @importFrom greta .internals
 #' @importFrom R6 R6Class
-#' 
-#' @examples
-#' 
-#' # add a simple example here to introduce the package!
-#' 
-NULL
\ No newline at end of file
+#' @importFrom utils globalVariables
+## usethis namespace: end
+NULL
+
+globalVariables(
+  c(
+    "as_2d_array",
+    "as_data",
+    "calculate",
+    "initials",
+    "prep_initials",
+    "variable"
+  )
+)
\ No newline at end of file
diff --git a/man/greta.distributions.Rd b/man/greta.distributions.Rd
index b36bcd1..625d4ba 100644
--- a/man/greta.distributions.Rd
+++ b/man/greta.distributions.Rd
@@ -1,14 +1,28 @@
 % Generated by roxygen2: do not edit by hand
-% Please edit documentation in R/package.R
+% Please edit documentation in R/greta-distributions-package.R
 \docType{package}
 \name{greta.distributions}
+\alias{greta.distributions-package}
 \alias{greta.distributions}
 \title{Extends Distributions Available in the \code{greta} package}
 \description{
 describe your package here, you can re-use the text from DESCRIPTION
 }
-\examples{
+\seealso{
+Useful links:
+\itemize{
+  \item \url{https://github.com/greta-dev/greta.distributions}
+  \item Report bugs at \url{https://github.com/greta-dev/greta.distributions/issues}
+}
+
+}
+\author{
+\strong{Maintainer}: Nicholas Tierney \email{nicholas.tierney@gmail.com} (\href{https://orcid.org/0000-0003-1460-8722}{ORCID})
 
-# add a simple example here to introduce the package!
+Authors:
+\itemize{
+  \item Nick Golding \email{nick.golding.research@gmail.com} (\href{https://orcid.org/0000-0001-8916-5570}{ORCID})
+}
 
 }
+\keyword{internal}
diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R
index aa2de6b..622d79a 100644
--- a/tests/testthat/helpers.R
+++ b/tests/testthat/helpers.R
@@ -1,5 +1,6 @@
 # test functions
-
+check_tf_version <- .internals$checks$check_tf_version
+create_progress_bar <- .internals$inference$progress_bar$create_progress_bar
 # set the seed and flush the graph before running tests
 if (check_tf_version()) {
   tensorflow::tf$compat$v1$reset_default_graph()

From b6af3d5e1dd79e4d71e8d83178731992345f412c Mon Sep 17 00:00:00 2001
From: njtierney <nicholas.tierney@gmail.com>
Date: Mon, 2 Dec 2024 12:10:43 +1100
Subject: [PATCH 15/19] small tweaks to get R CMD Check to pass

---
 DESCRIPTION              |  9 +++---
 tests/testthat/helpers.R | 70 +++++++++++++++++++++-------------------
 2 files changed, 41 insertions(+), 38 deletions(-)

diff --git a/DESCRIPTION b/DESCRIPTION
index 8fcbdae..a918b45 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -24,14 +24,15 @@ License: Apache License 2.0
 URL: https://github.com/greta-dev/greta.distributions
 BugReports: https://github.com/greta-dev/greta.distributions/issues
 Depends: 
-    greta (>= 0.4.2),
-    R (>= 3.1.0)
+    greta (>= 0.5.0),
+    R (>= 4.1.0)
 Imports: 
     cli,
     glue,
     progress,
     R6,
-    tensorflow (>= 1.14.0)
+    tensorflow (== 2.16.0),
+    rlang
 Suggests: 
     coda,
     covr,
@@ -52,7 +53,7 @@ Encoding: UTF-8
 Language: en-GB
 LazyData: true
 Roxygen: list(markdown = TRUE)
-RoxygenNote: 7.3.1
+RoxygenNote: 7.3.2
 SystemRequirements: Python (>= 2.7.0) with header files and shared
     library; TensorFlow (v1.14; https://www.tensorflow.org/); TensorFlow
     Probability (v0.7.0; https://www.tensorflow.org/probability/)
diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R
index 622d79a..18bbd41 100644
--- a/tests/testthat/helpers.R
+++ b/tests/testthat/helpers.R
@@ -56,13 +56,15 @@ get_density <- function(distrib, data) {
   as.vector(grab(tensor, dag))
 }
 
-compare_distribution <- function(greta_fun,
-                                 r_fun,
-                                 parameters,
-                                 x,
-                                 dim = NULL,
-                                 multivariate = FALSE,
-                                 tolerance = 1e-4) {
+compare_distribution <- function(
+    greta_fun,
+    r_fun,
+    parameters,
+    x,
+    dim = NULL,
+    multivariate = FALSE,
+    tolerance = 1e-4
+) {
   # calculate the absolute difference in the log density of some data between
   # greta and a r benchmark.
   # 'greta_fun' is the greta distribution constructor function (e.g. normal())
@@ -73,8 +75,11 @@ compare_distribution <- function(greta_fun,
 
   # define greta distribution, with fixed values
   greta_log_density <- greta_density(
-    greta_fun, parameters, x,
-    dim, multivariate
+    greta_fun,
+    parameters,
+    x,
+    dim,
+    multivariate
   )
   # get R version
   r_log_density <- log(do.call(r_fun, c(list(x), parameters)))
@@ -85,54 +90,51 @@ compare_distribution <- function(greta_fun,
 
 # evaluate the log density of x, given 'parameters' and a distribution
 # constructor function 'fun'
-greta_density <- function(fun,
-                          parameters,
-                          x,
-                          dim = NULL,
-                          multivariate = FALSE) {
-  if (is.null(dim)) {
-    dim <- NROW(x)
-  }
-
+greta_density <- function(
+    fun,
+    parameters,
+    x,
+    dim = NULL,
+    multivariate = FALSE
+) {
+  
+  dim <- dim %||% NROW(x)
+  
   # add the output dimension to the arguments list
   dim_list <- list(dim = dim)
-
+  
   # if it's a multivariate distribution name it n_realisations
   if (multivariate) {
     names(dim_list) <- "n_realisations"
   }
-
+  
   # don't add it for wishart & lkj, which don't mave multiple realisations
   is_wishart <- identical(names(parameters), c("df", "Sigma"))
   is_lkj <- identical(names(parameters), c("eta", "dimension"))
   if (is_wishart | is_lkj) {
     dim_list <- list()
   }
-
+  
   parameters <- c(parameters, dim_list)
-
+  
   # evaluate greta distribution
   dist <- do.call(fun, parameters)
   distrib_node <- get_node(dist)$distribution
-
+  
   # set density
   x_ <- as.greta_array(x)
   distrib_node$remove_target()
   distrib_node$add_target(get_node(x_))
-
+  
   # create dag
   dag <- dag_class$new(list(x_))
-  dag$define_tf()
-  dag$set_tf_data_list("batch_size", 1L)
-  dag$build_feed_dict()
-
+  
+  dag$tf_environment$batch_size <- 1L
+  distrib_node$define_tf(dag)
+  
   # get the log density as a vector
-  dag$on_graph(result <-
-    dag$evaluate_density(distrib_node, get_node(x_)))
-  assign("test_density", result, dag$tf_environment)
-
-  density <- dag$tf_sess_run(test_density)
-  as.vector(density)
+  result <- dag$evaluate_density(distrib_node, get_node(x_))
+  as.vector(result)
 }
 
 # execute a call via greta, swapping the objects named in 'swap' to greta

From c39ff5c9413af985624e04eac267d52587db7d36 Mon Sep 17 00:00:00 2001
From: njtierney <nicholas.tierney@gmail.com>
Date: Wed, 4 Dec 2024 18:42:24 +1100
Subject: [PATCH 16/19] port over R-CMD-Check for GH actions from greta main

---
 .github/workflows/R-CMD-check.yaml | 133 +++++++++++++++--------------
 1 file changed, 71 insertions(+), 62 deletions(-)

diff --git a/.github/workflows/R-CMD-check.yaml b/.github/workflows/R-CMD-check.yaml
index 2bfc291..c222df8 100644
--- a/.github/workflows/R-CMD-check.yaml
+++ b/.github/workflows/R-CMD-check.yaml
@@ -1,9 +1,3 @@
-# NOTE: This workflow is overkill for most R packages
-# check-standard.yaml is likely a better choice
-# usethis::use_github_action("check-standard") will install it.
-#
-# For help debugging build failures open an issue on the RStudio community with the 'github-actions' tag.
-# https://community.rstudio.com/new-topic?category=Package%20development&tags=github-actions
 on:
   push:
     branches:
@@ -13,96 +7,111 @@ on:
     branches:
       - main
       - master
+  schedule:
+    - cron: '1 23 * * Sun'
 
 name: R-CMD-check
 
-concurrency:
-  group: ${{ github.workflow }}-${{ github.head_ref }}
-  cancel-in-progress: true
+defaults:
+  run:
+    shell: Rscript {0}
 
 jobs:
   R-CMD-check:
-    runs-on: ${{ matrix.config.os }}
-
-    name: ${{ matrix.config.os }} (${{ matrix.config.r }})
-
+    name: ${{ matrix.os }}, tf-${{ matrix.tf }}, R-${{ matrix.r}}
+    timeout-minutes: 30
     strategy:
       fail-fast: false
       matrix:
-        config:
-          - {os: macOS-latest,   r: 'release'}
-          - {os: windows-latest, r: 'release'}
-          - {os: windows-latest, r: 'oldrel'}
-          - {os: ubuntu-18.04,   r: 'devel', rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest", http-user-agent: "R/4.0.0 (ubuntu-18.04) R (4.0.0 x86_64-pc-linux-gnu x86_64 linux-gnu) on GitHub Actions" }
-          - {os: ubuntu-18.04,   r: 'release', rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest"}
-          - {os: ubuntu-18.04,   r: 'oldrel-1', rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest"}
-          - {os: ubuntu-18.04,   r: 'oldrel-2', rspm: "https://packagemanager.rstudio.com/cran/__linux__/bionic/latest"}
+        include:
+          - {os: 'ubuntu-latest' , tf: 'default', r: 'release'}
+          - {os: 'windows-latest', tf: 'default', r: 'release'}
+          - {os: 'macOS-latest'  , tf: 'default', r: 'release'}
 
+    runs-on: ${{ matrix.os }}
+    continue-on-error: ${{ matrix.tf == 'nightly' || contains(matrix.tf, 'rc') || matrix.r == 'devel' }}
     env:
-      RSPM: ${{ matrix.config.rspm }}
+      R_REMOTES_NO_ERRORS_FROM_WARNINGS: 'true'
+      R_COMPILE_AND_INSTALL_PACKAGES: 'never'
       GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
-      RETICULATE_AUTOCONFIGURE: 'FALSE'
-      TF_VERSION: '1.14.0'
 
     steps:
 
       - uses: actions/checkout@v2
 
       - uses: r-lib/actions/setup-r@v2
-        id: install-r
+        id: setup-r
         with:
-          r-version: ${{ matrix.config.r }}
-          http-user-agent: ${{ matrix.config.http-user-agent }}
+          r-version: ${{ matrix.r }}
+          Ncpus: '2L'
+          use-public-rspm: true
 
       - uses: r-lib/actions/setup-pandoc@v2
 
-      - uses: r-lib/actions/setup-r-dependencies@v2
+      - name: Get Date
+        id: get-date
+        shell: bash
+        run: |
+          echo "::set-output name=year-week::$(date -u "+%Y-%U")"
+          echo "::set-output name=date::$(date -u "+%F")"
+
+      - name: Restore R package cache
+        uses: actions/cache@v2
+        id: r-package-cache
         with:
-          cache-version: 2
-          extra-packages: |
-            local::.
-            any::keras
-            any::rcmdcheck
+          path: ${{ env.R_LIBS_USER }}
+          key: ${{ matrix.os }}-${{ steps.setup-r.outputs.installed-r-version }}-${{ steps.get-date.outputs.year-week }}-1
 
-      - name: Install Miniconda
-        run: |
-          reticulate::install_miniconda()
-        shell: Rscript {0}
+      - name: Install remotes
+        if: steps.r-package-cache.outputs.cache-hit != 'true'
+        run: install.packages("remotes")
 
-      - name: Set options for conda binary for macOS
-        if: runner.os == 'macOS'
+      - name: Install system dependencies
+        if: runner.os == 'Linux'
+        shell: bash
         run: |
-          echo "options(reticulate.conda_binary = reticulate:::miniconda_conda())" >> .Rprofile
+          . /etc/os-release
+          while read -r cmd
+          do
+            echo "$cmd"
+            sudo $cmd
+          done < <(Rscript -e "writeLines(remotes::system_requirements('$ID-$VERSION_ID'))")
+
+      - name: Install package + deps
+        run: remotes::install_local(dependencies = TRUE, force = TRUE)
 
-      - name: Install TensorFlow
+      - name: Install greta deps
         run: |
-          cat("::group::Create Environment", sep = "\n")
-          reticulate::conda_create('r-reticulate', packages = c('python==3.7'))
-          cat("::endgroup::", sep = "\n")
+          library(greta)
+          greta::install_greta_deps(timeout = 50)
 
-          cat("::group::Install Tensorflow", sep = "\n")
-          keras::install_keras(tensorflow = Sys.getenv('TF_VERSION'),
-              extra_packages = c('IPython', 'requests', 'certifi', 'urllib3', 'tensorflow-probability==0.7.0', 'numpy==1.16.4'))
-          cat("::endgroup::", sep = "\n")
-        shell: Rscript {0}
+      - name: Situation Report on greta install
+        run: greta::greta_sitrep()
 
+      - name: Install rcmdcheck
+        run: remotes::install_cran("rcmdcheck")
 
-      - name: Python + TF details
-        run: |
-          tensorflow::tf_config()
-          tensorflow::tf_version()
-          reticulate::py_module_available("tensorflow_probability")
-          reticulate::py_config()
-        shell: Rscript {0}
+      - name: Check
+        run: rcmdcheck::rcmdcheck(args = '--no-manual', error_on = 'warning', check_dir = 'check')
+
+      - name: Show testthat output
+        if: always()
+        shell: bash
+        run: find check -name 'testthat.Rout*' -exec cat '{}' \; || true
+
+      - name: Don't use tar from old Rtools to store the cache
+        if: ${{ runner.os == 'Windows' && startsWith(steps.install-r.outputs.installed-r-version, '3') }}
+        shell: bash
+        run: echo "C:/Program Files/Git/usr/bin" >> $GITHUB_PATH
+
+      - name: Check on single core machine
+        if: runner.os != 'Windows'
+        env:
+          R_PARALLELLY_AVAILABLE_CORES: 1
+        run: rcmdcheck::rcmdcheck(args = c("--no-manual", "--as-cran", "--no-multiarch"))
 
       - name: Session info
         run: |
           options(width = 100)
           pkgs <- installed.packages()[, "Package"]
           sessioninfo::session_info(pkgs, include_base = TRUE)
-        shell: Rscript {0}
-
-      - uses: r-lib/actions/check-r-package@v2
-        with:
-          args: 'c("--no-manual", "--as-cran", "--no-multiarch")'
-

From 0e6fdfc7ffcc0e59f1fcf76a57534eb8743251d5 Mon Sep 17 00:00:00 2001
From: njtierney <nicholas.tierney@gmail.com>
Date: Thu, 5 Dec 2024 18:31:02 +1100
Subject: [PATCH 17/19] also update test coverage GH action

---
 .github/workflows/test-coverage.yaml | 113 +++++++++++++++------------
 1 file changed, 63 insertions(+), 50 deletions(-)

diff --git a/.github/workflows/test-coverage.yaml b/.github/workflows/test-coverage.yaml
index 9644845..558fef7 100644
--- a/.github/workflows/test-coverage.yaml
+++ b/.github/workflows/test-coverage.yaml
@@ -1,76 +1,89 @@
+# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
+# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
 on:
   push:
-    branches:
-      - main
-      - master
+    branches: [main, master]
   pull_request:
-    branches:
-      - main
-      - master
+    branches: [main, master]
 
-name: test-coverage
+name: test-coverage.yaml
+
+permissions: read-all
 
 jobs:
   test-coverage:
-    runs-on: macOS-latest
+    runs-on: ubuntu-latest
     env:
       GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
-    steps:
-      - uses: actions/checkout@v2
 
-      - uses: r-lib/actions/setup-r@v1
-
-      - uses: r-lib/actions/setup-pandoc@v1
+    steps:
+      - uses: actions/checkout@v4
 
-      - name: Query dependencies
-        run: |
-          install.packages('remotes')
-          saveRDS(remotes::dev_package_deps(dependencies = TRUE), ".github/depends.Rds", version = 2)
-          writeLines(sprintf("R-%i.%i", getRversion()$major, getRversion()$minor), ".github/R-version")
-        shell: Rscript {0}
+      - uses: r-lib/actions/setup-r@v2
+        with:
+          use-public-rspm: true
 
-      - name: Restore R package cache
-        uses: actions/cache@v2
+      - uses: r-lib/actions/setup-r-dependencies@v2
         with:
-          path: ${{ env.R_LIBS_USER }}
-          key: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-${{ hashFiles('.github/depends.Rds') }}
-          restore-keys: ${{ runner.os }}-${{ hashFiles('.github/R-version') }}-1-
+          extra-packages: |
+            any::covr
+            any::xml2
+            any::remotes
+          needs: coverage
 
-      - name: Install dependencies
+      - name: Install system dependencies
+        if: runner.os == 'Linux'
+        shell: bash
         run: |
-          install.packages(c("remotes"))
-          remotes::install_deps(dependencies = TRUE)
-          remotes::install_cran("covr")
+          . /etc/os-release
+          while read -r cmd
+          do
+            echo "$cmd"
+            sudo $cmd
+          done < <(Rscript -e "writeLines(remotes::system_requirements('$ID-$VERSION_ID'))")
+
+      - name: Install package + deps
+        run: remotes::install_local(dependencies = TRUE, force = TRUE)
         shell: Rscript {0}
 
-      ###
-      - name: Install Miniconda
+      - name: Install greta deps
         run: |
-          install.packages(c("remotes", "keras"))
-          reticulate::install_miniconda()
+          library(greta)
+          greta::install_greta_deps(timeout = 50)
         shell: Rscript {0}
 
-      - name: Set options for conda binary for macOS
-        if: runner.os == 'macOS'
-        run: |
-          echo "options(reticulate.conda_binary = reticulate:::miniconda_conda())" >> .Rprofile
+      - name: Situation Report on greta install
+        run: greta::greta_sitrep()
+        shell: Rscript {0}
 
-#  Perhaps here is where we can install / change the environment that we are
-# installing into? Can we call our own greta install functions here?
-      - name: Install TensorFlow
+      - name: Test coverage
         run: |
-          reticulate::conda_create(envname = "greta-env",python_version = "3.7")
-          reticulate::conda_install(envname = "greta-env", packages = c("numpy==1.16.4", "tensorflow-probability==0.7.0", "tensorflow==1.14.0"))
+          cov <- covr::package_coverage(
+            quiet = FALSE,
+            clean = FALSE,
+            install_path = file.path(normalizePath(Sys.getenv("RUNNER_TEMP"), winslash = "/"), "package")
+          )
+          covr::to_cobertura(cov)
         shell: Rscript {0}
 
-      - name: Python + TF details
+      - uses: codecov/codecov-action@v4
+        with:
+          fail_ci_if_error: ${{ github.event_name != 'pull_request' && true || false }}
+          file: ./cobertura.xml
+          plugin: noop
+          disable_search: true
+          token: ${{ secrets.CODECOV_TOKEN }}
+
+      - name: Show testthat output
+        if: always()
         run: |
-          Rscript -e 'tensorflow::tf_config()'
-          Rscript -e 'tensorflow::tf_version()'
-          Rscript -e 'reticulate::py_module_available("tensorflow_probability")'
-          Rscript -e 'reticulate::py_config()'
-      ###
+          ## --------------------------------------------------------------------
+          find '${{ runner.temp }}/package' -name 'testthat.Rout*' -exec cat '{}' \; || true
+        shell: bash
 
-      - name: Test coverage
-        run: covr::codecov()
-        shell: Rscript {0}
+      - name: Upload test results
+        if: failure()
+        uses: actions/upload-artifact@v4
+        with:
+          name: coverage-test-failures
+          path: ${{ runner.temp }}/package

From dbdbadda00e9eb65b9672448a1a1059ce2ac2f96 Mon Sep 17 00:00:00 2001
From: njtierney <nicholas.tierney@gmail.com>
Date: Sat, 7 Dec 2024 15:22:26 +1100
Subject: [PATCH 18/19] setting some helpers up for future more robust testing
 #23

---
 tests/testthat/helpers.R       | 85 ++++++++++++++++++++++++++--------
 tests/testthat/test_zip_zinb.R | 23 +++++++++
 2 files changed, 89 insertions(+), 19 deletions(-)

diff --git a/tests/testthat/helpers.R b/tests/testthat/helpers.R
index 18bbd41..7e1fac5 100644
--- a/tests/testthat/helpers.R
+++ b/tests/testthat/helpers.R
@@ -868,38 +868,85 @@ check_mvn_samples <- function(sampler, n_effective = 3000) {
 # compare the samples with iid samples returned by iid_function (which takes the
 # number of arguments as its sole argument), producing a labelled qqplot, and
 # running a KS test for differences between the two samples
-check_samples <- function(x,
-                          iid_function,
-                          sampler = hmc(),
-                          n_effective = 3000,
-                          title = NULL,
-                          one_by_one = FALSE) {
+# sample values of greta array 'x' (which must follow a distribution), and
+# compare the samples with iid samples returned by iid_function (which takes the
+# number of arguments as its sole argument), producing a labelled qqplot, and
+# running a KS test for differences between the two samples
+check_samples <- function(
+    x,
+    iid_function,
+    sampler = hmc(),
+    n_effective = 3000,
+    title = NULL,
+    one_by_one = FALSE,
+    time_limit = 300
+) {
   m <- model(x, precision = "single")
   draws <- get_enough_draws(
-    m,
+    model = m,
     sampler = sampler,
     n_effective = n_effective,
-    verbose = FALSE,
-    one_by_one = one_by_one
+    verbose = TRUE,
+    one_by_one = one_by_one,
+    time_limit = time_limit
   )
-
+  
   neff <- coda::effectiveSize(draws)
   iid_samples <- iid_function(neff)
   mcmc_samples <- as.matrix(draws)
+  
+  thin_amount <- find_thinning(draws)
+  
+  mcmc_samples <- do_thinning(mcmc_samples, thin_amount)
+  iid_samples <- do_thinning(iid_samples, thin_amount)
+  
+  list(
+    mcmc_samples = mcmc_samples,
+    iid_samples = iid_samples,
+    distrib = get_distribution_name(x),
+    sampler_name = class(sampler)[1]
+  )
+}
 
-  # plot
-  if (is.null(title)) {
-    distrib <- get_node(x)$distribution$distribution_name
-    sampler_name <- class(sampler)[1]
-    title <- paste(distrib, "with", sampler_name)
-  }
 
-  stats::qqplot(mcmc_samples, iid_samples, main = title)
+qqplot_checked_samples <- function(checked_samples, title){
+  
+  distrib <- checked_samples$distrib
+  sampler_name <- checked_samples$sampler_name
+  title <- paste(distrib, "with", sampler_name)
+  
+  mcmc_samples <- checked_samples$mcmc_samples
+  iid_samples <- checked_samples$iid_samples
+  
+  stats::qqplot(
+    x = mcmc_samples,
+    y = iid_samples,
+    main = title
+  )
+  
   graphics::abline(0, 1)
+}
 
+
+## helpers for running Kolmogorov-Smirnov test for MCMC samples vs IID samples
+ks_test_mcmc_vs_iid <- function(checked_samples){
   # do a formal hypothesis test
-  suppressWarnings(stat <- ks.test(mcmc_samples, iid_samples))
-  testthat::expect_gte(stat$p.value, 0.01)
+  suppressWarnings(stat <- ks.test(checked_samples$mcmc_samples,
+                                   checked_samples$iid_samples))
+  stat
+}
+
+## helpers for looping through optimisers
+run_opt <- function(
+    m,
+    optmr,
+    max_iterations = 200
+) {
+  opt(
+    m,
+    optimiser = optmr(),
+    max_iterations = max_iterations
+  )
 }
 
 # zero inflated poisson using distributional
diff --git a/tests/testthat/test_zip_zinb.R b/tests/testthat/test_zip_zinb.R
index 887437c..4f12135 100644
--- a/tests/testthat/test_zip_zinb.R
+++ b/tests/testthat/test_zip_zinb.R
@@ -27,3 +27,26 @@ test_that("zero inflated negative binomial distribution has correct density", {
     )
   )
 })
+
+# test_that("samplers are unbiased for zip", {
+#   skip_if_not(check_tf_version())
+#   
+#   x <- zero_inflated_poisson(0.1, 0.2)
+#   iid <- function(n) {
+#     extraDistr::rzip(n = n, lamb = 0.1, pi = 0.2)
+#   }
+#   
+#   zip_checked <- check_samples(
+#     x = x,
+#     iid_function = iid,
+#     one_by_one = TRUE
+#   )
+#   
+#   # do the plotting
+#   qqplot_checked_samples(zip_checked)
+#   
+#   # do a formal hypothesis test
+#   stat <- ks_test_mcmc_vs_iid(lkj_checked)
+#   
+#   expect_gte(stat$p.value, 0.01)
+# })

From 9b72b8a0cee89664e0944e22dd7155b05a6bda7e Mon Sep 17 00:00:00 2001
From: njtierney <nicholas.tierney@gmail.com>
Date: Mon, 9 Dec 2024 17:08:48 +1100
Subject: [PATCH 19/19] establish pkgdown site

---
 .Rbuildignore                  |  3 +++
 .github/workflows/pkgdown.yaml | 49 ++++++++++++++++++++++++++++++++++
 .gitignore                     |  1 +
 DESCRIPTION                    |  2 +-
 _pkgdown.yml                   |  4 +++
 5 files changed, 58 insertions(+), 1 deletion(-)
 create mode 100644 .github/workflows/pkgdown.yaml
 create mode 100644 _pkgdown.yml

diff --git a/.Rbuildignore b/.Rbuildignore
index ead851f..96a070a 100644
--- a/.Rbuildignore
+++ b/.Rbuildignore
@@ -6,3 +6,6 @@
 ^codecov\.yml$
 ^CODE_OF_CONDUCT\.md$
 ^\.github$
+^_pkgdown\.yml$
+^docs$
+^pkgdown$
diff --git a/.github/workflows/pkgdown.yaml b/.github/workflows/pkgdown.yaml
new file mode 100644
index 0000000..bfc9f4d
--- /dev/null
+++ b/.github/workflows/pkgdown.yaml
@@ -0,0 +1,49 @@
+# Workflow derived from https://github.com/r-lib/actions/tree/v2/examples
+# Need help debugging build failures? Start at https://github.com/r-lib/actions#where-to-find-help
+on:
+  push:
+    branches: [main, master]
+  pull_request:
+  release:
+    types: [published]
+  workflow_dispatch:
+
+name: pkgdown.yaml
+
+permissions: read-all
+
+jobs:
+  pkgdown:
+    runs-on: ubuntu-latest
+    # Only restrict concurrency for non-PR jobs
+    concurrency:
+      group: pkgdown-${{ github.event_name != 'pull_request' || github.run_id }}
+    env:
+      GITHUB_PAT: ${{ secrets.GITHUB_TOKEN }}
+    permissions:
+      contents: write
+    steps:
+      - uses: actions/checkout@v4
+
+      - uses: r-lib/actions/setup-pandoc@v2
+
+      - uses: r-lib/actions/setup-r@v2
+        with:
+          use-public-rspm: true
+
+      - uses: r-lib/actions/setup-r-dependencies@v2
+        with:
+          extra-packages: any::pkgdown, local::.
+          needs: website
+
+      - name: Build site
+        run: pkgdown::build_site_github_pages(new_process = FALSE, install = FALSE)
+        shell: Rscript {0}
+
+      - name: Deploy to GitHub pages 🚀
+        if: github.event_name != 'pull_request'
+        uses: JamesIves/github-pages-deploy-action@v4.5.0
+        with:
+          clean: false
+          branch: gh-pages
+          folder: docs
diff --git a/.gitignore b/.gitignore
index d387fe3..6926d34 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,3 +1,4 @@
 .Rhistory
 .RData
 .Rproj.user
+docs
diff --git a/DESCRIPTION b/DESCRIPTION
index a918b45..1c9bdbb 100644
--- a/DESCRIPTION
+++ b/DESCRIPTION
@@ -21,7 +21,7 @@ Description: Provides extra distributions for use with the 'greta' package.
   These will include distributions such as zero inflated negative binomial,
   zero inflated poisson, interval censored lognormal, and more.
 License: Apache License 2.0
-URL: https://github.com/greta-dev/greta.distributions
+URL: https://github.com/greta-dev/greta.distributions, https://greta-dev.github.io/greta.distributions/
 BugReports: https://github.com/greta-dev/greta.distributions/issues
 Depends: 
     greta (>= 0.5.0),
diff --git a/_pkgdown.yml b/_pkgdown.yml
new file mode 100644
index 0000000..2d2f904
--- /dev/null
+++ b/_pkgdown.yml
@@ -0,0 +1,4 @@
+url: https://greta-dev.github.io/greta.distributions/
+template:
+  bootstrap: 5
+