-
-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
4 changed files
with
73 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import unittest | ||
from parameterized import parameterized # type: ignore | ||
|
||
import jax | ||
import jax.numpy as jnp | ||
import equinox as eqx | ||
|
||
from distreqx.distributions import MixtureSameFamily, Categorical | ||
from distreqx.distributions import MultivariateNormalDiag, Normal | ||
|
||
|
||
class MixtureSameFamilyTest(unittest.TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.key = jax.random.PRNGKey(0) | ||
self.num_components = 3 | ||
self.logits_shape = (self.num_components,) | ||
self.logits = jax.random.normal(key=self.key, shape=self.logits_shape) | ||
self.probs = jax.nn.softmax(self.logits, axis=-1) | ||
|
||
key_loc, key_scale = jax.random.split(self.key) | ||
self.components_shape = (5,) | ||
self.loc = jax.random.normal(key=key_loc, shape=self.logits_shape + self.components_shape) | ||
self.scale_diag = jax.random.uniform(key=key_scale, shape=self.logits_shape + self.components_shape) + 0.5 | ||
|
||
def test_event_shape(self): | ||
mixture_dist = Categorical(logits=self.logits) | ||
components_dist = eqx.filter_vmap(MultivariateNormalDiag)(self.loc, self.scale_diag) | ||
dist = MixtureSameFamily(mixture_distribution=mixture_dist, components_distribution=components_dist) | ||
self.assertEqual(dist.event_shape, self.logits_shape + self.components_shape) | ||
|
||
def test_sample_shape(self): | ||
mixture_dist = Categorical(logits=self.logits) | ||
components_dist = eqx.filter_vmap(MultivariateNormalDiag)(self.loc, self.scale_diag) | ||
dist = MixtureSameFamily(mixture_distribution=mixture_dist, components_distribution=components_dist) | ||
samples = dist.sample(self.key) | ||
self.assertEqual(samples.shape, self.components_shape) | ||
|
||
@parameterized.expand([ | ||
("mean", "mean"), | ||
("variance", "variance"), | ||
("stddev", "stddev"), | ||
]) | ||
def test_method(self, name, method_name): | ||
mixture_dist = Categorical(logits=self.logits) | ||
components_dist = eqx.filter_vmap(MultivariateNormalDiag)(self.loc, self.scale_diag) | ||
dist = MixtureSameFamily(mixture_distribution=mixture_dist, components_distribution=components_dist) | ||
method = getattr(dist, method_name) | ||
result = method() | ||
self.assertIsInstance(result, jnp.ndarray) | ||
|
||
def test_jittable(self): | ||
mixture_dist = Categorical(logits=self.logits) | ||
components_dist = eqx.filter_vmap(MultivariateNormalDiag)(self.loc, self.scale_diag) | ||
dist = MixtureSameFamily(mixture_distribution=mixture_dist, components_distribution=components_dist) | ||
sample = eqx.filter_jit(dist.sample)(self.key) | ||
self.assertIsInstance(sample, jnp.ndarray) |