Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parametric control types #364

Merged
merged 34 commits into from
Feb 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
87b5201
Added parametric control types
tttc3 Jan 30, 2024
c05054c
Demo of AbstractTerm change
tttc3 Jan 30, 2024
61e97db
Correct initial parametric control type implementation.
tttc3 Jan 30, 2024
8a504b8
Parametric AbstractTerm initial implementation.
tttc3 Feb 1, 2024
1e20f33
Update tests and fix hinting
tttc3 Feb 3, 2024
cd23a0b
Implement review comments.
tttc3 Feb 6, 2024
6dbda14
Add parametric control check to integrator.
tttc3 Feb 9, 2024
813e85d
Update and test parametric control check
tttc3 Feb 9, 2024
267a41a
Introduce new LevyArea types
tttc3 Feb 11, 2024
9577d43
Updated Brownian path LevyArea types
tttc3 Feb 11, 2024
2e4fc83
Replace Union types in isinstance checks
tttc3 Feb 12, 2024
a744b9d
Remove rogue comment
tttc3 Feb 12, 2024
589f26a
Revert _brownian_arch to single assignment
tttc3 Feb 12, 2024
9e3021b
Revert _evaluate_leaf key splitting
tttc3 Feb 12, 2024
f1aceea
Rename variables in test_term
tttc3 Feb 12, 2024
d9e874a
Update isinstance and issubclass checks
tttc3 Feb 12, 2024
2661b35
Safer handling in _denormalise_bm_inc
tttc3 Feb 12, 2024
fb572d6
Fix style in integrate control type check
tttc3 Feb 13, 2024
e5f15e6
Add draft vector_field typing
tttc3 Feb 16, 2024
787ceee
Add draft vector_field typing
tttc3 Feb 16, 2024
3249ff9
Fix term test
tttc3 Feb 16, 2024
3c834f1
Revert extemporaneous modifications in _tree
tttc3 Feb 16, 2024
f607ebf
Rename TimeLevyArea to BrownianIncrement and simplify diff
tttc3 Feb 16, 2024
80fd358
Rename AbstractLevyReturn to AbstractBrownianReturn
tttc3 Feb 16, 2024
796b40f
Rename _LevyArea to _BrownianReturn
tttc3 Feb 16, 2024
ed073d3
Enhance _term_compatiblity checks
tttc3 Feb 17, 2024
3fd7f3f
Merge branch 'dev' into parametric-types-variant
tttc3 Feb 17, 2024
50f3363
Fix merge issues
tttc3 Feb 18, 2024
5ef24fa
Bump pre-commit and fix type hints
tttc3 Feb 18, 2024
b9fc23f
Clean up from self-review
tttc3 Feb 18, 2024
06fd6a9
Explicitly add typeguard to deps
tttc3 Feb 18, 2024
b0b18e4
Bump ruff config to new syntax
tttc3 Feb 19, 2024
3791d6d
Parameterised terms: fixed term compatibility + spurious pyright errors
patrick-kidger Feb 19, 2024
c08202c
Merge pull request #1 from patrick-kidger/parametric-tweaks
tttc3 Feb 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.7
rev: v0.2.2
hooks:
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.316
rev: v1.1.350
hooks:
- id: pyright
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typing_extensions]
6 changes: 5 additions & 1 deletion diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,11 @@
UnsafeBrownianPath as UnsafeBrownianPath,
VirtualBrownianTree as VirtualBrownianTree,
)
from ._custom_types import LevyVal as LevyVal
from ._custom_types import (
AbstractBrownianReturn as AbstractBrownianReturn,
BrownianIncrement as BrownianIncrement,
SpaceTimeLevyArea as SpaceTimeLevyArea,
)
from ._event import (
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
DiscreteTerminatingEvent as DiscreteTerminatingEvent,
Expand Down
7 changes: 5 additions & 2 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import abc
import functools as ft
import warnings
from collections.abc import Iterable
from typing import Any, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any, cast, Optional, Union

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -20,6 +20,9 @@
from ._term import AbstractTerm, AdjointTerm


ω = cast(Callable, ω)


def _is_none(x):
return x is None

Expand Down
13 changes: 8 additions & 5 deletions diffrax/_brownian/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import abc
from typing import Optional, Union
from typing import Optional, TypeVar, Union

from equinox.internal import AbstractVar
from jaxtyping import Array, PyTree

from .._custom_types import LevyArea, LevyVal, RealScalarLike
from .._custom_types import AbstractBrownianReturn, RealScalarLike
from .._path import AbstractPath


class AbstractBrownianPath(AbstractPath):
_Control = TypeVar("_Control", bound=Union[PyTree[Array], AbstractBrownianReturn])


class AbstractBrownianPath(AbstractPath[_Control]):
"""Abstract base class for all Brownian paths."""

levy_area: AbstractVar[LevyArea]
levy_area: AbstractVar[type[AbstractBrownianReturn]]

@abc.abstractmethod
def evaluate(
Expand All @@ -20,7 +23,7 @@ def evaluate(
t1: Optional[RealScalarLike] = None,
left: bool = True,
use_levy: bool = False,
) -> Union[PyTree[Array], LevyVal]:
) -> _Control:
r"""Samples a Brownian increment $w(t_1) - w(t_0)$.

Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$.
Expand Down
44 changes: 24 additions & 20 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,12 @@
import lineax.internal as lxi
from jaxtyping import Array, PRNGKeyArray, PyTree

from .._custom_types import levy_tree_transpose, LevyArea, LevyVal, RealScalarLike
from .._custom_types import (
BrownianIncrement,
levy_tree_transpose,
RealScalarLike,
SpaceTimeLevyArea,
)
from .._misc import (
force_bitcast_convert_type,
is_tuple_of_ints,
Expand Down Expand Up @@ -42,26 +47,24 @@ class UnsafeBrownianPath(AbstractBrownianPath):
"""

shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: LevyArea = eqx.field(static=True)
levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea]] = eqx.field(
static=True
)
key: PRNGKeyArray

