Skip to content

Commit

Permalink
Add mixture model (#25)
Browse files Browse the repository at this point in the history
* intermediate work

* docs

* vector

* typing

* a

* if array

* batch

* mask

* mask2

* dot

* vmap

* categorical distribution (#18)

* categorical distribution

* categorical distribution v2

---------

Co-authored-by: Owen Lockwood <[email protected]>

* example + categorical

* add tests

* test

* cat

* a

* a2

* doc strings

* doc

---------

Co-authored-by: Mayalen Etcheverry <[email protected]>
  • Loading branch information
lockwo and mayalenE authored Jul 28, 2024
1 parent f91eadc commit 6fd83c9
Show file tree
Hide file tree
Showing 18 changed files with 1,810 additions and 14 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repos:
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.315
rev: v1.1.368
hooks:
- id: pyright
additional_dependencies: ["equinox", "pytest", "jax", "jaxtyping"]
1 change: 1 addition & 0 deletions distreqx/bijectors/block.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Wrapper to turn independent Bijectors into block Bijectors."""

from jaxtyping import Array

from ..utils import sum_last
Expand Down
1 change: 1 addition & 0 deletions distreqx/bijectors/tanh.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Tanh bijector."""

import jax
import jax.numpy as jnp
from jaxtyping import Array
Expand Down
2 changes: 2 additions & 0 deletions distreqx/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
AbstractSurivialDistribution as AbstractSurivialDistribution,
)
from .bernoulli import Bernoulli as Bernoulli
from .categorical import Categorical as Categorical
from .independent import Independent as Independent
from .mixture_same_family import MixtureSameFamily as MixtureSameFamily
from .mvn_diag import MultivariateNormalDiag as MultivariateNormalDiag
from .mvn_from_bijector import (
AbstractMultivariateNormalFromBijector as AbstractMultivariateNormalFromBijector,
Expand Down
4 changes: 2 additions & 2 deletions distreqx/distributions/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ def probs(self) -> Array:
return jax.nn.sigmoid(self._logits)

@property
def event_shape(self) -> tuple[int]:
return self.prob.shape
def event_shape(self) -> tuple[int, ...]:
return self.probs.shape

def _log_probs_parameter(self) -> tuple[Array, Array]:
if self._logits is None:
Expand Down
232 changes: 232 additions & 0 deletions distreqx/distributions/categorical.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,232 @@
"""Categorical distribution."""

from typing import Optional, Union

import jax
import jax.numpy as jnp
from jaxtyping import Array, PRNGKeyArray

from ..utils.math import mul_exp, multiply_no_nan, normalize
from ._distribution import (
AbstractSampleLogProbDistribution,
AbstractSTDDistribution,
AbstractSurivialDistribution,
)


class Categorical(
AbstractSTDDistribution,
AbstractSampleLogProbDistribution,
AbstractSurivialDistribution,
strict=True,
):
"""Categorical distribution over integers.
The Categorical distribution is parameterized by either probabilities (`probs`) or
unormalized log-probabilities (`logits`) of a set of `K` classes.
It is defined over the integers `{0, 1, ..., K-1}`.
"""

_logits: Union[Array, None]
_probs: Union[Array, None]

def __init__(self, logits: Optional[Array] = None, probs: Optional[Array] = None):
"""Initializes a Categorical distribution.
**Arguments:**
- `logits`: Logit transform of the probability of each category. Only one
of `logits` or `probs` can be specified.
- `probs`: Probability of each category. Only one of `logits` or `probs` can
be specified.
"""
if (logits is None) == (probs is None):
raise ValueError(
f"One and exactly one of `logits` and `probs` should be `None`, "
f"but `logits` is {logits} and `probs` is {probs}."
)
if (not isinstance(logits, jax.Array)) and (not isinstance(probs, jax.Array)):
raise ValueError("`logits` and `probs` are not jax arrays.")

self._probs = None if probs is None else normalize(probs=probs)
self._logits = None if logits is None else normalize(logits=logits)

@property
def event_shape(self) -> tuple:
"""Shape of event of distribution samples."""
return ()

@property
def logits(self) -> Array:
"""The logits for each event."""
if self._logits is not None:
return self._logits
if self._probs is None:
raise ValueError(
"_probs and _logits are None!"
) # TODO: useless but needed for pyright
return jnp.log(self._probs)

@property
def probs(self) -> Array:
"""The probabilities for each event."""
if self._probs is not None:
return self._probs
if self._logits is None:
raise ValueError(
"_probs and _logits are None!"
) # TODO: useless but needed for pyright
return jax.nn.softmax(self._logits, axis=-1)

@property
def num_categories(self) -> int:
"""Number of categories."""
if self._probs is not None:
return self._probs.shape[-1]
if self._logits is None:
raise ValueError(
"_probs and _logits are None!"
) # TODO: useless but needed for pyright
return self._logits.shape[-1]

def sample(self, key: PRNGKeyArray) -> Array:
"""See `Distribution.sample`."""
is_valid = jnp.logical_and(
jnp.all(jnp.isfinite(self.probs), axis=-1),
jnp.all(self.probs >= 0, axis=-1),
)
draws = jax.random.categorical(key=key, logits=self.logits, axis=-1).astype(
"int8"
)
return jnp.where(is_valid, draws, jnp.ones_like(draws) * -1)

def log_prob(self, value: Array) -> Array:
"""See `Distribution.log_prob`."""
value_one_hot = jax.nn.one_hot(
value, self.num_categories, dtype=self.logits.dtype
)
mask_outside_domain = jnp.logical_or(value < 0, value > self.num_categories - 1)
return jnp.where(
mask_outside_domain,
-jnp.inf,
jnp.sum(multiply_no_nan(self.logits, value_one_hot), axis=-1),
)

def prob(self, value: Array) -> Array:
"""See `Distribution.prob`."""
value_one_hot = jax.nn.one_hot(
value, self.num_categories, dtype=self.probs.dtype
)
return jnp.sum(multiply_no_nan(self.probs, value_one_hot), axis=-1)

def entropy(self) -> Array:
"""See `Distribution.entropy`."""
if self._logits is None:
if self._probs is None:
raise ValueError(
"_probs and _logits are None!"
) # TODO: useless but needed for pyright
log_probs = jnp.log(self._probs)
else:
log_probs = jax.nn.log_softmax(self._logits)
return -jnp.sum(mul_exp(log_probs, log_probs), axis=-1)

def mode(self) -> Array:
"""See `Distribution.mode`."""
if self._logits is None:
if self._probs is None:
raise ValueError(
"_probs and _logits are None!"
) # TODO: useless but needed for pyright
parameter = self.probs
else:
parameter = self.logits
return jnp.argmax(parameter, axis=-1).astype("int8")

def cdf(self, value: Array) -> Array:
"""See `Distribution.cdf`."""
# For value < 0 the output should be zero because support = {0, ..., K-1}.
should_be_zero = value < 0
# For value >= K-1 the output should be one. Explicitly accounting for this
# case addresses potential numerical issues that may arise when evaluating
# derived methods (mainly, `log_survival_function`) for `value >= K-1`.
should_be_one = value >= self.num_categories - 1
# Will use value as an index below, so clip it to {0, ..., K-1}.
value = jnp.clip(value, 0, self.num_categories - 1)
value_one_hot = jax.nn.one_hot(
value, self.num_categories, dtype=self.probs.dtype
)
cdf = jnp.sum(
multiply_no_nan(jnp.cumsum(self.probs, axis=-1), value_one_hot), axis=-1
)
return jnp.where(should_be_zero, 0.0, jnp.where(should_be_one, 1.0, cdf))

def log_cdf(self, value: Array) -> Array:
"""See `Distribution.log_cdf`."""
return jnp.log(self.cdf(value))

def median(self):
raise NotImplementedError

def variance(self):
raise NotImplementedError

def mean(self):
raise NotImplementedError

def kl_divergence(self, other_dist, **kwargs) -> Array:
"""Calculates the KL divergence to another distribution.
**Arguments:**
- `other_dist`: A compatible disteqx distribution.
- `kwargs`: Additional kwargs.
**Returns:**
The KL divergence `KL(self || other_dist)`.
"""
return _kl_divergence_categorical_categorical(self, other_dist)


def _kl_divergence_categorical_categorical(
dist1: Categorical,
dist2: Categorical,
*unused_args,
**unused_kwargs,
) -> Array:
"""Obtains the KL divergence `KL(dist1 || dist2)` between two Categoricals.
The KL computation takes into account that `0 * log(0) = 0`; therefore,
`dist1` may have zeros in its probability vector.
**Arguments:**
- `dist1`: A Categorical distribution.
- `dist2`: A Categorical distribution.
**Returns:**
`KL(dist1 || dist2)`.
**Raises:**
ValueError if the two distributions have different number of categories.
"""
logits1 = dist1.logits
logits2 = dist2.logits

num_categories1 = logits1.shape[-1]
num_categories2 = logits2.shape[-1]

if num_categories1 != num_categories2:
raise ValueError(
f"Cannot obtain the KL between two Categorical distributions "
f"with different number of categories: the first distribution has "
f"{num_categories1} categories, while the second distribution has "
f"{num_categories2} categories."
)

log_probs1 = jax.nn.log_softmax(logits1, axis=-1)
log_probs2 = jax.nn.log_softmax(logits2, axis=-1)
return jnp.sum(mul_exp(log_probs1 - log_probs2, log_probs1), axis=-1)
3 changes: 1 addition & 2 deletions distreqx/distributions/independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import jax.tree_util as jtu
from jaxtyping import Array, PRNGKeyArray, PyTree

from .._custom_types import EventT
from ._distribution import (
AbstractCDFDistribution,
AbstractDistribution,
Expand Down Expand Up @@ -43,7 +42,7 @@ def __init__(
self._distribution = distribution

@property
def event_shape(self) -> EventT:
def event_shape(self) -> tuple:
"""Shape of event of distribution samples."""
return self._distribution.event_shape

Expand Down
Loading

0 comments on commit 6fd83c9

Please sign in to comment.