Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Create native Dirichlet multinomial family #1729

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ export(dhurdle_negbinomial)
export(dhurdle_poisson)
export(dinv_gaussian)
export(dirichlet)
export(dirichlet_multinomial)
export(dlogistic_normal)
export(dmulti_normal)
export(dmulti_student_t)
Expand Down
12 changes: 12 additions & 0 deletions R/distributions.R
Original file line number Diff line number Diff line change
Expand Up @@ -2185,6 +2185,18 @@ dmultinomial <- function(x, eta, log = FALSE) {
out
}

# density of the multinomial distribution with the softmax transform
# @param x positive integers not greater than ncat
# @param eta the linear predictor (of length or ncol ncat)
# @param phi the dispersion parameter
# @param log return values on the log scale?
ddirichletmultinomial <- function(x, eta, phi, log = FALSE) {
require_package("extraDistr")
alpha <- softmax(eta) * phi
size <- sum(x)
extraDistr::ddirmnom(x, size = size, alpha = alpha, log = log)
}

# density of the cumulative distribution
#
# @param x Integer vector containing response category indices to return the
Expand Down
25 changes: 18 additions & 7 deletions R/families.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
#' \code{inverse.gaussian}, \code{exponential}, \code{weibull},
#' \code{frechet}, \code{Beta}, \code{dirichlet}, \code{von_mises},
#' \code{asym_laplace}, \code{gen_extreme_value}, \code{categorical},
#' \code{multinomial}, \code{cumulative}, \code{cratio}, \code{sratio},
#' \code{multinomial}, \code{dirichlet_multinomial}, \code{cumulative}, \code{cratio}, \code{sratio},
#' \code{acat}, \code{hurdle_poisson}, \code{hurdle_negbinomial},
#' \code{hurdle_gamma}, \code{hurdle_lognormal}, \code{hurdle_cumulative},
#' \code{zero_inflated_binomial}, \code{zero_inflated_beta_binomial},
Expand Down Expand Up @@ -51,8 +51,9 @@
#' consecutive thresholds to the same value, and
#' \code{"sum_to_zero"} ensures the thresholds sum to zero.
#' @param refcat Optional name of the reference response category used in
#' \code{categorical}, \code{multinomial}, \code{dirichlet} and
#' \code{logistic_normal} models. If \code{NULL} (the default), the first
#' \code{categorical}, \code{multinomial}, \code{dirichlet},
#' \code{dirichlet_multinomial} and \code{logistic_normal} models.
#' If \code{NULL} (the default), the first
#' category is used as the reference. If \code{NA}, all categories will be
#' predicted, which requires strong priors or carefully specified predictor
#' terms in order to lead to an identified model.
Expand All @@ -76,8 +77,9 @@
#' can be used for binary regression (i.e., most commonly logistic
#' regression).}
#'
#' \item{Families \code{categorical} and \code{multinomial} can be used for
#' multi-logistic regression when there are more than two possible outcomes.}
#' \item{Families \code{categorical}, \code{multinomial} and
#' \code{dirichlet_multinomial} can be used for multi-logistic regression
#' when there are more than two possible outcomes.}
#'
#' \item{Families \code{cumulative}, \code{cratio} ('continuation ratio'),
#' \code{sratio} ('stopping ratio'), and \code{acat} ('adjacent category')
Expand Down Expand Up @@ -150,8 +152,8 @@
#' \code{acat}, and \code{hurdle_cumulative} support \code{logit},
#' \code{probit}, \code{probit_approx}, \code{cloglog}, and \code{cauchit}.}
#'
#' \item{Families \code{categorical}, \code{multinomial}, and \code{dirichlet}
#' support \code{logit}.}
#' \item{Families \code{categorical}, \code{multinomial},
#' \code{dirichlet_multinomial} and \code{dirichlet} support \code{logit}.}
#'
#' \item{Families \code{Gamma}, \code{weibull}, \code{exponential},
#' \code{frechet}, and \code{hurdle_gamma} support
Expand Down Expand Up @@ -812,6 +814,15 @@ multinomial <- function(link = "logit", refcat = NULL) {
.brmsfamily("multinomial", link = link, slink = slink, refcat = refcat)
}

