From 77b8ee2ccd42cbeeb55c478032e3ab7b66c48076 Mon Sep 17 00:00:00 2001 From: Owen Lockwood <42878312+lockwo@users.noreply.github.com> Date: Tue, 11 Jun 2024 22:55:10 -0700 Subject: [PATCH] Enforces strict modules (#20) * distribution part 1 * bijectors * beartyping * design --- CONTRIBUTING.md | 2 +- README.md | 1 + distreqx/_custom_types.py | 4 +- distreqx/bijectors/__init__.py | 7 +- distreqx/bijectors/_bijector.py | 94 ++++----- distreqx/bijectors/_linear.py | 15 +- distreqx/bijectors/block.py | 17 +- distreqx/bijectors/chain.py | 33 ++- distreqx/bijectors/diag_linear.py | 15 +- distreqx/bijectors/scalar_affine.py | 109 ++++++---- distreqx/bijectors/shift.py | 21 +- distreqx/bijectors/sigmoid.py | 20 +- distreqx/bijectors/tanh.py | 25 ++- distreqx/bijectors/triangular_linear.py | 10 +- distreqx/distributions/__init__.py | 11 +- distreqx/distributions/_distribution.py | 144 ++++++++++++- distreqx/distributions/bernoulli.py | 28 ++- distreqx/distributions/independent.py | 17 +- distreqx/distributions/mvn_diag.py | 41 +++- distreqx/distributions/mvn_from_bijector.py | 135 +++++++----- distreqx/distributions/mvn_tri.py | 47 ++++- distreqx/distributions/normal.py | 9 +- distreqx/distributions/transformed.py | 223 +++++++++++++------- distreqx/utils/math.py | 10 +- docs/api/bijectors/_bijector.md | 24 ++- docs/api/bijectors/_linear.md | 8 - docs/api/distributions/_distribution.md | 53 ++++- docs/api/distributions/mvn_from_bijector.md | 1 - docs/api/distributions/transformed.md | 6 +- mkdocs.yml | 11 +- tests/abstractbijector_test.py | 29 +-- tests/abstractdistribution_test.py | 57 ++++- tests/abstractlinear_test.py | 33 ++- tests/mvn_from_bijector_test.py | 34 ++- 34 files changed, 918 insertions(+), 376 deletions(-) delete mode 100644 docs/api/bijectors/_linear.md diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 713d4a7..54d7e65 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -29,7 +29,7 @@ These hooks use Black and isort to format the code, and flake8 to lint it. **If you're making changes to the code:** -Now make your changes. Make sure to include additional tests if necessary. +Now make your changes. Make sure to include additional tests if necessary. Be sure to consider the abstract/final design pattern: https://docs.kidger.site/equinox/pattern/. If you include a new features, there are 3 required classes of tests: - Correctness: tests the are against analytic or known solutions that ensure the computation is correct diff --git a/README.md b/README.md index 57fbf19..e18b7f1 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,7 @@ from distreqx import - No official support/interoperability with TFP - The concept of a batch dimension is dropped. If you want to operate on a batch, use `vmap` (note, this can be used in construction as well, e.g. [vmaping the construction](https://docs.kidger.site/equinox/tricks/#ensembling) of a `ScalarAffine`) - Broader pytree enablement +- Strict [abstract/final](https://docs.kidger.site/equinox/pattern/) design pattern ## Citation diff --git a/distreqx/_custom_types.py b/distreqx/_custom_types.py index f81bc1e..eb4a457 100644 --- a/distreqx/_custom_types.py +++ b/distreqx/_custom_types.py @@ -1,4 +1,4 @@ -from typing import Tuple, Union +from typing import Union import jax from jaxtyping import ( @@ -6,4 +6,4 @@ ) -EventT = Union[Tuple[int], PyTree[jax.ShapeDtypeStruct]] +EventT = Union[tuple[int], PyTree[jax.ShapeDtypeStruct]] diff --git a/distreqx/bijectors/__init__.py b/distreqx/bijectors/__init__.py index 8ac908f..80ea071 100644 --- a/distreqx/bijectors/__init__.py +++ b/distreqx/bijectors/__init__.py @@ -1,4 +1,9 @@ -from ._bijector import AbstractBijector as AbstractBijector +from ._bijector import ( + AbstractBijector as AbstractBijector, + AbstractFowardInverseBijector as AbstractFowardInverseBijector, + AbstractFwdLogDetJacBijector as AbstractFwdLogDetJacBijector, + AbstractInvLogDetJacBijector as AbstractInvLogDetJacBijector, +) from ._linear import AbstractLinearBijector as AbstractLinearBijector from .block import Block as Block from .chain import Chain as Chain diff --git a/distreqx/bijectors/_bijector.py b/distreqx/bijectors/_bijector.py index 3cbdf26..73923d7 100644 --- a/distreqx/bijectors/_bijector.py +++ b/distreqx/bijectors/_bijector.py @@ -1,5 +1,4 @@ from abc import abstractmethod -from typing import Optional, Tuple import equinox as eqx from jaxtyping import Array, PyTree @@ -27,69 +26,38 @@ class AbstractBijector(eqx.Module, strict=True): will assume these properties hold, and will make no attempt to verify them. """ - _is_constant_jacobian: bool - _is_constant_log_det: bool - - def __init__( - self, - is_constant_jacobian: bool = False, - is_constant_log_det: Optional[bool] = None, - ): - """Initializes a Bijector. - - **Arguments:** - - - `is_constant_jacobian`: Whether the Jacobian is promised to be constant - (which is the case if and only if the bijector is affine). A value of - False will be interpreted as "we don't know whether the Jacobian is - constant", rather than "the Jacobian is definitely not constant". Only - set to True if you're absolutely sure the Jacobian is constant; if - you're not sure, set to False. - - `is_constant_log_det`: Whether the Jacobian determinant is promised to be - constant (which is the case for, e.g., volume-preserving bijectors). If - None, it defaults to `is_constant_jacobian`. Note that the Jacobian - determinant can be constant without the Jacobian itself being constant. - Only set to True if you're absoltely sure the Jacobian determinant is - constant; if you're not sure, set to False. - """ - if is_constant_log_det is None: - is_constant_log_det = is_constant_jacobian - if is_constant_jacobian and not is_constant_log_det: - raise ValueError( - "The Jacobian is said to be constant, but its " - "determinant is said not to be, which is impossible." - ) - self._is_constant_jacobian = is_constant_jacobian - self._is_constant_log_det = is_constant_log_det + _is_constant_jacobian: eqx.AbstractVar[bool] + _is_constant_log_det: eqx.AbstractVar[bool] + @abstractmethod def forward(self, x: PyTree) -> PyTree: R"""Computes $y = f(x)$.""" - y, _ = self.forward_and_log_det(x) - return y + raise NotImplementedError + @abstractmethod def inverse(self, y: PyTree) -> PyTree: r"""Computes $x = f^{-1}(y)$.""" - x, _ = self.inverse_and_log_det(y) - return x + raise NotImplementedError + @abstractmethod def forward_log_det_jacobian(self, x: PyTree) -> PyTree: r"""Computes $\log|\det J(f)(x)|$.""" - _, logdet = self.forward_and_log_det(x) - return logdet + raise NotImplementedError + @abstractmethod def inverse_log_det_jacobian(self, y: PyTree) -> PyTree: r"""Computes $\log|\det J(f^{-1})(y)|$.""" - _, logdet = self.inverse_and_log_det(y) - return logdet + raise NotImplementedError @abstractmethod - def forward_and_log_det(self, x: PyTree) -> Tuple[PyTree, PyTree]: + def forward_and_log_det(self, x: PyTree) -> tuple[PyTree, PyTree]: r"""Computes $y = f(x)$ and $\log|\det J(f)(x)|$.""" raise NotImplementedError( f"Bijector {self.name} does not implement `forward_and_log_det`." ) - def inverse_and_log_det(self, y: Array) -> Tuple[PyTree, PyTree]: + @abstractmethod + def inverse_and_log_det(self, y: Array) -> tuple[PyTree, PyTree]: r"""Computes $x = f^{-1}(y)$ and $\log|\det J(f^{-1})(y)|$.""" raise NotImplementedError( f"Bijector {self.name} does not implement `inverse_and_log_det`." @@ -110,7 +78,39 @@ def name(self) -> str: """Name of the bijector.""" return self.__class__.__name__ + @abstractmethod def same_as(self, other) -> bool: """Returns True if this bijector is guaranteed to be the same as `other`.""" - del other - return False + raise NotImplementedError + + +class AbstractInvLogDetJacBijector(AbstractBijector, strict=True): + """AbstractBijector + concrete `inverse_log_det_jacobian`.""" + + def inverse_log_det_jacobian(self, y: PyTree) -> PyTree: + r"""Computes $\log|\det J(f^{-1})(y)|$.""" + _, logdet = self.inverse_and_log_det(y) + return logdet + + +class AbstractFwdLogDetJacBijector(AbstractBijector, strict=True): + """AbstractBijector + concrete `forward_log_det_jacobian`.""" + + def forward_log_det_jacobian(self, x: PyTree) -> PyTree: + r"""Computes $\log|\det J(f)(x)|$.""" + _, logdet = self.forward_and_log_det(x) + return logdet + + +class AbstractFowardInverseBijector(AbstractBijector, strict=True): + """AbstractBijector + concrete `forward` and `reverse`.""" + + def forward(self, x: PyTree) -> PyTree: + R"""Computes $y = f(x)$.""" + y, _ = self.forward_and_log_det(x) + return y + + def inverse(self, y: PyTree) -> PyTree: + r"""Computes $x = f^{-1}(y)$.""" + x, _ = self.inverse_and_log_det(y) + return x diff --git a/distreqx/bijectors/_linear.py b/distreqx/bijectors/_linear.py index f746ffe..6799a11 100644 --- a/distreqx/bijectors/_linear.py +++ b/distreqx/bijectors/_linear.py @@ -1,28 +1,19 @@ """Linear bijector.""" +import equinox as eqx from jaxtyping import Array from ._bijector import AbstractBijector -class AbstractLinearBijector(AbstractBijector): +class AbstractLinearBijector(AbstractBijector, strict=True): """Base class for linear bijectors. This class provides a base class for bijectors defined as `f(x) = Ax`, where `A` is a `DxD` matrix and `x` is a `D`-dimensional vector. """ - _event_dims: int - - def __init__(self, event_dims: int): - """Initializes a `Linear` bijector. - - **Arguments:** - - - `event_dims`: the dimensionality of the vector `D` - """ - super().__init__(is_constant_jacobian=True) - self._event_dims = event_dims + _event_dims: eqx.AbstractVar[int] @property def matrix(self) -> Array: diff --git a/distreqx/bijectors/block.py b/distreqx/bijectors/block.py index bbc4749..9e12022 100644 --- a/distreqx/bijectors/block.py +++ b/distreqx/bijectors/block.py @@ -1,14 +1,11 @@ """Wrapper to turn independent Bijectors into block Bijectors.""" - -from typing import Tuple - from jaxtyping import Array from ..utils import sum_last from ._bijector import AbstractBijector -class Block(AbstractBijector): +class Block(AbstractBijector, strict=True): """A wrapper that promotes a bijector to a block bijector. A block bijector applies a bijector to a k-dimensional array of events, but @@ -33,6 +30,8 @@ class Block(AbstractBijector): _ndims: int _bijector: AbstractBijector + _is_constant_jacobian: bool + _is_constant_log_det: bool def __init__(self, bijector: AbstractBijector, ndims: int): """Initializes a Block. @@ -47,10 +46,8 @@ def __init__(self, bijector: AbstractBijector, ndims: int): raise ValueError(f"`ndims` must be non-negative; got {ndims}.") self._bijector = bijector self._ndims = ndims - super().__init__( - is_constant_jacobian=self._bijector.is_constant_jacobian, - is_constant_log_det=self._bijector.is_constant_log_det, - ) + self._is_constant_jacobian = self._bijector.is_constant_jacobian + self._is_constant_log_det = self._bijector.is_constant_log_det @property def bijector(self) -> AbstractBijector: @@ -80,12 +77,12 @@ def inverse_log_det_jacobian(self, y: Array) -> Array: log_det = self._bijector.inverse_log_det_jacobian(y) return sum_last(log_det, self._ndims) - def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: + def forward_and_log_det(self, x: Array) -> tuple[Array, Array]: """Computes y = f(x) and log|det J(f)(x)|.""" y, log_det = self._bijector.forward_and_log_det(x) return y, sum_last(log_det, self._ndims) - def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: + def inverse_and_log_det(self, y: Array) -> tuple[Array, Array]: """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" x, log_det = self._bijector.inverse_and_log_det(y) return x, sum_last(log_det, self._ndims) diff --git a/distreqx/bijectors/chain.py b/distreqx/bijectors/chain.py index 3d2a98f..78b547c 100644 --- a/distreqx/bijectors/chain.py +++ b/distreqx/bijectors/chain.py @@ -1,13 +1,17 @@ """Chain Bijector for composing a sequence of Bijector transformations.""" -from typing import List, Sequence, Tuple +from typing import Sequence from jaxtyping import Array -from ._bijector import AbstractBijector +from ._bijector import ( + AbstractBijector, + AbstractFwdLogDetJacBijector, + AbstractInvLogDetJacBijector, +) -class Chain(AbstractBijector): +class Chain(AbstractFwdLogDetJacBijector, AbstractInvLogDetJacBijector, strict=True): """Composition of a sequence of bijectors into a single bijector. Bijectors are composable: if `f` and `g` are bijectors, then `g o f` is also @@ -30,7 +34,9 @@ class Chain(AbstractBijector): `y = f(g(x))`. """ - _bijectors: List[AbstractBijector] + _bijectors: list[AbstractBijector] + _is_constant_jacobian: bool + _is_constant_log_det: bool def __init__(self, bijectors: Sequence[AbstractBijector]): """Initializes a Chain bijector. @@ -47,13 +53,18 @@ def __init__(self, bijectors: Sequence[AbstractBijector]): is_constant_jacobian = all(b.is_constant_jacobian for b in self._bijectors) is_constant_log_det = all(b.is_constant_log_det for b in self._bijectors) - super().__init__( - is_constant_jacobian=is_constant_jacobian, - is_constant_log_det=is_constant_log_det, - ) + if is_constant_log_det is None: + is_constant_log_det = is_constant_jacobian + if is_constant_jacobian and not is_constant_log_det: + raise ValueError( + "The Jacobian is said to be constant, but its " + "determinant is said not to be, which is impossible." + ) + self._is_constant_jacobian = is_constant_jacobian + self._is_constant_log_det = is_constant_log_det @property - def bijectors(self) -> List[AbstractBijector]: + def bijectors(self) -> list[AbstractBijector]: """The list of bijectors in the chain.""" return self._bijectors @@ -69,7 +80,7 @@ def inverse(self, y: Array) -> Array: y = bijector.inverse(y) return y - def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: + def forward_and_log_det(self, x: Array) -> tuple[Array, Array]: """Computes y = f(x) and log|det J(f)(x)|.""" x, log_det = self._bijectors[-1].forward_and_log_det(x) for bijector in reversed(self._bijectors[:-1]): @@ -77,7 +88,7 @@ def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: log_det += ld return x, log_det - def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: + def inverse_and_log_det(self, y: Array) -> tuple[Array, Array]: """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" y, log_det = self._bijectors[0].inverse_and_log_det(y) for bijector in self._bijectors[1:]: diff --git a/distreqx/bijectors/diag_linear.py b/distreqx/bijectors/diag_linear.py index 3dc7a35..ece7b38 100644 --- a/distreqx/bijectors/diag_linear.py +++ b/distreqx/bijectors/diag_linear.py @@ -1,7 +1,5 @@ """Diagonal linear bijector.""" -from typing import Tuple - import jax.numpy as jnp from jaxtyping import Array @@ -11,7 +9,7 @@ from .scalar_affine import ScalarAffine -class DiagLinear(AbstractLinearBijector): +class DiagLinear(AbstractLinearBijector, strict=True): """Linear bijector with a diagonal weight matrix. The bijector is defined as `f(x) = Ax` where `A` is a `DxD` diagonal matrix. @@ -29,6 +27,9 @@ class DiagLinear(AbstractLinearBijector): _diag: Array _bijector: AbstractBijector + _is_constant_jacobian: bool + _is_constant_log_det: bool + _event_dims: int def __init__(self, diag: Array): """Initializes the bijector. @@ -42,8 +43,10 @@ def __init__(self, diag: Array): self._bijector = Block( ScalarAffine(shift=jnp.zeros_like(diag), scale=diag), ndims=diag.ndim ) - super().__init__(event_dims=diag.shape[-1]) + self._event_dims = diag.shape[-1] self._diag = diag + self._is_constant_jacobian = True + self._is_constant_log_det = True def forward(self, x: Array) -> Array: """Computes y = f(x).""" @@ -61,11 +64,11 @@ def inverse_log_det_jacobian(self, y: Array) -> Array: """Computes log|det J(f^{-1})(y)|.""" return self._bijector.inverse_log_det_jacobian(y) - def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: + def inverse_and_log_det(self, y: Array) -> tuple[Array, Array]: """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" return self._bijector.inverse_and_log_det(y) - def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: + def forward_and_log_det(self, x: Array) -> tuple[Array, Array]: """Computes y = f(x) and log|det J(f)(x)|.""" return self._bijector.forward_and_log_det(x) diff --git a/distreqx/bijectors/scalar_affine.py b/distreqx/bijectors/scalar_affine.py index 244df40..e29326e 100644 --- a/distreqx/bijectors/scalar_affine.py +++ b/distreqx/bijectors/scalar_affine.py @@ -1,14 +1,74 @@ """Scalar affine bijector.""" -from typing import Optional, Tuple +from typing import Optional +import equinox as eqx import jax.numpy as jnp from jaxtyping import Array from ._bijector import AbstractBijector -class ScalarAffine(AbstractBijector): +class AbstractScalarAffine(AbstractBijector, strict=True): + """An affine bijector that acts elementwise. + + The bijector is defined as follows: + + - Forward: `y = scale * x + shift` + - Forward Jacobian determinant: `log|det J(x)| = log|scale|` + - Inverse: `x = (y - shift) / scale` + - Inverse Jacobian determinant: `log|det J(y)| = -log|scale|` + + where `scale` and `shift` are the bijector's parameters. + """ + + _shift: eqx.AbstractVar[Array] + _scale: eqx.AbstractVar[Array] + _inv_scale: eqx.AbstractVar[Array] + _log_scale: eqx.AbstractVar[Array] + + @property + def shift(self) -> Array: + """The bijector's shift.""" + return self._shift + + @property + def log_scale(self) -> Array: + """The log of the bijector's scale.""" + return self._log_scale + + @property + def scale(self) -> Array: + """The bijector's scale.""" + assert self._scale is not None # By construction. + return self._scale + + def forward(self, x: Array) -> Array: + """Computes y = f(x).""" + return self._scale * x + self._shift + + def forward_log_det_jacobian(self, x: Array) -> Array: + """Computes log|det J(f)(x)|.""" + return self._log_scale + + def forward_and_log_det(self, x: Array) -> tuple[Array, Array]: + """Computes y = f(x) and log|det J(f)(x)|.""" + return self.forward(x), self.forward_log_det_jacobian(x) + + def inverse(self, y: Array) -> Array: + """Computes x = f^{-1}(y).""" + return self._inv_scale * (y - self._shift) + + def inverse_log_det_jacobian(self, y: Array) -> Array: + """Computes log|det J(f^{-1})(y)|.""" + return jnp.negative(self._log_scale) + + def inverse_and_log_det(self, y: Array) -> tuple[Array, Array]: + """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" + return self.inverse(y), self.inverse_log_det_jacobian(y) + + +class ScalarAffine(AbstractScalarAffine, strict=True): """An affine bijector that acts elementwise. The bijector is defined as follows: @@ -25,6 +85,8 @@ class ScalarAffine(AbstractBijector): _scale: Array _inv_scale: Array _log_scale: Array + _is_constant_jacobian: bool + _is_constant_log_det: bool def __init__( self, @@ -51,7 +113,8 @@ def __init__( - `ValueError`: if both `scale` and `log_scale` are not None. """ - super().__init__(is_constant_jacobian=True) + self._is_constant_jacobian = True + self._is_constant_log_det = True self._shift = shift if scale is None and log_scale is None: self._scale = jnp.ones_like(shift) @@ -70,46 +133,6 @@ def __init__( "Only one of `scale` and `log_scale` can be specified, not both." ) - @property - def shift(self) -> Array: - """The bijector's shift.""" - return self._shift - - @property - def log_scale(self) -> Array: - """The log of the bijector's scale.""" - return self._log_scale - - @property - def scale(self) -> Array: - """The bijector's scale.""" - assert self._scale is not None # By construction. - return self._scale - - def forward(self, x: Array) -> Array: - """Computes y = f(x).""" - return self._scale * x + self._shift - - def forward_log_det_jacobian(self, x: Array) -> Array: - """Computes log|det J(f)(x)|.""" - return self._log_scale - - def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: - """Computes y = f(x) and log|det J(f)(x)|.""" - return self.forward(x), self.forward_log_det_jacobian(x) - - def inverse(self, y: Array) -> Array: - """Computes x = f^{-1}(y).""" - return self._inv_scale * (y - self._shift) - - def inverse_log_det_jacobian(self, y: Array) -> Array: - """Computes log|det J(f^{-1})(y)|.""" - return jnp.negative(self._log_scale) - - def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: - """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" - return self.inverse(y), self.inverse_log_det_jacobian(y) - def same_as(self, other: AbstractBijector) -> bool: """Returns True if this bijector is guaranteed to be the same as `other`.""" if type(other) is ScalarAffine: diff --git a/distreqx/bijectors/shift.py b/distreqx/bijectors/shift.py index d0d68d9..7693b11 100644 --- a/distreqx/bijectors/shift.py +++ b/distreqx/bijectors/shift.py @@ -1,12 +1,13 @@ """Shift bijector.""" +from jax import numpy as jnp from jaxtyping import Array from ._bijector import AbstractBijector -from .scalar_affine import ScalarAffine +from .scalar_affine import AbstractScalarAffine -class Shift(ScalarAffine): +class Shift(AbstractScalarAffine, strict=True): """Bijector that translates its input elementwise. The bijector is defined as follows: @@ -19,6 +20,13 @@ class Shift(ScalarAffine): where `shift` parameterizes the bijector. """ + _shift: Array + _scale: Array + _inv_scale: Array + _log_scale: Array + _is_constant_jacobian: bool + _is_constant_log_det: bool + def __init__(self, shift: Array): """Initializes a `Shift` bijector. @@ -26,10 +34,15 @@ def __init__(self, shift: Array): - `shift`: the bijector's shift parameter. """ - super().__init__(shift=shift) + self._is_constant_jacobian = True + self._is_constant_log_det = True + self._shift = shift + self._scale = jnp.ones_like(shift) + self._inv_scale = jnp.ones_like(shift) + self._log_scale = jnp.zeros_like(shift) def same_as(self, other: AbstractBijector) -> bool: """Returns True if this bijector is guaranteed to be the same as `other`.""" - if type(other) is Shift: # pylint: disable=unidiomatic-typecheck + if type(other) is Shift: return self.shift is other.shift return False diff --git a/distreqx/bijectors/sigmoid.py b/distreqx/bijectors/sigmoid.py index 554c5aa..f2c7eb9 100644 --- a/distreqx/bijectors/sigmoid.py +++ b/distreqx/bijectors/sigmoid.py @@ -1,15 +1,17 @@ """Sigmoid bijector.""" -from typing import Tuple - import jax import jax.numpy as jnp from jaxtyping import Array -from ._bijector import AbstractBijector +from ._bijector import ( + AbstractBijector, + AbstractFowardInverseBijector, + AbstractInvLogDetJacBijector, +) -class Sigmoid(AbstractBijector): +class Sigmoid(AbstractFowardInverseBijector, AbstractInvLogDetJacBijector, strict=True): """A bijector that computes the logistic sigmoid. The log-determinant implementation in this bijector is more numerically stable @@ -34,18 +36,22 @@ class Sigmoid(AbstractBijector): instead of `sample` followed by `log_prob`. """ + _is_constant_log_det: bool + _is_constant_jacobian: bool + def __init__(self) -> None: - super().__init__() + self._is_constant_jacobian = False + self._is_constant_log_det = False def forward_log_det_jacobian(self, x: Array) -> Array: """Computes log|det J(f)(x)|.""" return -_more_stable_softplus(-x) - _more_stable_softplus(x) - def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: + def forward_and_log_det(self, x: Array) -> tuple[Array, Array]: """Computes y = f(x) and log|det J(f)(x)|.""" return _more_stable_sigmoid(x), self.forward_log_det_jacobian(x) - def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: + def inverse_and_log_det(self, y: Array) -> tuple[Array, Array]: """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" x = jnp.log(y) - jnp.log1p(-y) return x, -self.forward_log_det_jacobian(x) diff --git a/distreqx/bijectors/tanh.py b/distreqx/bijectors/tanh.py index 4292e70..3e384c4 100644 --- a/distreqx/bijectors/tanh.py +++ b/distreqx/bijectors/tanh.py @@ -1,15 +1,16 @@ """Tanh bijector.""" - -from typing import Tuple - import jax import jax.numpy as jnp from jaxtyping import Array -from ._bijector import AbstractBijector +from ._bijector import ( + AbstractBijector, + AbstractFowardInverseBijector, + AbstractInvLogDetJacBijector, +) -class Tanh(AbstractBijector): +class Tanh(AbstractFowardInverseBijector, AbstractInvLogDetJacBijector, strict=True): """A bijector that computes the hyperbolic tangent. The log-determinant implementation in this bijector is more numerically stable @@ -29,22 +30,26 @@ class Tanh(AbstractBijector): instead of `sample` followed by `log_prob`. """ - def __init__(self): - super().__init__() + _is_constant_log_det: bool + _is_constant_jacobian: bool + + def __init__(self) -> None: + self._is_constant_jacobian = False + self._is_constant_log_det = False def forward_log_det_jacobian(self, x: Array) -> Array: """Computes log|det J(f)(x)|.""" return 2 * (jnp.log(2) - x - jax.nn.softplus(-2 * x)) - def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]: + def forward_and_log_det(self, x: Array) -> tuple[Array, Array]: """Computes y = f(x) and log|det J(f)(x)|.""" return jnp.tanh(x), self.forward_log_det_jacobian(x) - def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]: + def inverse_and_log_det(self, y: Array) -> tuple[Array, Array]: """Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|.""" x = jnp.arctanh(y) return x, -self.forward_log_det_jacobian(x) def same_as(self, other: AbstractBijector) -> bool: """Returns True if this bijector is guaranteed to be the same as `other`.""" - return type(other) is Tanh # pylint: disable=unidiomatic-typecheck + return type(other) is Tanh diff --git a/distreqx/bijectors/triangular_linear.py b/distreqx/bijectors/triangular_linear.py index 983b381..c3be121 100644 --- a/distreqx/bijectors/triangular_linear.py +++ b/distreqx/bijectors/triangular_linear.py @@ -13,7 +13,7 @@ def _triangular_logdet(matrix: Array) -> Array: return jnp.sum(jnp.log(jnp.abs(jnp.diag(matrix)))) -class TriangularLinear(AbstractLinearBijector): +class TriangularLinear(AbstractLinearBijector, strict=True): """A linear bijector whose weight matrix is triangular. The bijector is defined as `f(x) = Ax` where `A` is a DxD triangular matrix. @@ -32,6 +32,9 @@ class TriangularLinear(AbstractLinearBijector): _matrix: Array _is_lower: bool + _is_constant_jacobian: bool + _is_constant_log_det: bool + _event_dims: int def __init__(self, matrix: Array, is_lower: bool = True): """Initializes a `TriangularLinear` bijector. @@ -44,6 +47,8 @@ def __init__(self, matrix: Array, is_lower: bool = True): - `is_lower`: if True, `A` is set to the lower triangular part of `matrix`. If False, `A` is set to the upper triangular part of `matrix`. """ + self._is_constant_jacobian = True + self._is_constant_log_det = True if matrix.ndim < 2: raise ValueError( f"`matrix` must have at least 2 dimensions, got {matrix.ndim}." @@ -52,8 +57,7 @@ def __init__(self, matrix: Array, is_lower: bool = True): raise ValueError( f"`matrix` must be square; instead, it has shape {matrix.shape[-2:]}." ) - super().__init__(event_dims=matrix.shape[-1]) - + self._event_dims = matrix.shape[-1] self._matrix = jnp.tril(matrix) if is_lower else jnp.triu(matrix) self._is_lower = is_lower diff --git a/distreqx/distributions/__init__.py b/distreqx/distributions/__init__.py index 8fdd09f..99ced36 100644 --- a/distreqx/distributions/__init__.py +++ b/distreqx/distributions/__init__.py @@ -1,12 +1,21 @@ from ._distribution import ( + AbstractCDFDistribution as AbstractCDFDistribution, AbstractDistribution as AbstractDistribution, + AbstractProbDistribution as AbstractProbDistribution, + AbstractSampleLogProbDistribution as AbstractSampleLogProbDistribution, + AbstractSTDDistribution as AbstractSTDDistribution, + AbstractSurivialDistribution as AbstractSurivialDistribution, ) from .bernoulli import Bernoulli as Bernoulli from .independent import Independent as Independent from .mvn_diag import MultivariateNormalDiag as MultivariateNormalDiag from .mvn_from_bijector import ( + AbstractMultivariateNormalFromBijector as AbstractMultivariateNormalFromBijector, MultivariateNormalFromBijector as MultivariateNormalFromBijector, ) from .mvn_tri import MultivariateNormalTri as MultivariateNormalTri from .normal import Normal as Normal -from .transformed import Transformed as Transformed +from .transformed import ( + AbstractTransformed as AbstractTransformed, + Transformed as Transformed, +) diff --git a/distreqx/distributions/_distribution.py b/distreqx/distributions/_distribution.py index 9041c62..49daae7 100644 --- a/distreqx/distributions/_distribution.py +++ b/distreqx/distributions/_distribution.py @@ -1,7 +1,6 @@ """Base class for distributions.""" from abc import abstractmethod -from typing import Tuple import equinox as eqx import jax @@ -14,10 +13,11 @@ class AbstractDistribution(eqx.Module, strict=True): """Base class for all distreqx distributions.""" + @abstractmethod def sample_and_log_prob( self, key: PRNGKeyArray, - ) -> Tuple[PyTree[Array], PyTree[Array]]: + ) -> tuple[PyTree[Array], PyTree[Array]]: """Returns sample and its log prob. By default, it just calls `log_prob` on the generated samples. However, for @@ -34,9 +34,7 @@ def sample_and_log_prob( - A tuple of a sample and their log probs. """ - samples = self.sample(key) - log_prob = self.log_prob(samples) - return samples, log_prob + raise NotImplementedError @abstractmethod def log_prob(self, value: PyTree[Array]) -> PyTree[Array]: @@ -69,6 +67,7 @@ def name(self) -> str: """Distribution name.""" return type(self).__name__ + @abstractmethod def prob(self, value: PyTree[Array]) -> PyTree[Array]: """Calculates the probability of an event. @@ -80,19 +79,21 @@ def prob(self, value: PyTree[Array]) -> PyTree[Array]: - The probability P(value). """ - return jnp.exp(self.log_prob(value)) + raise NotImplementedError @abstractmethod def sample(self, key: PRNGKeyArray) -> PyTree[Array]: """Samples an event.""" raise NotImplementedError + @abstractmethod def entropy(self) -> PyTree[Array]: """Calculates the Shannon entropy (in nats).""" raise NotImplementedError( f"Distribution `{self.name}` does not implement `entropy`." ) + @abstractmethod def log_cdf(self, value: PyTree[Array]) -> PyTree[Array]: """Evaluates the log cumulative distribution function at `value` i.e. log P[X <= value].""" @@ -100,6 +101,7 @@ def log_cdf(self, value: PyTree[Array]) -> PyTree[Array]: f"Distribution `{self.name}` does not implement `log_cdf`." ) + @abstractmethod def cdf(self, value: PyTree[Array]) -> PyTree[Array]: """Evaluates the cumulative distribution function at `value`. @@ -111,8 +113,9 @@ def cdf(self, value: PyTree[Array]) -> PyTree[Array]: - The CDF evaluated at value, i.e. P[X <= value]. """ - return jnp.exp(self.log_cdf(value)) + raise NotImplementedError + @abstractmethod def survival_function(self, value: PyTree[Array]) -> PyTree[Array]: """Evaluates the survival function at `value`. @@ -129,8 +132,9 @@ def survival_function(self, value: PyTree[Array]) -> PyTree[Array]: - The survival function evaluated at `value`, i.e. P[X > value] """ - return 1.0 - self.cdf(value) + raise NotImplementedError + @abstractmethod def log_survival_function(self, value: PyTree[Array]) -> PyTree[Array]: """Evaluates the log of the survival function at `value`. @@ -148,36 +152,44 @@ def log_survival_function(self, value: PyTree[Array]) -> PyTree[Array]: - The log of the survival function evaluated at `value`, i.e. log P[X > value] """ - return jnp.log1p(-self.cdf(value)) + raise NotImplementedError + @abstractmethod def mean(self) -> PyTree[Array]: """Calculates the mean.""" raise NotImplementedError( f"Distribution `{self.name}` does not implement `mean`." ) + @abstractmethod def median(self) -> PyTree[Array]: """Calculates the median.""" raise NotImplementedError( f"Distribution `{self.name}` does not implement `median`." ) + @abstractmethod def variance(self) -> PyTree[Array]: """Calculates the variance.""" raise NotImplementedError( f"Distribution `{self.name}` does not implement `variance`." ) + @abstractmethod def stddev(self) -> PyTree[Array]: """Calculates the standard deviation.""" - return jnp.sqrt(self.variance()) + raise NotImplementedError( + f"Distribution `{self.name}` does not implement `stddev`." + ) + @abstractmethod def mode(self) -> PyTree[Array]: """Calculates the mode.""" raise NotImplementedError( f"Distribution `{self.name}` does not implement `mode`." ) + @abstractmethod def kl_divergence(self, other_dist, **kwargs) -> PyTree[Array]: """Calculates the KL divergence to another distribution. @@ -207,3 +219,115 @@ def cross_entropy(self, other_dist, **kwargs) -> Array: - The cross entropy `H(self || other_dist)`. """ return self.kl_divergence(other_dist, **kwargs) + self.entropy() + + +class AbstractSampleLogProbDistribution(AbstractDistribution, strict=True): + """Abstract distribution + concrete `sample_and_log_prob`.""" + + def sample_and_log_prob( + self, + key: PRNGKeyArray, + ) -> tuple[PyTree[Array], PyTree[Array]]: + """Returns sample and its log prob. + + By default, it just calls `log_prob` on the generated samples. However, for + many distributions it's more efficient to compute the log prob of samples + than of arbitrary events (for example, there's no need to check that a + sample is within the distribution's domain). If that's the case, a subclass + may override this method with a more efficient implementation. + + **Arguments:** + + - `key`: PRNG key. + + **Returns:** + + - A tuple of a sample and their log probs. + """ + samples = self.sample(key) + log_prob = self.log_prob(samples) + return samples, log_prob + + +class AbstractProbDistribution(AbstractDistribution, strict=True): + """Abstract distribution + concrete `prob`.""" + + def prob(self, value: PyTree[Array]) -> PyTree[Array]: + """Calculates the probability of an event. + + **Arguments:** + + - `value`: An event. + + **Returns:** + + - The probability P(value). + """ + return jnp.exp(self.log_prob(value)) + + +class AbstractCDFDistribution(AbstractDistribution, strict=True): + """Abstract distribution + concrete `cdf`.""" + + def cdf(self, value: PyTree[Array]) -> PyTree[Array]: + """Evaluates the cumulative distribution function at `value`. + + **Arguments:** + + - `value`: An event. + + **Returns:** + + - The CDF evaluated at value, i.e. P[X <= value]. + """ + return jnp.exp(self.log_cdf(value)) + + +class AbstractSTDDistribution(AbstractDistribution, strict=True): + """Abstract distribution + concrete `stddev`.""" + + def stddev(self) -> PyTree[Array]: + """Calculate the standard deviation.""" + return jnp.sqrt(self.variance()) + + +class AbstractSurivialDistribution(AbstractDistribution, strict=True): + """Abstract distribution + concrete `survival_function` and + `log_survival_function`.""" + + def survival_function(self, value: PyTree[Array]) -> PyTree[Array]: + """Evaluates the survival function at `value`. + + Note that by default we use a numerically not necessarily stable definition + of the survival function in terms of the CDF. + More stable definitions should be implemented in subclasses for + distributions for which they exist. + + **Arguments:** + + - `value`: An event. + + **Returns:** + + - The survival function evaluated at `value`, i.e. P[X > value] + """ + return 1.0 - self.cdf(value) + + def log_survival_function(self, value: PyTree[Array]) -> PyTree[Array]: + """Evaluates the log of the survival function at `value`. + + Note that by default we use a numerically not necessarily stable definition + of the log of the survival function in terms of the CDF. + More stable definitions should be implemented in subclasses for + distributions for which they exist. + + **Arguments:** + + - `value`: An event. + + **Returns:** + + - The log of the survival function evaluated at `value`, i.e. + log P[X > value] + """ + return jnp.log1p(-self.cdf(value)) diff --git a/distreqx/distributions/bernoulli.py b/distreqx/distributions/bernoulli.py index fff8f73..4d12f70 100644 --- a/distreqx/distributions/bernoulli.py +++ b/distreqx/distributions/bernoulli.py @@ -1,16 +1,25 @@ """Bernoulli distribution.""" -from typing import Optional, Tuple, Union +from typing import Optional, Union import jax import jax.numpy as jnp from jaxtyping import Array, PRNGKeyArray from ..utils.math import multiply_no_nan -from ._distribution import AbstractDistribution - - -class Bernoulli(AbstractDistribution): +from ._distribution import ( + AbstractSampleLogProbDistribution, + AbstractSTDDistribution, + AbstractSurivialDistribution, +) + + +class Bernoulli( + AbstractSampleLogProbDistribution, + AbstractSTDDistribution, + AbstractSurivialDistribution, + strict=True, +): """Bernoulli distribution of shape dims. Bernoulli distribution with parameter `probs`, the probability of outcome `1`. @@ -65,10 +74,10 @@ def probs(self) -> Array: return jax.nn.sigmoid(self._logits) @property - def event_shape(self) -> Tuple[int]: + def event_shape(self) -> tuple[int]: return self.prob.shape - def _log_probs_parameter(self) -> Tuple[Array, Array]: + def _log_probs_parameter(self) -> tuple[Array, Array]: if self._logits is None: if self._probs is None: raise ValueError("_probs is None!") @@ -121,6 +130,9 @@ def mean(self) -> Array: """See `Distribution.mean`.""" return self.probs + def median(self) -> None: + raise NotImplementedError + def variance(self) -> Array: """See `Distribution.variance`.""" return (1 - self.probs) * self.probs @@ -146,7 +158,7 @@ def kl_divergence(self, other_dist, **kwargs) -> Array: def _probs_and_log_probs( dist: Bernoulli, -) -> Tuple[ +) -> tuple[ Array, Array, Array, diff --git a/distreqx/distributions/independent.py b/distreqx/distributions/independent.py index 0b7b230..06e388f 100644 --- a/distreqx/distributions/independent.py +++ b/distreqx/distributions/independent.py @@ -1,14 +1,18 @@ """Independent distribution.""" import operator -from typing import Tuple import jax.numpy as jnp import jax.tree_util as jtu from jaxtyping import Array, PRNGKeyArray, PyTree from .._custom_types import EventT -from ._distribution import AbstractDistribution +from ._distribution import ( + AbstractCDFDistribution, + AbstractDistribution, + AbstractProbDistribution, + AbstractSurivialDistribution, +) def _reduce_helper(pytree: PyTree) -> Array: @@ -16,7 +20,12 @@ def _reduce_helper(pytree: PyTree) -> Array: return jtu.tree_reduce(operator.add, sum_over_leaves) -class Independent(AbstractDistribution): +class Independent( + AbstractProbDistribution, + AbstractCDFDistribution, + AbstractSurivialDistribution, + strict=True, +): """Independent distribution obtained from child distributions.""" _distribution: AbstractDistribution @@ -46,7 +55,7 @@ def sample(self, key: PRNGKeyArray) -> Array: """See `Distribution.sample`.""" return self._distribution.sample(key) - def sample_and_log_prob(self, key: PRNGKeyArray) -> Tuple[Array, Array]: + def sample_and_log_prob(self, key: PRNGKeyArray) -> tuple[Array, Array]: """See `Distribution.sample_and_log_prob`.""" samples, log_prob = self._distribution.sample_and_log_prob(key) log_prob = _reduce_helper(log_prob) diff --git a/distreqx/distributions/mvn_diag.py b/distreqx/distributions/mvn_diag.py index 2291094..3362988 100644 --- a/distreqx/distributions/mvn_diag.py +++ b/distreqx/distributions/mvn_diag.py @@ -2,12 +2,26 @@ from typing import Optional +import equinox as eqx import jax import jax.numpy as jnp from jaxtyping import Array -from ..bijectors import DiagLinear -from .mvn_from_bijector import MultivariateNormalFromBijector +from ..bijectors import ( + AbstractBijector, + AbstractLinearBijector, + Block, + Chain, + DiagLinear, + Shift, +) +from ._distribution import AbstractDistribution +from .independent import Independent +from .mvn_from_bijector import ( + _check_input_parameters_are_valid, + AbstractMultivariateNormalFromBijector, +) +from .normal import Normal def _check_parameters(loc: Optional[Array], scale_diag: Optional[Array]) -> None: @@ -32,9 +46,14 @@ def _check_parameters(loc: Optional[Array], scale_diag: Optional[Array]) -> None ) -class MultivariateNormalDiag(MultivariateNormalFromBijector): +class MultivariateNormalDiag(AbstractMultivariateNormalFromBijector, strict=True): """Multivariate normal distribution on `R^k` with diagonal covariance.""" + _loc: Array + _scale: AbstractLinearBijector + _event_shape: tuple[int] + _distribution: AbstractDistribution + _bijector: AbstractBijector _scale_diag: Array def __init__(self, loc: Optional[Array] = None, scale_diag: Optional[Array] = None): @@ -62,7 +81,21 @@ def __init__(self, loc: Optional[Array] = None, scale_diag: Optional[Array] = No raise ValueError("scale_diag must be a vector!") scale = DiagLinear(scale_diag) - super().__init__(loc=loc, scale=scale) + _check_input_parameters_are_valid(scale, loc) + + # Build a standard multivariate Gaussian. + std_mvn_dist = Independent( + distribution=eqx.filter_vmap(Normal)( + jnp.zeros_like(loc), jnp.ones_like(loc) + ), + ) + # Form the bijector `f(x) = Ax + b`. + bijector = Chain([Block(Shift(loc), ndims=loc.ndim), scale]) + self._distribution = std_mvn_dist + self._bijector = bijector + self._scale = scale + self._loc = loc + self._event_shape = loc.shape[-1:] self._scale_diag = scale_diag @property diff --git a/distreqx/distributions/mvn_from_bijector.py b/distreqx/distributions/mvn_from_bijector.py index 233ff52..2d5ccc5 100644 --- a/distreqx/distributions/mvn_from_bijector.py +++ b/distreqx/distributions/mvn_from_bijector.py @@ -1,16 +1,24 @@ """MultivariateNormalFromBijector distribution.""" -from typing import Callable, Tuple +from typing import Callable import equinox as eqx import jax import jax.numpy as jnp -from jaxtyping import Array - -from ..bijectors import AbstractLinearBijector, Block, Chain, DiagLinear, Shift +from jaxtyping import Array, PyTree + +from ..bijectors import ( + AbstractBijector, + AbstractLinearBijector, + Block, + Chain, + DiagLinear, + Shift, +) +from ._distribution import AbstractDistribution from .independent import Independent from .normal import Normal -from .transformed import Transformed +from .transformed import AbstractTransformed def _check_input_parameters_are_valid( @@ -27,48 +35,12 @@ def _check_input_parameters_are_valid( ) -class MultivariateNormalFromBijector(Transformed): - """Multivariate normal distribution on `R^k`. - - The multivariate normal over `x` is characterized by an invertible affine - transformation `x = f(z) = A @ z + b`, where `z` is a random variable that - follows a standard multivariate normal on `R^k`, i.e., `p(z) = N(0, I_k)`, - `A` is a `k x k` transformation matrix, and `b` is a `k`-dimensional vector. - - The resulting PDF on `x` is a multivariate normal, `p(x) = N(b, C)`, where - `C = A @ A.T` is the covariance matrix. - - The transformation `x = f(z)` must be specified by a linear scale bijector - implementing the operation `A @ z` and a shift (or location) term `b`. - """ - - _loc: Array - _scale: AbstractLinearBijector - _event_shape: Tuple[int] - - def __init__(self, loc: Array, scale: AbstractLinearBijector): - """Initializes the distribution. - - **Arguments:** - - - `loc`: The term `b`, i.e., the mean of the multivariate normal distribution. - - `scale`: The bijector specifying the linear transformation `A @ z`, as - described in the class docstring. - """ - _check_input_parameters_are_valid(scale, loc) - - # Build a standard multivariate Gaussian. - std_mvn_dist = Independent( - distribution=eqx.filter_vmap(Normal)( - jnp.zeros_like(loc), jnp.ones_like(loc) - ), - ) - # Form the bijector `f(x) = Ax + b`. - bijector = Chain([Block(Shift(loc), ndims=loc.ndim), scale]) - super().__init__(distribution=std_mvn_dist, bijector=bijector) - self._scale = scale - self._loc = loc - self._event_shape = loc.shape[-1:] +class AbstractMultivariateNormalFromBijector(AbstractTransformed, strict=True): + _loc: eqx.AbstractVar[Array] + _scale: eqx.AbstractVar[AbstractLinearBijector] + _event_shape: eqx.AbstractVar[tuple[int]] + _distribution: eqx.AbstractVar[AbstractDistribution] + _bijector: eqx.AbstractVar[AbstractBijector] @property def scale(self) -> AbstractLinearBijector: @@ -139,29 +111,84 @@ def kl_divergence(self, other_dist, **kwargs) -> Array: return _kl_divergence_mvn_mvn(self, other_dist) +class MultivariateNormalFromBijector(AbstractMultivariateNormalFromBijector): + """Multivariate normal distribution on `R^k`. + + The multivariate normal over `x` is characterized by an invertible affine + transformation `x = f(z) = A @ z + b`, where `z` is a random variable that + follows a standard multivariate normal on `R^k`, i.e., `p(z) = N(0, I_k)`, + `A` is a `k x k` transformation matrix, and `b` is a `k`-dimensional vector. + + The resulting PDF on `x` is a multivariate normal, `p(x) = N(b, C)`, where + `C = A @ A.T` is the covariance matrix. + + The transformation `x = f(z)` must be specified by a linear scale bijector + implementing the operation `A @ z` and a shift (or location) term `b`. + """ + + _loc: Array + _scale: AbstractLinearBijector + _event_shape: tuple[int] + _distribution: AbstractDistribution + _bijector: AbstractBijector + + def __init__(self, loc: Array, scale: AbstractLinearBijector): + """Initializes the distribution. + + **Arguments:** + + - `loc`: The term `b`, i.e., the mean of the multivariate normal distribution. + - `scale`: The bijector specifying the linear transformation `A @ z`, as + described in the class docstring. + """ + _check_input_parameters_are_valid(scale, loc) + + # Build a standard multivariate Gaussian. + std_mvn_dist = Independent( + distribution=eqx.filter_vmap(Normal)( + jnp.zeros_like(loc), jnp.ones_like(loc) + ), + ) + # Form the bijector `f(x) = Ax + b`. + bijector = Chain([Block(Shift(loc), ndims=loc.ndim), scale]) + self._distribution = std_mvn_dist + self._bijector = bijector + self._scale = scale + self._loc = loc + self._event_shape = loc.shape[-1:] + + def log_cdf(self, value: PyTree[Array]) -> PyTree[Array]: + raise NotImplementedError + + def cdf(self, value: PyTree[Array]) -> PyTree[Array]: + raise NotImplementedError + + def _squared_frobenius_norm(x: Array) -> Array: """Computes the squared Frobenius norm of a matrix.""" return jnp.sum(jnp.square(x), axis=[-2, -1]) -def _log_abs_determinant(d: MultivariateNormalFromBijector) -> Array: +def _log_abs_determinant(d: AbstractMultivariateNormalFromBijector) -> Array: """Obtains `log|det(A)|`.""" return d.scale.forward_log_det_jacobian(jnp.zeros(d.event_shape, dtype=d.dtype)) -def _inv_scale_operator(d: MultivariateNormalFromBijector) -> Callable[[Array], Array]: +def _inv_scale_operator( + d: AbstractMultivariateNormalFromBijector, +) -> Callable[[Array], Array]: """Gets the operator that performs `A^-1 * x`.""" return jax.vmap(d.scale.inverse, in_axes=-1, out_axes=-1) -def _scale_matrix(d: MultivariateNormalFromBijector) -> Array: +def _scale_matrix(d: AbstractMultivariateNormalFromBijector) -> Array: """Gets the full scale matrix `A`.""" return d.scale.matrix -def _has_diagonal_scale(d: MultivariateNormalFromBijector) -> bool: +def _has_diagonal_scale(d: AbstractMultivariateNormalFromBijector) -> bool: """Determines if the scale matrix `A` is diagonal.""" - if isinstance(d, MultivariateNormalFromBijector) and isinstance( + if isinstance(d, AbstractMultivariateNormalFromBijector) and isinstance( d.scale, DiagLinear ): return True @@ -169,8 +196,8 @@ def _has_diagonal_scale(d: MultivariateNormalFromBijector) -> bool: def _kl_divergence_mvn_mvn( - dist1: MultivariateNormalFromBijector, - dist2: MultivariateNormalFromBijector, + dist1: AbstractMultivariateNormalFromBijector, + dist2: AbstractMultivariateNormalFromBijector, *unused_args, **unused_kwargs, ) -> Array: diff --git a/distreqx/distributions/mvn_tri.py b/distreqx/distributions/mvn_tri.py index 43e5c3c..2bbcd25 100644 --- a/distreqx/distributions/mvn_tri.py +++ b/distreqx/distributions/mvn_tri.py @@ -2,11 +2,23 @@ from typing import Optional +import equinox as eqx import jax.numpy as jnp -from jaxtyping import Array - -from ..bijectors import DiagLinear, TriangularLinear -from .mvn_from_bijector import MultivariateNormalFromBijector +from jaxtyping import Array, PyTree + +from ..bijectors import ( + AbstractBijector, + AbstractLinearBijector, + Block, + Chain, + DiagLinear, + Shift, + TriangularLinear, +) +from ._distribution import AbstractDistribution +from .independent import Independent +from .mvn_from_bijector import AbstractMultivariateNormalFromBijector +from .normal import Normal def _check_parameters(loc: Optional[Array], scale_tri: Optional[Array]) -> None: @@ -38,7 +50,7 @@ def _check_parameters(loc: Optional[Array], scale_tri: Optional[Array]) -> None: ) -class MultivariateNormalTri(MultivariateNormalFromBijector): +class MultivariateNormalTri(AbstractMultivariateNormalFromBijector, strict=True): """Multivariate normal distribution on `R^k`. The `MultivariateNormalTri` distribution is parameterized by a `k`-length @@ -48,6 +60,11 @@ class MultivariateNormalTri(MultivariateNormalFromBijector): _scale_tri: Array _is_lower: bool + _loc: Array + _scale: AbstractLinearBijector + _event_shape: tuple[int] + _distribution: AbstractDistribution + _bijector: AbstractBijector def __init__( self, @@ -94,7 +111,19 @@ def __init__( scale = TriangularLinear(matrix=self._scale_tri, is_lower=is_lower) self._is_lower = is_lower - super().__init__(loc=loc, scale=scale) + # Build a standard multivariate Gaussian. + std_mvn_dist = Independent( + distribution=eqx.filter_vmap(Normal)( + jnp.zeros_like(loc), jnp.ones_like(loc) + ), + ) + # Form the bijector `f(x) = Ax + b`. + bijector = Chain([Block(Shift(loc), ndims=loc.ndim), scale]) + self._distribution = std_mvn_dist + self._bijector = bijector + self._scale = scale + self._loc = loc + self._event_shape = loc.shape[-1:] @property def scale_tri(self) -> Array: @@ -105,3 +134,9 @@ def scale_tri(self) -> Array: def is_lower(self) -> bool: """Whether the `scale_tri` matrix is lower triangular.""" return self._is_lower + + def log_cdf(self, value: PyTree[Array]) -> PyTree[Array]: + raise NotImplementedError + + def cdf(self, value: PyTree[Array]) -> PyTree[Array]: + raise NotImplementedError diff --git a/distreqx/distributions/normal.py b/distreqx/distributions/normal.py index b44862c..169384d 100644 --- a/distreqx/distributions/normal.py +++ b/distreqx/distributions/normal.py @@ -1,19 +1,18 @@ """Normal distribution.""" import math -from typing import Tuple import jax import jax.numpy as jnp from jaxtyping import Array, PRNGKeyArray -from ._distribution import AbstractDistribution +from ._distribution import AbstractProbDistribution _half_log2pi = 0.5 * math.log(2 * math.pi) -class Normal(AbstractDistribution): +class Normal(AbstractProbDistribution, strict=True): """Normal distribution with location `loc` and `scale` parameters.""" _loc: Array @@ -31,7 +30,7 @@ def __init__(self, loc: Array, scale: Array): self._scale = jnp.array(scale) @property - def event_shape(self) -> Tuple[int, ...]: + def event_shape(self) -> tuple[int, ...]: """Shape of event of distribution samples.""" return self._loc.shape @@ -54,7 +53,7 @@ def sample(self, key: PRNGKeyArray) -> Array: rnd = self._sample_from_std_normal(key) return self._scale * rnd + self._loc - def sample_and_log_prob(self, key: PRNGKeyArray) -> Tuple[Array, Array]: + def sample_and_log_prob(self, key: PRNGKeyArray) -> tuple[Array, Array]: """See `Distribution.sample_and_log_prob`.""" rnd = self._sample_from_std_normal(key) samples = self._scale * rnd + self._loc diff --git a/distreqx/distributions/transformed.py b/distreqx/distributions/transformed.py index a63bc14..e3b8c84 100644 --- a/distreqx/distributions/transformed.py +++ b/distreqx/distributions/transformed.py @@ -1,16 +1,26 @@ """Distribution representing a Bijector applied to a Distribution.""" -from typing import Optional, Tuple +from typing import Optional +import equinox as eqx import jax import jax.numpy as jnp -from jaxtyping import Array, PRNGKeyArray +from jaxtyping import Array, PRNGKeyArray, PyTree from ..bijectors import AbstractBijector -from ._distribution import AbstractDistribution - - -class Transformed(AbstractDistribution): +from ._distribution import ( + AbstractDistribution, + AbstractProbDistribution, + AbstractSTDDistribution, + AbstractSurivialDistribution, +) + + +class AbstractTransformed( + AbstractSurivialDistribution, + AbstractProbDistribution, + strict=True, +): """Distribution of a random variable transformed by a bijective function. Let `X` be a continuous random variable and `Y = f(X)` be a random variable @@ -44,19 +54,8 @@ class Transformed(AbstractDistribution): can be correctly obtained. """ - _distribution: AbstractDistribution - _bijector: AbstractBijector - - def __init__(self, distribution: AbstractDistribution, bijector: AbstractBijector): - """Initializes a Transformed distribution. - - **Arguments:** - - `distribution`: the base distribution. - - `bijector`: a differentiable bijective transformation. Can be a bijector or - a callable to be wrapped by `Lambda` bijector. - """ - self._distribution = distribution - self._bijector = bijector + _distribution: eqx.AbstractVar[AbstractDistribution] + _bijector: eqx.AbstractVar[AbstractBijector] @property def distribution(self): @@ -81,7 +80,7 @@ def dtype(self) -> jnp.dtype: return self._infer_shapes_and_dtype()[1] @property - def event_shape(self) -> Tuple[int, ...]: + def event_shape(self) -> tuple[int, ...]: """See `Distribution.event_shape`.""" return self._infer_shapes_and_dtype()[0] @@ -93,12 +92,12 @@ def log_prob(self, value: Array) -> Array: return lp_y def sample(self, key: PRNGKeyArray) -> Array: - """Return asamples.""" + """Return a sample.""" x = self.distribution.sample(key) y = self.bijector.forward(x) return y - def sample_and_log_prob(self, key: PRNGKeyArray) -> Tuple[Array, Array]: + def sample_and_log_prob(self, key: PRNGKeyArray) -> tuple[Array, Array]: """Return a sample and log prob. This function is more efficient than calling `sample` and `log_prob` @@ -118,27 +117,6 @@ def sample_and_log_prob(self, key: PRNGKeyArray) -> Tuple[Array, Array]: lp_y = jnp.subtract(lp_x, fldj) return y, lp_y - def mean(self) -> Array: - """Calculates the mean.""" - if self.bijector.is_constant_jacobian: - return self.bijector.forward(self.distribution.mean()) - else: - raise NotImplementedError( - "`mean` is not implemented for this transformed distribution, " - "because its bijector's Jacobian is not known to be constant." - ) - - def mode(self) -> Array: - """Calculates the mode.""" - if self.bijector.is_constant_log_det: - return self.bijector.forward(self.distribution.mode()) - else: - raise NotImplementedError( - "`mode` is not implemented for this transformed distribution, " - "because its bijector's Jacobian determinant is not known to be " - "constant." - ) - def entropy(self, input_hint: Optional[Array] = None) -> Array: """Calculates the Shannon entropy (in Nats). @@ -174,6 +152,124 @@ def entropy(self, input_hint: Optional[Array] = None) -> Array: "constant." ) + +def _kl_divergence_transformed_transformed( + dist1: AbstractTransformed, + dist2: AbstractTransformed, + *unused_args, + input_hint: Optional[Array] = None, + **unused_kwargs, +) -> Array: + if dist1.distribution.event_shape != dist2.distribution.event_shape: + raise ValueError( + f"The two base distributions do not have the same event shape: " + f"{dist1.distribution.event_shape} and " + f"{dist2.distribution.event_shape}." + ) + + bij1 = dist1.bijector + bij2 = dist2.bijector + + # Check if the bijectors are different. + if bij1 != bij2 and not bij1.same_as(bij2): + if input_hint is None: + input_hint = jnp.zeros( + dist1.distribution.event_shape, dtype=dist1.distribution.dtype + ) + jaxpr_bij1 = jax.make_jaxpr(bij1.forward)(input_hint).jaxpr + jaxpr_bij2 = jax.make_jaxpr(bij2.forward)(input_hint).jaxpr + if str(jaxpr_bij1) != str(jaxpr_bij2): + raise NotImplementedError( + f"The KL divergence cannot be obtained because it is not possible to " + f"guarantee that the bijectors {dist1.bijector.name} and " + f"{dist2.bijector.name} of the Transformed distributions are " + f"equal. If possible, use the same instance of a distreqx bijector." + ) + + return dist1.distribution.kl_divergence(dist2.distribution) + + +class Transformed(AbstractTransformed, AbstractSTDDistribution, strict=True): + """Distribution of a random variable transformed by a bijective function. + + Let `X` be a continuous random variable and `Y = f(X)` be a random variable + transformed by a differentiable bijection `f` (a "bijector"). Given the + distribution of `X` (the "base distribution") and the bijector `f`, this class + implements the distribution of `Y` (also known as the pushforward of the base + distribution through `f`). + + The probability density of `Y` can be computed by: + + `log p(y) = log p(x) - log|det J(f)(x)|` + + where `p(x)` is the probability density of `X` (the "base density") and + `J(f)(x)` is the Jacobian matrix of `f`, both evaluated at `x = f^{-1}(y)`. + + Sampling from a Transformed distribution involves two steps: sampling from the + base distribution `x ~ p(x)` and then evaluating `y = f(x)`. For example: + + ```python + dist = distrax.Normal(loc=0., scale=1.) + bij = distrax.ScalarAffine(shift=jnp.asarray([3., 3., 3.])) + transformed_dist = distrax.Transformed(distribution=dist, bijector=bij) + samples = transformed_dist.sample(jax.random.PRNGKey(0)) + print(samples) # [2.7941577, 2.7941577, 2.7941577] + ``` + + This assumes that the `forward` function of the bijector is traceable; that is, + it is a pure function that does not contain run-time branching. Functions that + do not strictly meet this requirement can still be used, but we cannot guarantee + that the shapes, dtype, and KL computations involving the transformed distribution + can be correctly obtained. + """ + + _distribution: AbstractDistribution + _bijector: AbstractBijector + + def __init__(self, distribution: AbstractDistribution, bijector: AbstractBijector): + """Initializes a Transformed distribution. + + **Arguments:** + - `distribution`: the base distribution. + - `bijector`: a differentiable bijective transformation. Can be a bijector or + a callable to be wrapped by `Lambda` bijector. + """ + self._distribution = distribution + self._bijector = bijector + + def mean(self) -> Array: + """Calculates the mean.""" + if self.bijector.is_constant_jacobian: + return self.bijector.forward(self.distribution.mean()) + else: + raise NotImplementedError( + "`mean` is not implemented for this transformed distribution, " + "because its bijector's Jacobian is not known to be constant." + ) + + def mode(self) -> Array: + """Calculates the mode.""" + if self.bijector.is_constant_log_det: + return self.bijector.forward(self.distribution.mode()) + else: + raise NotImplementedError( + "`mode` is not implemented for this transformed distribution, " + "because its bijector's Jacobian determinant is not known to be " + "constant." + ) + + def log_cdf(self, value: PyTree[Array]) -> PyTree[Array]: + raise NotImplementedError + + def median(self) -> PyTree[Array]: + raise NotImplementedError + + def variance(self) -> PyTree[Array]: + raise NotImplementedError + + def cdf(self, value: PyTree[Array]) -> PyTree[Array]: + raise NotImplementedError + def kl_divergence(self, other_dist, **kwargs) -> Array: """Obtains the KL divergence between two Transformed distributions. @@ -181,16 +277,21 @@ def kl_divergence(self, other_dist, **kwargs) -> Array: same bijector. If the two Transformed distributions do not have the same bijector, an error is raised. To determine if the bijectors are equal, this method proceeds as follows: + - If both bijectors are the same instance of a distreqx bijector, then they are declared equal. + - If not the same instance, we check if they are equal according to their `same_as` predicate. + - Otherwise, the string representation of the Jaxpr of the `forward` method of each bijector is compared. If both string representations are equal, the bijectors are declared equal. + - Otherwise, the bijectors cannot be guaranteed to be equal and an error is raised. + **Arguments:** - `other_dist`: A Transformed distribution. @@ -208,39 +309,3 @@ def kl_divergence(self, other_dist, **kwargs) -> Array: - `ValueError`: If the base distributions do not have the same `event_shape`. """ return _kl_divergence_transformed_transformed(self, other_dist, **kwargs) - - -def _kl_divergence_transformed_transformed( - dist1: Transformed, - dist2: Transformed, - *unused_args, - input_hint: Optional[Array] = None, - **unused_kwargs, -) -> Array: - if dist1.distribution.event_shape != dist2.distribution.event_shape: - raise ValueError( - f"The two base distributions do not have the same event shape: " - f"{dist1.distribution.event_shape} and " - f"{dist2.distribution.event_shape}." - ) - - bij1 = dist1.bijector - bij2 = dist2.bijector - - # Check if the bijectors are different. - if bij1 != bij2 and not bij1.same_as(bij2): - if input_hint is None: - input_hint = jnp.zeros( - dist1.distribution.event_shape, dtype=dist1.distribution.dtype - ) - jaxpr_bij1 = jax.make_jaxpr(bij1.forward)(input_hint).jaxpr - jaxpr_bij2 = jax.make_jaxpr(bij2.forward)(input_hint).jaxpr - if str(jaxpr_bij1) != str(jaxpr_bij2): - raise NotImplementedError( - f"The KL divergence cannot be obtained because it is not possible to " - f"guarantee that the bijectors {dist1.bijector.name} and " - f"{dist2.bijector.name} of the Transformed distributions are " - f"equal. If possible, use the same instance of a distreqx bijector." - ) - - return dist1.distribution.kl_divergence(dist2.distribution) diff --git a/distreqx/utils/math.py b/distreqx/utils/math.py index 5b752ad..2c53d67 100644 --- a/distreqx/utils/math.py +++ b/distreqx/utils/math.py @@ -1,6 +1,6 @@ """Utility math functions.""" -from typing import Optional, Tuple +from typing import Optional import jax import jax.numpy as jnp @@ -31,8 +31,8 @@ def multiply_no_nan(x: Array, y: Array) -> Array: @multiply_no_nan.defjvp def multiply_no_nan_jvp( - primals: Tuple[Array, Array], tangents: Tuple[Array, Array] -) -> Tuple[Array, Array]: + primals: tuple[Array, Array], tangents: tuple[Array, Array] +) -> tuple[Array, Array]: """Custom gradient computation for `multiply_no_nan`. **Arguments:** @@ -71,8 +71,8 @@ def power_no_nan(x: Array, y: Array) -> Array: @power_no_nan.defjvp def power_no_nan_jvp( - primals: Tuple[Array, Array], tangents: Tuple[Array, Array] -) -> Tuple[Array, Array]: + primals: tuple[Array, Array], tangents: tuple[Array, Array] +) -> tuple[Array, Array]: """Custom gradient computation for `power_no_nan`. **Arguments:** diff --git a/docs/api/bijectors/_bijector.md b/docs/api/bijectors/_bijector.md index eb34a7b..b81809b 100644 --- a/docs/api/bijectors/_bijector.md +++ b/docs/api/bijectors/_bijector.md @@ -1,4 +1,4 @@ -# Base Bijector +# Abstract Bijectors ::: distreqx.bijectors._bijector.AbstractBijector selection: @@ -11,4 +11,26 @@ - forward_and_log_det - inverse_and_log_det - same_as + +::: distreqx.bijectors._bijector.AbstractInvLogDetJacBijector + selection: + members: + - inverse_log_det_jacobian + +::: distreqx.bijectors._bijector.AbstractFwdLogDetJacBijector + selection: + members: + - forward_log_det_jacobian + +::: distreqx.bijectors._bijector.AbstractFowardInverseBijector + selection: + members: + - forward + - inverse + +::: distreqx.bijectors._linear.AbstractLinearBijector + selection: + members: + - __init__ + - matrix --- \ No newline at end of file diff --git a/docs/api/bijectors/_linear.md b/docs/api/bijectors/_linear.md deleted file mode 100644 index 9a42931..0000000 --- a/docs/api/bijectors/_linear.md +++ /dev/null @@ -1,8 +0,0 @@ -# Base Linear Bijector - -::: distreqx.bijectors._linear.AbstractLinearBijector - selection: - members: - - __init__ - - matrix ---- \ No newline at end of file diff --git a/docs/api/distributions/_distribution.md b/docs/api/distributions/_distribution.md index 06c5dc9..4b2b85a 100644 --- a/docs/api/distributions/_distribution.md +++ b/docs/api/distributions/_distribution.md @@ -1,4 +1,4 @@ -# Base Distribution +# Abstract Distributions ::: distreqx.distributions._distribution.AbstractDistribution selection: @@ -9,4 +9,55 @@ - cdf - survival_function - log_survival_function + - kl_divergence + - cross_entropy + +::: distreqx.distributions._distribution.AbstractSampleLogProbDistribution + selection: + members: + - _sample_n_and_log_prob + +::: distreqx.distributions._distribution.AbstractProbDistribution + selection: + members: + - prob + +::: distreqx.distributions._distribution.AbstractCDFDistribution + selection: + members: + - cdf + +::: distreqx.distributions._distribution.AbstractSTDDistribution + selection: + members: + - stddev + +::: distreqx.distributions._distribution.AbstractSurivialDistribution + selection: + members: + - survival_function + - log_survival_function + +::: distreqx.distributions.transformed.AbstractTransformed + selection: + members: + - distribution + - bijector + - dtype + - event_shape + - log_prob + - sample + - sample_and_log_prob + - entropy + +::: distreqx.distributions.mvn_from_bijector.AbstractMultivariateNormalFromBijector + selection: + members: + - scale + - loc + - covariance + - variance + - stddev + - kl_divergence + --- \ No newline at end of file diff --git a/docs/api/distributions/mvn_from_bijector.md b/docs/api/distributions/mvn_from_bijector.md index bb7985d..61b62f2 100644 --- a/docs/api/distributions/mvn_from_bijector.md +++ b/docs/api/distributions/mvn_from_bijector.md @@ -4,5 +4,4 @@ selection: members: - __init__ - - covariance --- \ No newline at end of file diff --git a/docs/api/distributions/transformed.md b/docs/api/distributions/transformed.md index 20a5d0c..a73434a 100644 --- a/docs/api/distributions/transformed.md +++ b/docs/api/distributions/transformed.md @@ -1,10 +1,10 @@ -# Base Transformed +# Transformed ::: distreqx.distributions.transformed.Transformed selection: members: - __init__ - - sample_and_log_prob - - entropy + - mean + - mode - kl_divergence --- \ No newline at end of file diff --git a/mkdocs.yml b/mkdocs.yml index d30b3cb..aa78f76 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -94,18 +94,16 @@ nav: - Binary MNIST VAE: 'examples/01_vae.ipynb' - API: - Distributions: - - 'api/distributions/_distribution.md' - - 'api/distributions/independent.md' - - 'api/distributions/bernoulli.md' - - 'api/distributions/transformed.md' - Gaussians: - 'api/distributions/normal.md' - 'api/distributions/mvn_diag.md' - 'api/distributions/mvn_from_bijector.md' - 'api/distributions/mvn_tri.md' + - 'api/distributions/bernoulli.md' + - 'api/distributions/independent.md' + - 'api/distributions/transformed.md' + - 'api/distributions/_distribution.md' - Bijectors: - - 'api/bijectors/_bijector.md' - - 'api/bijectors/_linear.md' - 'api/bijectors/block.md' - 'api/bijectors/chain.md' - 'api/bijectors/diag_linear.md' @@ -114,6 +112,7 @@ nav: - 'api/bijectors/sigmoid.md' - 'api/bijectors/tanh.md' - 'api/bijectors/triangular_linear.md' + - 'api/bijectors/_bijector.md' - Utilities: - 'api/utils/math.md' - Further Details: diff --git a/tests/abstractbijector_test.py b/tests/abstractbijector_test.py index a972037..d8a527c 100644 --- a/tests/abstractbijector_test.py +++ b/tests/abstractbijector_test.py @@ -5,29 +5,34 @@ import jax import jax.numpy as jnp import numpy as np -from parameterized import parameterized # type: ignore -from distreqx.bijectors import AbstractBijector +from distreqx.bijectors import ( + AbstractFowardInverseBijector, + AbstractFwdLogDetJacBijector, + AbstractInvLogDetJacBijector, +) -class DummyBijector(AbstractBijector): +class DummyBijector( + AbstractFowardInverseBijector, + AbstractFwdLogDetJacBijector, + AbstractInvLogDetJacBijector, + strict=True, +): + _is_constant_jacobian: bool + _is_constant_log_det: bool + def forward_and_log_det(self, x): return x, jnp.zeros(x.shape[:-1], jnp.float_) def inverse_and_log_det(self, y): return y, jnp.zeros(y.shape[:-1], jnp.float_) + def same_as(self, other): + raise NotImplementedError -class BijectorTest(TestCase): - @parameterized.expand( - [ - ("non-consistent constant properties", True, False), - ] - ) - def test_invalid_parameters(self, name, cnst_jac, cnst_logdet): - with self.assertRaises(ValueError): - DummyBijector(cnst_jac, cnst_logdet) +class BijectorTest(TestCase): def test_jittable(self): @jax.jit def forward(bij, x): diff --git a/tests/abstractdistribution_test.py b/tests/abstractdistribution_test.py index 65d8a61..d61a8c1 100644 --- a/tests/abstractdistribution_test.py +++ b/tests/abstractdistribution_test.py @@ -8,7 +8,18 @@ from distreqx.distributions import _distribution -class DummyUnivariateDist(_distribution.AbstractDistribution): +class AbstractAll( + _distribution.AbstractSurivialDistribution, + _distribution.AbstractSTDDistribution, + _distribution.AbstractSampleLogProbDistribution, + _distribution.AbstractCDFDistribution, + _distribution.AbstractProbDistribution, + strict=True, +): + pass + + +class DummyUnivariateDist(AbstractAll): """Dummy univariate distribution for testing.""" def sample(self, key): @@ -21,8 +32,29 @@ def log_prob(self, value): def event_shape(self): return (1,) + def entropy(self): + raise NotImplementedError -class DummyMultivariateDist(_distribution.AbstractDistribution): + def kl_divergence(self, other_dist, **kwargs): + raise NotImplementedError + + def log_cdf(self, value): + raise NotImplementedError + + def mean(self): + raise NotImplementedError + + def median(self): + raise NotImplementedError + + def mode(self): + raise NotImplementedError + + def variance(self): + raise NotImplementedError + + +class DummyMultivariateDist(AbstractAll): """Dummy multivariate distribution for testing.""" _dimension: tuple @@ -37,6 +69,27 @@ def log_prob(self, value): def event_shape(self): return self._dimension + def entropy(self): + raise NotImplementedError + + def kl_divergence(self, other_dist, **kwargs): + raise NotImplementedError + + def log_cdf(self, value): + raise NotImplementedError + + def mean(self): + raise NotImplementedError + + def median(self): + raise NotImplementedError + + def mode(self): + raise NotImplementedError + + def variance(self): + raise NotImplementedError + class DistributionTest(TestCase): def setUp(self): diff --git a/tests/abstractlinear_test.py b/tests/abstractlinear_test.py index 6c8e1f9..06ca9e8 100644 --- a/tests/abstractlinear_test.py +++ b/tests/abstractlinear_test.py @@ -4,16 +4,41 @@ from parameterized import parameterized # type: ignore -from distreqx.bijectors import AbstractLinearBijector +from distreqx.bijectors import ( + AbstractFowardInverseBijector, + AbstractFwdLogDetJacBijector, + AbstractInvLogDetJacBijector, + AbstractLinearBijector, +) + + +class MockLinear( + AbstractLinearBijector, + AbstractFowardInverseBijector, + AbstractFwdLogDetJacBijector, + AbstractInvLogDetJacBijector, + strict=True, +): + _is_constant_jacobian: bool + _is_constant_log_det: bool + _event_dims: int - -class MockLinear(AbstractLinearBijector): def __init__(self, dims): - super().__init__(dims) + self._event_dims = dims + self._is_constant_jacobian = True + self._is_constant_log_det = True def forward_and_log_det(self, x): raise Exception + def same_as(self, other): + raise NotImplementedError + + def inverse_and_log_det(self, y): + raise NotImplementedError( + f"Bijector {self.name} does not implement `inverse_and_log_det`." + ) + class LinearTest(TestCase): @parameterized.expand( diff --git a/tests/mvn_from_bijector_test.py b/tests/mvn_from_bijector_test.py index 6a1cc9d..7de4ba3 100644 --- a/tests/mvn_from_bijector_test.py +++ b/tests/mvn_from_bijector_test.py @@ -8,20 +8,44 @@ import numpy as np from parameterized import parameterized # type: ignore -from distreqx.bijectors import AbstractLinearBijector, DiagLinear +from distreqx.bijectors import ( + AbstractFowardInverseBijector, + AbstractFwdLogDetJacBijector, + AbstractInvLogDetJacBijector, + AbstractLinearBijector, + DiagLinear, +) from distreqx.distributions import MultivariateNormalFromBijector -class MockLinear(AbstractLinearBijector): - """A mock linear bijector.""" +class MockLinear( + AbstractLinearBijector, + AbstractFowardInverseBijector, + AbstractFwdLogDetJacBijector, + AbstractInvLogDetJacBijector, + strict=True, +): + _is_constant_jacobian: bool + _is_constant_log_det: bool + _event_dims: int - def __init__(self, event_dims: int): - super().__init__(event_dims) + def __init__(self, dims): + self._event_dims = dims + self._is_constant_jacobian = True + self._is_constant_log_det = True def forward_and_log_det(self, x): """Computes y = f(x) and log|det J(f)(x)|.""" return x, jnp.zeros_like(x)[:-1] + def same_as(self, other): + raise NotImplementedError + + def inverse_and_log_det(self, y): + raise NotImplementedError( + f"Bijector {self.name} does not implement `inverse_and_log_det`." + ) + class MultivariateNormalFromBijectorTest(TestCase): @parameterized.expand(