def __init__(
self,
shape: Union[tuple[int, ...], PyTree[jax.ShapeDtypeStruct]],
key: PRNGKeyArray,
levy_area: LevyArea = "",
levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea]],
):
self.shape = (
jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype())
if is_tuple_of_ints(shape)
else shape
)
self.key = key
if levy_area not in ["", "space-time"]:
raise ValueError(
f"levy_area must be one of '', 'space-time', but got {levy_area}."
)
self.levy_area = levy_area
self.levy_area = levy_area # pyright: ignore[reportIncompatibleVariableOverride]

if any(
not jnp.issubdtype(x.dtype, jnp.inexact)
Expand All @@ -70,11 +73,11 @@ def __init__(
raise ValueError("UnsafeBrownianPath dtypes all have to be floating-point.")

@property
def t0(self):
def t0(self): # pyright: ignore[reportIncompatibleVariableOverride]
return -jnp.inf

@property
def t1(self):
def t1(self): # pyright: ignore[reportIncompatibleVariableOverride]
return jnp.inf

@eqx.filter_jit
Expand All @@ -84,7 +87,7 @@ def evaluate(
t1: Optional[RealScalarLike] = None,
left: bool = True,
use_levy: bool = False,
) -> Union[PyTree[Array], LevyVal]:
) -> Union[PyTree[Array], BrownianIncrement, SpaceTimeLevyArea]:
del left
if t1 is None:
dtype = jnp.result_type(t0)
Expand Down Expand Up @@ -112,7 +115,7 @@ def evaluate(
)
if use_levy:
out = levy_tree_transpose(self.shape, out)
assert isinstance(out, LevyVal)
assert isinstance(out, (BrownianIncrement, SpaceTimeLevyArea))
return out

@staticmethod
Expand All @@ -121,25 +124,26 @@ def _evaluate_leaf(
t1: RealScalarLike,
key,
shape: jax.ShapeDtypeStruct,
levy_area: str,
levy_area: type[Union[BrownianIncrement, SpaceTimeLevyArea]],
use_levy: bool,
):
w_std = jnp.sqrt(t1 - t0).astype(shape.dtype)
w = jr.normal(key, shape.shape, shape.dtype) * w_std
dt = t1 - t0

if levy_area == "space-time":
if levy_area is SpaceTimeLevyArea:
key, key_hh = jr.split(key, 2)
hh_std = w_std / math.sqrt(12)
hh = jr.normal(key_hh, shape.shape, shape.dtype) * hh_std
elif levy_area == "":
hh = None
levy_val = SpaceTimeLevyArea(dt=dt, W=w, H=hh, K=None)
elif levy_area is BrownianIncrement:
levy_val = BrownianIncrement(dt=dt, W=w)
else:
assert False
w = jr.normal(key, shape.shape, shape.dtype) * w_std

if use_levy:
return LevyVal(dt=t1 - t0, W=w, H=hh, K=None)
else:
return w
return levy_val
return w


UnsafeBrownianPath.__init__.__doc__ = """
Expand Down
Loading
Loading