Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lockwo committed Jul 28, 2024
1 parent 8e35717 commit 64f4a78
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 24 deletions.
2 changes: 1 addition & 1 deletion distreqx/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, logits: Optional[Array] = None, probs: Optional[Array] = None
self._logits = None if logits is None else normalize(logits=logits)

@property
def event_shape(self):
def event_shape(self) -> tuple:
"""Shape of event of distribution samples."""
return ()

Expand Down
3 changes: 1 addition & 2 deletions distreqx/distributions/independent.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import jax.tree_util as jtu
from jaxtyping import Array, PRNGKeyArray, PyTree

from .._custom_types import EventT
from ._distribution import (
AbstractCDFDistribution,
AbstractDistribution,
Expand Down Expand Up @@ -43,7 +42,7 @@ def __init__(
self._distribution = distribution

@property
def event_shape(self) -> EventT:
def event_shape(self) -> tuple:
"""Shape of event of distribution samples."""
return self._distribution.event_shape

Expand Down
34 changes: 13 additions & 21 deletions examples/02_mixture_models.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,23 @@
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"from jax import numpy as jnp\n",
"from jax.scipy.special import expit, logit\n",
"import equinox as eqx\n",
"import jax\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"import optax\n",
"import tensorflow_datasets as tfds\n",
"from jax import numpy as jnp\n",
"from jax.scipy.special import expit, logit\n",
"from tqdm.notebook import tqdm\n",
"\n",
"from distreqx.distributions import (\n",
" MixtureSameFamily,\n",
" Bernoulli,\n",
" Categorical,\n",
" Independent,\n",
" Bernoulli,\n",
" MixtureSameFamily,\n",
" Normal,\n",
")\n",
"from tqdm.notebook import tqdm\n",
"import numpy as np"
")"
]
},
{
Expand Down Expand Up @@ -125,7 +126,8 @@
" def plot(self, n_row, n_col):\n",
" if n_row * n_col != len(self.mixing_coeffs):\n",
" raise TypeError(\n",
" \"The number of rows and columns does not match with the number of component distribution.\"\n",
" \"The number of rows and columns does not match with \"\n",
" \"the number of component distribution.\"\n",
" )\n",
" fig, axes = plt.subplots(n_row, n_col)\n",
"\n",
Expand Down Expand Up @@ -489,7 +491,8 @@
" def plot(self, n_row, n_col):\n",
" if n_row * n_col != len(self.mixing_coeffs):\n",
" raise TypeError(\n",
" \"The number of rows and columns does not match with the number of component distribution.\"\n",
" \"The number of rows and columns does not match with the \"\n",
" \"number of component distribution.\"\n",
" )\n",
" fig, axes = plt.subplots(n_row, n_col)\n",
"\n",
Expand Down Expand Up @@ -700,17 +703,6 @@
" return model\n",
"\n",
"\n",
"def loss_fn(model, params, inp):\n",
" model = update_model_params(model, params)\n",
" return -model.expected_log_likelihood(inp)\n",
"\n",
"\n",
"def vmap_loss(params, model, batch):\n",
" return jnp.mean(\n",
" eqx.filter_vmap(loss_fn, in_axes=(None, None, 0))(model, params, batch)\n",
" )\n",
"\n",
"\n",
"@eqx.filter_jit\n",
"def step(model, params, batch, opt_state):\n",
" loss, grads = eqx.filter_value_and_grad(vmap_loss)(params, model, batch)\n",
Expand Down
58 changes: 58 additions & 0 deletions tests/mixture_same_family_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import unittest
from parameterized import parameterized # type: ignore

import jax
import jax.numpy as jnp
import equinox as eqx

from distreqx.distributions import MixtureSameFamily, Categorical
from distreqx.distributions import MultivariateNormalDiag, Normal


class MixtureSameFamilyTest(unittest.TestCase):

def setUp(self):
super().setUp()
self.key = jax.random.PRNGKey(0)
self.num_components = 3
self.logits_shape = (self.num_components,)
self.logits = jax.random.normal(key=self.key, shape=self.logits_shape)
self.probs = jax.nn.softmax(self.logits, axis=-1)

key_loc, key_scale = jax.random.split(self.key)
self.components_shape = (5,)
self.loc = jax.random.normal(key=key_loc, shape=self.logits_shape + self.components_shape)
self.scale_diag = jax.random.uniform(key=key_scale, shape=self.logits_shape + self.components_shape) + 0.5

def test_event_shape(self):
mixture_dist = Categorical(logits=self.logits)
components_dist = eqx.filter_vmap(MultivariateNormalDiag)(self.loc, self.scale_diag)
dist = MixtureSameFamily(mixture_distribution=mixture_dist, components_distribution=components_dist)
self.assertEqual(dist.event_shape, self.logits_shape + self.components_shape)

def test_sample_shape(self):
mixture_dist = Categorical(logits=self.logits)
components_dist = eqx.filter_vmap(MultivariateNormalDiag)(self.loc, self.scale_diag)
dist = MixtureSameFamily(mixture_distribution=mixture_dist, components_distribution=components_dist)
samples = dist.sample(self.key)
self.assertEqual(samples.shape, self.components_shape)

@parameterized.expand([
("mean", "mean"),
("variance", "variance"),
("stddev", "stddev"),
])
def test_method(self, name, method_name):
mixture_dist = Categorical(logits=self.logits)
components_dist = eqx.filter_vmap(MultivariateNormalDiag)(self.loc, self.scale_diag)
dist = MixtureSameFamily(mixture_distribution=mixture_dist, components_distribution=components_dist)
method = getattr(dist, method_name)
result = method()
self.assertIsInstance(result, jnp.ndarray)

def test_jittable(self):
mixture_dist = Categorical(logits=self.logits)
components_dist = eqx.filter_vmap(MultivariateNormalDiag)(self.loc, self.scale_diag)
dist = MixtureSameFamily(mixture_distribution=mixture_dist, components_distribution=components_dist)
sample = eqx.filter_jit(dist.sample)(self.key)
self.assertIsInstance(sample, jnp.ndarray)

0 comments on commit 64f4a78

Please sign in to comment.