Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pr branch #331

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .idea/.gitignore

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions .idea/diffrax_STLA.iml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/inspectionProfiles/profiles_settings.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 7 additions & 0 deletions .idea/misc.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 8 additions & 0 deletions .idea/modules.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions .idea/vcs.xml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 12 additions & 3 deletions diffrax/brownian/base.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
import abc
from typing import Optional, Union

from ..custom_types import Array, PyTree, Scalar
from ..custom_types import Array, LevyVal, PyTree, Scalar
from ..path import AbstractPath


class AbstractBrownianPath(AbstractPath):
"Abstract base class for all Brownian paths."
"""Abstract base class for all Brownian paths."""

@abc.abstractmethod
def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
def evaluate(
self,
t0: Scalar,
t1: Optional[Scalar] = None,
left: bool = True,
use_levy: bool = False,
) -> Union[PyTree[Array], LevyVal]:
r"""Samples a Brownian increment $w(t_1) - w(t_0)$.

Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$.
Expand All @@ -20,6 +27,8 @@ def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
- `left`: Ignored. (This determines whether to treat the path as
left-continuous or right-continuous at any jump points, but Brownian
motion has no jump points.)
- `use_levy`: If True, the return type will be a `LevyVal`, which contains
PyTrees of Brownian increments and their Levy areas.

**Returns:**

Expand Down
65 changes: 56 additions & 9 deletions diffrax/brownian/path.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Tuple, Union
from typing import Literal, Tuple, Union

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -7,7 +7,7 @@
import jax.random as jrandom
import jax.tree_util as jtu

from ..custom_types import Array, PyTree, Scalar
from ..custom_types import Array, levy_tree_transpose, LevyVal, PyTree, Scalar
from ..misc import force_bitcast_convert_type, is_tuple_of_ints, split_by_tree
from .base import AbstractBrownianPath

Expand All @@ -30,9 +30,14 @@ class UnsafeBrownianPath(AbstractBrownianPath):
interval, ignoring the correlation between samples exhibited in true Brownian
motion. Hence the restrictions above. (They describe the general case for which the
correlation structure isn't needed.)

Depending on the `levy_area` argument, this can also be used to generate Levy area.
`levy_area` can be "" or "space-time".

"""

shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: Literal["", "space-time"] = eqx.field(static=True)
# Handled as a string because PRNGKey is actually a function, not a class, which
# makes it appearly badly in autogenerated documentation.
key: "jax.random.PRNGKey" # noqa: F821
Expand All @@ -41,13 +46,20 @@ def __init__(
self,
shape: Union[Tuple[int, ...], PyTree[jax.ShapeDtypeStruct]],
key: "jax.random.PRNGKey",
levy_area: Literal["", "space-time"] = "",
):
self.shape = (
jax.ShapeDtypeStruct(shape, jax.dtypes.canonicalize_dtype(None))
if is_tuple_of_ints(shape)
else shape
)
self.key = key
if levy_area not in ["", "space-time"]:
raise ValueError(
andyElking marked this conversation as resolved.
Show resolved Hide resolved
f"levy_area must be one of '', 'space-time', but got {levy_area}."
)
self.levy_area = levy_area

if any(
not jnp.issubdtype(x.dtype, jnp.inexact)
for x in jtu.tree_leaves(self.shape)
Expand All @@ -63,7 +75,13 @@ def t1(self):
return None

@eqx.filter_jit
def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
def evaluate(
self,
t0: Scalar,
t1: Scalar,
left: bool = True,
use_levy: bool = False,
) -> PyTree[Array]:
del left
t0 = eqxi.nondifferentiable(t0, name="t0")
t1 = eqxi.nondifferentiable(t1, name="t1")
Expand All @@ -72,14 +90,42 @@ def evaluate(self, t0: Scalar, t1: Scalar, left: bool = True) -> PyTree[Array]:
key = jrandom.fold_in(self.key, t0_)
key = jrandom.fold_in(key, t1_)
key = split_by_tree(key, self.shape)
return jtu.tree_map(
lambda key, shape: self._evaluate_leaf(t0, t1, key, shape), key, self.shape
out = jtu.tree_map(
lambda key, shape: self._evaluate_leaf(
t0, t1, key, shape, self.levy_area, use_levy
),
key,
self.shape,
)
if use_levy:
out = levy_tree_transpose(self.shape, self.levy_area, out)
assert isinstance(out, LevyVal)
return out

@staticmethod
def _evaluate_leaf(
t0: Scalar,
t1: Scalar,
key,
shape: jax.ShapeDtypeStruct,
levy_area: str,
use_levy: bool,
):
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)

def _evaluate_leaf(self, t0: Scalar, t1: Scalar, key, shape: jax.ShapeDtypeStruct):
return jrandom.normal(key, shape.shape, shape.dtype) * jnp.sqrt(t1 - t0).astype(
shape.dtype
)
if levy_area == "space-time":
key_w, key_hh = jrandom.split(key, 2)
w = jrandom.normal(key_w, shape.shape, shape.dtype) * w_std
hh_std = jnp.sqrt((t1 - t0) / 12).astype(shape.dtype)
hh = jrandom.normal(key_hh, shape.shape, shape.dtype) * hh_std
else:
hh = None
w = jrandom.normal(key, shape.shape, shape.dtype) * w_std

if use_levy:
return LevyVal(dt=t1 - t0, W=w, H=hh, bar_H=None, K=None, bar_K=None)
else:
return w


UnsafeBrownianPath.__init__.__doc__ = """
Expand All @@ -89,5 +135,6 @@ def _evaluate_leaf(self, t0: Scalar, t1: Scalar, key, shape: jax.ShapeDtypeStruc
dtype, and PyTree structure of the output. For simplicity, `shape` can also just
be a tuple of integers, describing the shape of a single JAX array. In that case
the dtype is chosen to be the default floating-point dtype.

- `key`: A random key.
"""
Loading
Loading