#' @rdname brmsfamily
#' @export
dirichlet_multinomial <- function(link = "logit", link_phi = "log",
refcat = NULL) {
slink <- substitute(link)
.brmsfamily("dirichlet_multinomial", link = link, slink = slink,
link_phi = link_phi, refcat = refcat)
}

#' @rdname brmsfamily
#' @export
cumulative <- function(link = "logit", link_disc = "log",
Expand Down
14 changes: 14 additions & 0 deletions R/family-lists.R
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,20 @@
)
}

.family_dirichlet_multinomial <- function() {
list(
links = "logit",
dpars = "phi",
multi_dpars = "mu", # size determined by the data
type = "int", ybounds = c(-Inf, Inf),
closed = c(NA, NA),
ad = c("weights", "subset", "trials", "index"),
specials = c("multinomial", "joint_link"),
include = "fun_dirichlet_multinomial_logit.stan",
normalized = ""
)
}

.family_beta <- function() {
list(
links = c(
Expand Down
11 changes: 11 additions & 0 deletions R/log_lik.R
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,17 @@ log_lik_multinomial <- function(i, prep) {
log_lik_weight(out, i = i, prep = prep)
}

log_lik_dirichlet_multinomial <- function(i, prep) {
stopifnot(prep$family$link == "logit")
eta <- get_Mu(prep, i = i)
eta <- insert_refcat(eta, refcat = prep$refcat)
phi <- get_dpar(prep, "phi", i = i)
cats <- seq_len(prep$data$ncat)
alpha <- dcategorical(cats, eta = eta) * phi
out <- ddirichletmultinomial(prep$data$Y[i, ], eta = eta, phi = phi, log = TRUE)
log_lik_weight(out, i = i, prep = prep)
}

log_lik_dirichlet <- function(i, prep) {
stopifnot(prep$family$link == "logit")
eta <- get_Mu(prep, i = i)
Expand Down
15 changes: 15 additions & 0 deletions R/posterior_epred.R
Original file line number Diff line number Diff line change
Expand Up @@ -587,6 +587,21 @@ posterior_epred_multinomial <- function(prep) {
out
}

posterior_epred_dirichlet_multinomial <- function(prep) {
get_counts <- function(i) {
eta <- insert_refcat(slice_col(eta, i), refcat = prep$refcat)
dcategorical(cats, eta = eta) * trials[i]
}
# dirichlet part included in mu
eta <- get_Mu(prep)
cats <- seq_len(prep$data$ncat)
trials <- prep$data$trials
out <- abind(lapply(seq_cols(eta), get_counts), along = 3)
out <- aperm(out, perm = c(1, 3, 2))
dimnames(out)[[3]] <- prep$cats
out
}

posterior_epred_dirichlet <- function(prep) {
get_probs <- function(i) {
eta <- insert_refcat(slice_col(eta, i), refcat = prep$refcat)
Expand Down
10 changes: 10 additions & 0 deletions R/posterior_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,16 @@ posterior_predict_multinomial <- function(i, prep, ...) {
rblapply(seq_rows(p), function(s) t(rmultinom(1, size, p[s, ])))
}

posterior_predict_dirichlet_multinomial <- function(i, prep, ...) {
eta <- get_Mu(prep, i = i)
eta <- insert_refcat(eta, refcat = prep$refcat)
phi <- get_dpar(prep, "phi", i = i)
alpha <- dcategorical(seq_len(prep$data$ncat), eta = eta) * phi
p <- rdirichlet(prep$ndraws, alpha = alpha)
size <- prep$data$trials[i]
rblapply(seq_rows(p), function(s) t(rmultinom(1, size, p[s, ])))
}

posterior_predict_dirichlet <- function(i, prep, ...) {
eta <- get_Mu(prep, i = i)
eta <- insert_refcat(eta, refcat = prep$refcat)
Expand Down
8 changes: 8 additions & 0 deletions R/stan-likelihood.R
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,14 @@ stan_log_lik_multinomial <- function(bterms, ...) {
sdist("multinomial_logit2", p$mu, vec = FALSE)
}

stan_log_lik_dirichlet_multinomial <- function(bterms, ...) {
stopifnot(bterms$family$link == "logit")
mu <- stan_log_lik_dpars(bterms, reqn = TRUE, dpars = "mu", type = "multi")$mu
reqn_phi <- is_pred_dpar(bterms, "phi")
phi <- stan_log_lik_dpars(bterms, reqn = reqn_phi, dpars = "phi")$phi
sdist("dirichlet_multinomial_logit2", mu, phi, vec = FALSE)
}

stan_log_lik_dirichlet <- function(bterms, ...) {
stopifnot(bterms$family$link == "logit")
mu <- stan_log_lik_dpars(bterms, reqn = TRUE, dpars = "mu", type = "multi")$mu
Expand Down
20 changes: 20 additions & 0 deletions inst/chunks/fun_dirichlet_multinomial_logit.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
/* dirichlet-multinomial-logit log-PMF
* Args:
* y: array of integer response values
* mu: vector of category logit probabilities
* phi: precision parameter (sum of Dirichlet alphas)
* Returns:
* a scalar to be added to the log posterior
*/
real dirichlet_multinomial_logit2_lpmf(array[] int y, vector mu, real phi) {
// get Dirichlet alphas
int N = num_elements(mu);
vector[N] alpha = phi * softmax(mu);

// get trials from y
real T = sum(y);

real output = lgamma(phi) + lgamma(T + 1.0) - lgamma(T + phi) +
sum(lgamma(to_vector(y) + alpha)) - sum(lgamma(alpha)) - sum(lgamma(to_vector(y) + 1));
return output;
}
19 changes: 12 additions & 7 deletions man/brmsfamily.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

25 changes: 25 additions & 0 deletions tests/local/tests.models-4.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,31 @@ test_that("multinomial models work correctly", suppressWarnings({
expect_ggplot(plot(ce, ask = FALSE)[[1]])
}))

test_that("dirichlet_multinomial models work correctly", suppressWarnings({
require("extraDistr")
set.seed(1245)
N <- 100
dat <- as.data.frame(extraDistr::rdirmnom(N, 10, c(10, 5, 1)))
names(dat) <- paste0("y", 1:3)
dat$size <- with(dat, y1 + y2 + y3)
dat$x <- rnorm(N)
dat$y <- with(dat, cbind(y1, y2, y3))

fit <- brm(
y | trials(size) ~ x, data = dat,
family = dirichlet_multinomial(),
prior = prior("exponential(0.01)", "phi")
)
print(summary(fit))
pred <- predict(fit)
expect_equal(dim(pred), c(nobs(fit), 4, 3))
expect_equal(dimnames(pred)[[3]], c("y1", "y2", "y3"))
waic <- waic(fit)
expect_range(waic$estimates[3, 1], 550, 650)
ce <- conditional_effects(fit, categorical = TRUE)
expect_ggplot(plot(ce, ask = FALSE)[[1]])
}))

test_that("dirichlet models work correctly", suppressWarnings({
set.seed(1246)
N <- 100
Expand Down
6 changes: 6 additions & 0 deletions tests/testthat/tests.log_lik.R
Original file line number Diff line number Diff line change
Expand Up @@ -445,6 +445,12 @@ test_that("log_lik for categorical and related models runs without erros", {
ll <- sapply(1:nobs, brms:::log_lik_multinomial, prep = prep)
expect_equal(dim(ll), c(ns, nobs))

prep$data$trials <- sample(1:20, nobs)
prep$dpars$phi <- rexp(ns, 10)
prep$family <- dirichlet_multinomial()
ll <- sapply(1:nobs, brms:::log_lik_dirichlet_multinomial, prep = prep)
expect_equal(dim(ll), c(ns, nobs))

prep$data$Y <- prep$data$Y / rowSums(prep$data$Y)
prep$dpars$phi <- rexp(ns, 10)
prep$family <- dirichlet()
Expand Down
6 changes: 5 additions & 1 deletion tests/testthat/tests.posterior_epred.R
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ test_that("posterior_epred for advanced count data distributions runs without er
expect_equal(dim(pred), c(ns, nobs))
})

test_that("posterior_epred for multinomial and dirichlet models runs without errors", {
test_that("posterior_epred for multinomial, dirichlet_multinomial and dirichlet models runs without errors", {
ns <- 15
nobs <- 8
ncat <- 3
Expand All @@ -198,6 +198,10 @@ test_that("posterior_epred for multinomial and dirichlet models runs without err
pred <- brms:::posterior_epred_multinomial(prep = prep)
expect_equal(dim(pred), c(ns, nobs, ncat))

prep$family <- dirichlet_multinomial()
pred <- brms:::posterior_epred_dirichlet_multinomial(prep = prep)
expect_equal(dim(pred), c(ns, nobs, ncat))

prep$family <- dirichlet()
pred <- brms:::posterior_epred_dirichlet(prep = prep)
expect_equal(dim(pred), c(ns, nobs, ncat))
Expand Down
6 changes: 6 additions & 0 deletions tests/testthat/tests.posterior_predict.R
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,12 @@ test_that("posterior_predict for categorical and related models runs without err
pred <- brms:::posterior_predict_multinomial(i = sample(1:nobs, 1), prep = prep)
expect_equal(dim(pred), c(ns, ncat))

prep$data$trials <- sample(1:20, nobs)
prep$dpars$phi <- rexp(ns, 1)
prep$family <- dirichlet_multinomial()
pred <- brms:::posterior_predict_dirichlet_multinomial(i = sample(1:nobs, 1), prep = prep)
expect_equal(dim(pred), c(ns, ncat))

prep$dpars$phi <- rexp(ns, 1)
prep$family <- dirichlet()
pred <- brms:::posterior_predict_dirichlet(i = sample(1:nobs, 1), prep = prep)
Expand Down
23 changes: 23 additions & 0 deletions tests/testthat/tests.stancode.R
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,29 @@ test_that("Stan code for multinomial models is correct", {
expect_match2(scode, "lprior += normal_lpdf(Intercept_muy3 | 0, 2);")
})

test_that("Stan code for dirichlet_multinomial models is correct", {
N <- 15
dat <- data.frame(
y1 = rbinom(N, 10, 0.3), y2 = rbinom(N, 10, 0.5),
y3 = rbinom(N, 10, 0.7), x = rnorm(N)
)
dat$size <- with(dat, y1 + y2 + y3)
dat$y <- with(dat, cbind(y1, y2, y3))
prior <- prior(normal(0, 10), "b", dpar = muy2) +
prior(cauchy(0, 1), "Intercept", dpar = muy2) +
prior(normal(0, 2), "Intercept", dpar = muy3) +
prior(exponential(10), "phi")
scode <- stancode(bf(y | trials(size) ~ 1, muy2 ~ x), data = dat,
family = dirichlet_multinomial(), prior = prior)
expect_match2(scode, "array[N, ncat] int Y;")
expect_match2(scode, "target += dirichlet_multinomial_logit2_lpmf(Y[n] | mu[n], phi);")
expect_match2(scode, "muy2 += Intercept_muy2 + Xc_muy2 * b_muy2;")
expect_match2(scode, "lprior += normal_lpdf(b_muy2 | 0, 10);")
expect_match2(scode, "lprior += cauchy_lpdf(Intercept_muy2 | 0, 1);")
expect_match2(scode, "lprior += normal_lpdf(Intercept_muy3 | 0, 2);")
expect_match2(scode, "lprior += exponential_lpdf(phi | 10);")
})

test_that("Stan code for dirichlet models is correct", {
N <- 15
dat <- as.data.frame(rdirichlet(N, c(3, 2, 1)))
Expand Down
Loading