Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

v0.5.1 #367

Merged
merged 15 commits into from
May 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.7
rev: v0.2.2
hooks:
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.316
rev: v1.1.350
hooks:
- id: pyright
additional_dependencies: [equinox, jax, jaxtyping, optax, optimistix, lineax, pytest, typing_extensions]
11 changes: 6 additions & 5 deletions benchmarks/brownian_tree_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import cast, Optional, Union
from typing_extensions import TypeAlias

import diffrax
import equinox as eqx
import equinox.internal as eqxi
import jax
Expand Down Expand Up @@ -50,9 +51,9 @@ def __init__(
tol: RealScalarLike,
shape: tuple[int, ...],
key: PRNGKeyArray,
levy_area: str,
levy_area: type[diffrax.AbstractBrownianIncrement] = diffrax.BrownianIncrement,
):
assert levy_area == ""
assert levy_area == diffrax.BrownianIncrement
self.t0 = t0
self.t1 = t1
self.tol = tol
Expand Down Expand Up @@ -187,13 +188,13 @@ def run(_ts):
)


for levy_area in ("", "space-time"):
for levy_area in (diffrax.BrownianIncrement, diffrax.SpaceTimeLevyArea):
print(f"- {levy_area=}")
for tol in (2**-3, 2**-12):
print(f"-- {tol=}")
for num_ts in (1, 100):
for num_ts in (1, 10000):
print(f"--- {num_ts=}")
if levy_area == "":
if levy_area == diffrax.BrownianIncrement:
print(f"Old: {time_tree(OldVBT, num_ts, tol, levy_area):.5f}")
print(f"new: {time_tree(VirtualBrownianTree, num_ts, tol, levy_area):.5f}")
print("")
12 changes: 6 additions & 6 deletions benchmarks/small_neural_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
class FuncTorch(torch.nn.Module):
def __init__(self):
super().__init__()
self.func = torch.jit.script( # pyright: ignore
self.func = torch.jit.script(
torch.nn.Sequential(
torch.nn.Linear(4, 32),
torch.nn.Softplus(),
Expand All @@ -30,7 +30,7 @@ def __init__(self):
)

def forward(self, t, y):
return self.func(y) # pyright: ignore
return self.func(y)


class FuncJax(eqx.Module):
Expand Down Expand Up @@ -177,10 +177,10 @@ def run(multiple, grad, batch_size=64, t1=100):
with torch.no_grad():
func_jax = neural_ode_diffrax.func.func
func_torch = neural_ode_torch.func.func
func_torch[0].weight.copy_(torch.tensor(np.asarray(func_jax.layers[0].weight))) # pyright: ignore
func_torch[0].bias.copy_(torch.tensor(np.asarray(func_jax.layers[0].bias))) # pyright: ignore
func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight))) # pyright: ignore
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias))) # pyright: ignore
func_torch[0].weight.copy_(torch.tensor(np.asarray(func_jax.layers[0].weight)))
func_torch[0].bias.copy_(torch.tensor(np.asarray(func_jax.layers[0].bias)))
func_torch[2].weight.copy_(torch.tensor(np.asarray(func_jax.layers[1].weight)))
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias)))

y0_jax = jr.normal(jr.PRNGKey(1), (batch_size, 4))
y0_torch = torch.tensor(np.asarray(y0_jax))
Expand Down
23 changes: 22 additions & 1 deletion diffrax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,14 @@
UnsafeBrownianPath as UnsafeBrownianPath,
VirtualBrownianTree as VirtualBrownianTree,
)
from ._custom_types import LevyVal as LevyVal
from ._custom_types import (
AbstractBrownianIncrement as AbstractBrownianIncrement,
AbstractSpaceTimeLevyArea as AbstractSpaceTimeLevyArea,
AbstractSpaceTimeTimeLevyArea as AbstractSpaceTimeTimeLevyArea,
BrownianIncrement as BrownianIncrement,
SpaceTimeLevyArea as SpaceTimeLevyArea,
SpaceTimeTimeLevyArea as SpaceTimeTimeLevyArea,
)
from ._event import (
AbstractDiscreteTerminatingEvent as AbstractDiscreteTerminatingEvent,
DiscreteTerminatingEvent as DiscreteTerminatingEvent,
Expand All @@ -37,6 +44,12 @@
)
from ._misc import adjoint_rms_seminorm as adjoint_rms_seminorm
from ._path import AbstractPath as AbstractPath
from ._progress_meter import (
AbstractProgressMeter as AbstractProgressMeter,
NoProgressMeter as NoProgressMeter,
TextProgressMeter as TextProgressMeter,
TqdmProgressMeter as TqdmProgressMeter,
)
from ._root_finder import (
VeryChord as VeryChord,
with_stepsize_controller_tols as with_stepsize_controller_tols,
Expand All @@ -59,6 +72,7 @@
AbstractRungeKutta as AbstractRungeKutta,
AbstractSDIRK as AbstractSDIRK,
AbstractSolver as AbstractSolver,
AbstractSRK as AbstractSRK,
AbstractStratonovichSolver as AbstractStratonovichSolver,
AbstractWrappedSolver as AbstractWrappedSolver,
Bosh3 as Bosh3,
Expand All @@ -68,6 +82,7 @@
Dopri8 as Dopri8,
Euler as Euler,
EulerHeun as EulerHeun,
GeneralShARK as GeneralShARK,
HalfSolver as HalfSolver,
Heun as Heun,
ImplicitEuler as ImplicitEuler,
Expand All @@ -83,8 +98,14 @@
MultiButcherTableau as MultiButcherTableau,
Ralston as Ralston,
ReversibleHeun as ReversibleHeun,
SEA as SEA,
SemiImplicitEuler as SemiImplicitEuler,
ShARK as ShARK,
Sil3 as Sil3,
SlowRK as SlowRK,
SPaRK as SPaRK,
SRA1 as SRA1,
StochasticButcherTableau as StochasticButcherTableau,
StratonovichMilstein as StratonovichMilstein,
Tsit5 as Tsit5,
)
Expand Down
11 changes: 8 additions & 3 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import abc
import functools as ft
import warnings
from collections.abc import Iterable
from typing import Any, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any, cast, Optional, Union

