Skip to content

Commit

Permalink
Enforces strict modules (#20)
Browse files Browse the repository at this point in the history
* distribution part 1

* bijectors

* beartyping

* design
  • Loading branch information
lockwo authored Jun 12, 2024
1 parent 0cad139 commit 77b8ee2
Show file tree
Hide file tree
Showing 34 changed files with 918 additions and 376 deletions.
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ These hooks use Black and isort to format the code, and flake8 to lint it.

**If you're making changes to the code:**

Now make your changes. Make sure to include additional tests if necessary.
Now make your changes. Make sure to include additional tests if necessary. Be sure to consider the abstract/final design pattern: https://docs.kidger.site/equinox/pattern/.

If you include a new features, there are 3 required classes of tests:
- Correctness: tests the are against analytic or known solutions that ensure the computation is correct
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ from distreqx import
- No official support/interoperability with TFP
- The concept of a batch dimension is dropped. If you want to operate on a batch, use `vmap` (note, this can be used in construction as well, e.g. [vmaping the construction](https://docs.kidger.site/equinox/tricks/#ensembling) of a `ScalarAffine`)
- Broader pytree enablement
- Strict [abstract/final](https://docs.kidger.site/equinox/pattern/) design pattern

## Citation

Expand Down
4 changes: 2 additions & 2 deletions distreqx/_custom_types.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Tuple, Union
from typing import Union

import jax
from jaxtyping import (
PyTree,
)


EventT = Union[Tuple[int], PyTree[jax.ShapeDtypeStruct]]
EventT = Union[tuple[int], PyTree[jax.ShapeDtypeStruct]]
7 changes: 6 additions & 1 deletion distreqx/bijectors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from ._bijector import AbstractBijector as AbstractBijector
from ._bijector import (
AbstractBijector as AbstractBijector,
AbstractFowardInverseBijector as AbstractFowardInverseBijector,
AbstractFwdLogDetJacBijector as AbstractFwdLogDetJacBijector,
AbstractInvLogDetJacBijector as AbstractInvLogDetJacBijector,
)
from ._linear import AbstractLinearBijector as AbstractLinearBijector
from .block import Block as Block
from .chain import Chain as Chain
Expand Down
94 changes: 47 additions & 47 deletions distreqx/bijectors/_bijector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from abc import abstractmethod
from typing import Optional, Tuple

import equinox as eqx
from jaxtyping import Array, PyTree
Expand Down Expand Up @@ -27,69 +26,38 @@ class AbstractBijector(eqx.Module, strict=True):
will assume these properties hold, and will make no attempt to verify them.
"""

_is_constant_jacobian: bool
_is_constant_log_det: bool

def __init__(
self,
is_constant_jacobian: bool = False,
is_constant_log_det: Optional[bool] = None,
):
"""Initializes a Bijector.
**Arguments:**
- `is_constant_jacobian`: Whether the Jacobian is promised to be constant
(which is the case if and only if the bijector is affine). A value of
False will be interpreted as "we don't know whether the Jacobian is
constant", rather than "the Jacobian is definitely not constant". Only
set to True if you're absolutely sure the Jacobian is constant; if
you're not sure, set to False.
- `is_constant_log_det`: Whether the Jacobian determinant is promised to be
constant (which is the case for, e.g., volume-preserving bijectors). If
None, it defaults to `is_constant_jacobian`. Note that the Jacobian
determinant can be constant without the Jacobian itself being constant.
Only set to True if you're absoltely sure the Jacobian determinant is
constant; if you're not sure, set to False.
"""
if is_constant_log_det is None:
is_constant_log_det = is_constant_jacobian
if is_constant_jacobian and not is_constant_log_det:
raise ValueError(
"The Jacobian is said to be constant, but its "
"determinant is said not to be, which is impossible."
)
self._is_constant_jacobian = is_constant_jacobian
self._is_constant_log_det = is_constant_log_det
_is_constant_jacobian: eqx.AbstractVar[bool]
_is_constant_log_det: eqx.AbstractVar[bool]

@abstractmethod
def forward(self, x: PyTree) -> PyTree:
R"""Computes $y = f(x)$."""
y, _ = self.forward_and_log_det(x)
return y
raise NotImplementedError

@abstractmethod
def inverse(self, y: PyTree) -> PyTree:
r"""Computes $x = f^{-1}(y)$."""
x, _ = self.inverse_and_log_det(y)
return x
raise NotImplementedError

@abstractmethod
def forward_log_det_jacobian(self, x: PyTree) -> PyTree:
r"""Computes $\log|\det J(f)(x)|$."""
_, logdet = self.forward_and_log_det(x)
return logdet
raise NotImplementedError

@abstractmethod
def inverse_log_det_jacobian(self, y: PyTree) -> PyTree:
r"""Computes $\log|\det J(f^{-1})(y)|$."""
_, logdet = self.inverse_and_log_det(y)
return logdet
raise NotImplementedError

@abstractmethod
def forward_and_log_det(self, x: PyTree) -> Tuple[PyTree, PyTree]:
def forward_and_log_det(self, x: PyTree) -> tuple[PyTree, PyTree]:
r"""Computes $y = f(x)$ and $\log|\det J(f)(x)|$."""
raise NotImplementedError(
f"Bijector {self.name} does not implement `forward_and_log_det`."
)

def inverse_and_log_det(self, y: Array) -> Tuple[PyTree, PyTree]:
@abstractmethod
def inverse_and_log_det(self, y: Array) -> tuple[PyTree, PyTree]:
r"""Computes $x = f^{-1}(y)$ and $\log|\det J(f^{-1})(y)|$."""
raise NotImplementedError(
f"Bijector {self.name} does not implement `inverse_and_log_det`."
Expand All @@ -110,7 +78,39 @@ def name(self) -> str:
"""Name of the bijector."""
return self.__class__.__name__

@abstractmethod
def same_as(self, other) -> bool:
"""Returns True if this bijector is guaranteed to be the same as `other`."""
del other
return False
raise NotImplementedError


class AbstractInvLogDetJacBijector(AbstractBijector, strict=True):
"""AbstractBijector + concrete `inverse_log_det_jacobian`."""

def inverse_log_det_jacobian(self, y: PyTree) -> PyTree:
r"""Computes $\log|\det J(f^{-1})(y)|$."""
_, logdet = self.inverse_and_log_det(y)
return logdet


class AbstractFwdLogDetJacBijector(AbstractBijector, strict=True):
"""AbstractBijector + concrete `forward_log_det_jacobian`."""

def forward_log_det_jacobian(self, x: PyTree) -> PyTree:
r"""Computes $\log|\det J(f)(x)|$."""
_, logdet = self.forward_and_log_det(x)
return logdet


class AbstractFowardInverseBijector(AbstractBijector, strict=True):
"""AbstractBijector + concrete `forward` and `reverse`."""

def forward(self, x: PyTree) -> PyTree:
R"""Computes $y = f(x)$."""
y, _ = self.forward_and_log_det(x)
return y

def inverse(self, y: PyTree) -> PyTree:
r"""Computes $x = f^{-1}(y)$."""
x, _ = self.inverse_and_log_det(y)
return x
15 changes: 3 additions & 12 deletions distreqx/bijectors/_linear.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,19 @@
"""Linear bijector."""

import equinox as eqx
from jaxtyping import Array

from ._bijector import AbstractBijector


class AbstractLinearBijector(AbstractBijector):
class AbstractLinearBijector(AbstractBijector, strict=True):
"""Base class for linear bijectors.
This class provides a base class for bijectors defined as `f(x) = Ax`,
where `A` is a `DxD` matrix and `x` is a `D`-dimensional vector.
"""

_event_dims: int

def __init__(self, event_dims: int):
"""Initializes a `Linear` bijector.
**Arguments:**
- `event_dims`: the dimensionality of the vector `D`
"""
super().__init__(is_constant_jacobian=True)
self._event_dims = event_dims
_event_dims: eqx.AbstractVar[int]

@property
def matrix(self) -> Array:
Expand Down
17 changes: 7 additions & 10 deletions distreqx/bijectors/block.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
"""Wrapper to turn independent Bijectors into block Bijectors."""

from typing import Tuple

from jaxtyping import Array

from ..utils import sum_last
from ._bijector import AbstractBijector


class Block(AbstractBijector):
class Block(AbstractBijector, strict=True):
"""A wrapper that promotes a bijector to a block bijector.
A block bijector applies a bijector to a k-dimensional array of events, but
Expand All @@ -33,6 +30,8 @@ class Block(AbstractBijector):

_ndims: int
_bijector: AbstractBijector
_is_constant_jacobian: bool
_is_constant_log_det: bool

def __init__(self, bijector: AbstractBijector, ndims: int):
"""Initializes a Block.
Expand All @@ -47,10 +46,8 @@ def __init__(self, bijector: AbstractBijector, ndims: int):
raise ValueError(f"`ndims` must be non-negative; got {ndims}.")
self._bijector = bijector
self._ndims = ndims
super().__init__(
is_constant_jacobian=self._bijector.is_constant_jacobian,
is_constant_log_det=self._bijector.is_constant_log_det,
)
self._is_constant_jacobian = self._bijector.is_constant_jacobian
self._is_constant_log_det = self._bijector.is_constant_log_det

@property
def bijector(self) -> AbstractBijector:
Expand Down Expand Up @@ -80,12 +77,12 @@ def inverse_log_det_jacobian(self, y: Array) -> Array:
log_det = self._bijector.inverse_log_det_jacobian(y)
return sum_last(log_det, self._ndims)

def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
def forward_and_log_det(self, x: Array) -> tuple[Array, Array]:
"""Computes y = f(x) and log|det J(f)(x)|."""
y, log_det = self._bijector.forward_and_log_det(x)
return y, sum_last(log_det, self._ndims)

def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
def inverse_and_log_det(self, y: Array) -> tuple[Array, Array]:
"""Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
x, log_det = self._bijector.inverse_and_log_det(y)
return x, sum_last(log_det, self._ndims)
Expand Down
33 changes: 22 additions & 11 deletions distreqx/bijectors/chain.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
"""Chain Bijector for composing a sequence of Bijector transformations."""

from typing import List, Sequence, Tuple
from typing import Sequence

from jaxtyping import Array

from ._bijector import AbstractBijector
from ._bijector import (
AbstractBijector,
AbstractFwdLogDetJacBijector,
AbstractInvLogDetJacBijector,
)


class Chain(AbstractBijector):
class Chain(AbstractFwdLogDetJacBijector, AbstractInvLogDetJacBijector, strict=True):
"""Composition of a sequence of bijectors into a single bijector.
Bijectors are composable: if `f` and `g` are bijectors, then `g o f` is also
Expand All @@ -30,7 +34,9 @@ class Chain(AbstractBijector):
`y = f(g(x))`.
"""

_bijectors: List[AbstractBijector]
_bijectors: list[AbstractBijector]
_is_constant_jacobian: bool
_is_constant_log_det: bool

def __init__(self, bijectors: Sequence[AbstractBijector]):
"""Initializes a Chain bijector.
Expand All @@ -47,13 +53,18 @@ def __init__(self, bijectors: Sequence[AbstractBijector]):

is_constant_jacobian = all(b.is_constant_jacobian for b in self._bijectors)
is_constant_log_det = all(b.is_constant_log_det for b in self._bijectors)
super().__init__(
is_constant_jacobian=is_constant_jacobian,
is_constant_log_det=is_constant_log_det,
)
if is_constant_log_det is None:
is_constant_log_det = is_constant_jacobian
if is_constant_jacobian and not is_constant_log_det:
raise ValueError(
"The Jacobian is said to be constant, but its "
"determinant is said not to be, which is impossible."
)
self._is_constant_jacobian = is_constant_jacobian
self._is_constant_log_det = is_constant_log_det

@property
def bijectors(self) -> List[AbstractBijector]:
def bijectors(self) -> list[AbstractBijector]:
"""The list of bijectors in the chain."""
return self._bijectors

Expand All @@ -69,15 +80,15 @@ def inverse(self, y: Array) -> Array:
y = bijector.inverse(y)
return y

def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
def forward_and_log_det(self, x: Array) -> tuple[Array, Array]:
"""Computes y = f(x) and log|det J(f)(x)|."""
x, log_det = self._bijectors[-1].forward_and_log_det(x)
for bijector in reversed(self._bijectors[:-1]):
x, ld = bijector.forward_and_log_det(x)
log_det += ld
return x, log_det

def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
def inverse_and_log_det(self, y: Array) -> tuple[Array, Array]:
"""Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
y, log_det = self._bijectors[0].inverse_and_log_det(y)
for bijector in self._bijectors[1:]:
Expand Down
15 changes: 9 additions & 6 deletions distreqx/bijectors/diag_linear.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
"""Diagonal linear bijector."""

from typing import Tuple

import jax.numpy as jnp
from jaxtyping import Array

Expand All @@ -11,7 +9,7 @@
from .scalar_affine import ScalarAffine


class DiagLinear(AbstractLinearBijector):
class DiagLinear(AbstractLinearBijector, strict=True):
"""Linear bijector with a diagonal weight matrix.
The bijector is defined as `f(x) = Ax` where `A` is a `DxD` diagonal matrix.
Expand All @@ -29,6 +27,9 @@ class DiagLinear(AbstractLinearBijector):

_diag: Array
_bijector: AbstractBijector
_is_constant_jacobian: bool
_is_constant_log_det: bool
_event_dims: int

def __init__(self, diag: Array):
"""Initializes the bijector.
Expand All @@ -42,8 +43,10 @@ def __init__(self, diag: Array):
self._bijector = Block(
ScalarAffine(shift=jnp.zeros_like(diag), scale=diag), ndims=diag.ndim
)
super().__init__(event_dims=diag.shape[-1])
self._event_dims = diag.shape[-1]
self._diag = diag
self._is_constant_jacobian = True
self._is_constant_log_det = True

def forward(self, x: Array) -> Array:
"""Computes y = f(x)."""
Expand All @@ -61,11 +64,11 @@ def inverse_log_det_jacobian(self, y: Array) -> Array:
"""Computes log|det J(f^{-1})(y)|."""
return self._bijector.inverse_log_det_jacobian(y)

def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
def inverse_and_log_det(self, y: Array) -> tuple[Array, Array]:
"""Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
return self._bijector.inverse_and_log_det(y)

def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
def forward_and_log_det(self, x: Array) -> tuple[Array, Array]:
"""Computes y = f(x) and log|det J(f)(x)|."""
return self._bijector.forward_and_log_det(x)

Expand Down
Loading

0 comments on commit 77b8ee2

Please sign in to comment.