import equinox as eqx
import equinox.internal as eqxi
Expand All @@ -20,6 +20,9 @@
from ._term import AbstractTerm, AdjointTerm


ω = cast(Callable, ω)


def _is_none(x):
return x is None

Expand Down Expand Up @@ -128,6 +131,7 @@ def loop(
init_state,
passed_solver_state,
passed_controller_state,
progress_meter,
) -> Any:
"""Runs the main solve loop. Subclasses can override this to provide custom
backpropagation behaviour; see for example the implementation of
Expand Down Expand Up @@ -559,6 +563,7 @@ def _loop_backsolve_bwd(
max_steps,
throw,
init_state,
progress_meter,
):
assert discrete_terminating_event is None

Expand All @@ -567,7 +572,7 @@ def _loop_backsolve_bwd(
# using them later.
#

del perturbed, init_state, t1
del perturbed, init_state, t1, progress_meter
ts, ys = residuals
del residuals
grad_final_state, _ = grad_final_state__aux_stats
Expand Down
25 changes: 24 additions & 1 deletion diffrax/_autocitation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,23 @@
from ._saveat import SubSaveAt
from ._solver import (
AbstractImplicitSolver,
AbstractItoSolver,
AbstractSRK,
AbstractStratonovichSolver,
Dopri5,
Dopri8,
GeneralShARK,
Kvaerno3,
Kvaerno4,
Kvaerno5,
LeapfrogMidpoint,
ReversibleHeun,
SEA,
SemiImplicitEuler,
ShARK,
SlowRK,
SPaRK,
SRA1,
Tsit5,
)
from ._step_size_controller import PIDController
Expand Down Expand Up @@ -374,7 +383,15 @@ def _backsolve_rms_norm(adjoint):

@citation_rules.append
def _explicit_solver(solver, terms=None):
if not isinstance(solver, AbstractImplicitSolver) and not is_sde(terms):
if not isinstance(
solver,
(
AbstractImplicitSolver,
AbstractSRK,
AbstractItoSolver,
AbstractStratonovichSolver,
),
) and not is_sde(terms):
return r"""
% You are using an explicit solver, and may wish to cite the standard textbook:
@book{hairer2008solving-i,
Expand Down Expand Up @@ -467,6 +484,12 @@ def _solvers(solver, saveat=None):
Kvaerno5,
ReversibleHeun,
LeapfrogMidpoint,
ShARK,
SRA1,
SlowRK,
GeneralShARK,
SPaRK,
SEA,
):
return (
r"""
Expand Down
18 changes: 13 additions & 5 deletions diffrax/_brownian/base.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,25 @@
import abc
from typing import Optional, Union
from typing import Optional, TypeVar, Union

from equinox.internal import AbstractVar
from jaxtyping import Array, PyTree

from .._custom_types import LevyArea, LevyVal, RealScalarLike
from .._custom_types import (
AbstractBrownianIncrement,
BrownianIncrement,
RealScalarLike,
SpaceTimeLevyArea,
)
from .._path import AbstractPath


class AbstractBrownianPath(AbstractPath):
_Control = TypeVar("_Control", bound=Union[PyTree[Array], AbstractBrownianIncrement])


class AbstractBrownianPath(AbstractPath[_Control]):
"""Abstract base class for all Brownian paths."""

levy_area: AbstractVar[LevyArea]
levy_area: AbstractVar[type[Union[BrownianIncrement, SpaceTimeLevyArea]]]

@abc.abstractmethod
def evaluate(
Expand All @@ -20,7 +28,7 @@ def evaluate(
t1: Optional[RealScalarLike] = None,
left: bool = True,
use_levy: bool = False,
) -> Union[PyTree[Array], LevyVal]:
) -> _Control:
r"""Samples a Brownian increment $w(t_1) - w(t_0)$.

Each increment has distribution $\mathcal{N}(0, t_1 - t_0)$.
Expand Down
Loading
Loading