diff --git a/diffrax/__init__.py b/diffrax/__init__.py index 5fe2ae59..71d6b503 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -126,7 +126,7 @@ from ._term import ( AbstractTerm as AbstractTerm, ControlTerm as ControlTerm, - LangevinTerm as LangevinTerm, + make_langevin_term as make_langevin_term, MultiTerm as MultiTerm, ODETerm as ODETerm, WeaklyDiagonalControlTerm as WeaklyDiagonalControlTerm, diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index a904f13d..615eafe9 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -63,7 +63,7 @@ ConstantStepSize, StepTo, ) -from ._term import AbstractTerm, LangevinTerm, MultiTerm, ODETerm, WrapTerm +from ._term import AbstractTerm, MultiTerm, ODETerm, WrapTerm from ._typing import better_isinstance, get_args_of, get_origin_no_specials @@ -162,15 +162,12 @@ def _check(term_cls, term, term_contr_kwargs, yi): pass elif n_term_args == 2: vf_type_expected, control_type_expected = term_args - if not isinstance(term, LangevinTerm): - # TODO: The line below causes problems with LangevinTerm - # Please help me fix this - vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args) - vf_type_compatible = eqx.filter_eval_shape( - better_isinstance, vf_type, vf_type_expected - ) - if not vf_type_compatible: - raise ValueError + vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args) + vf_type_compatible = eqx.filter_eval_shape( + better_isinstance, vf_type, vf_type_expected + ) + if not vf_type_compatible: + raise ValueError contr = ft.partial(term.contr, **term_contr_kwargs) # Work around https://github.com/google/jax/issues/21825 @@ -1008,10 +1005,6 @@ def _promote(yi): y0 = jtu.tree_map(_promote, y0) del timelikes - # Langevin terms must be unwrapped unless `term_structure=LangevinTerm - if isinstance(terms, LangevinTerm) and solver.term_structure != LangevinTerm: - terms = terms.term - # Backward compatibility if isinstance( solver, (EulerHeun, ItoMilstein, StratonovichMilstein) @@ -1070,14 +1063,13 @@ def _promote(yi): def _wrap(term): assert isinstance(term, AbstractTerm) - assert not isinstance(term, (MultiTerm, LangevinTerm)) + assert not isinstance(term, MultiTerm) return WrapTerm(term, direction) terms = jtu.tree_map( _wrap, terms, - is_leaf=lambda x: isinstance(x, AbstractTerm) - and not isinstance(x, (MultiTerm, LangevinTerm)), + is_leaf=lambda x: isinstance(x, AbstractTerm) and not isinstance(x, MultiTerm), ) if isinstance(solver, AbstractImplicitSolver): diff --git a/diffrax/_solver/align.py b/diffrax/_solver/align.py index 1be6defb..53427bb8 100644 --- a/diffrax/_solver/align.py +++ b/diffrax/_solver/align.py @@ -10,8 +10,13 @@ RealScalarLike, ) from .._local_interpolation import LocalLinearInterpolation -from .._term import _LangevinArgs, LangevinTerm, LangevinTuple, LangevinX -from .langevin_srk import AbstractCoeffs, AbstractLangevinSRK, SolverState +from .._term import _LangevinTuple, _LangevinX +from .langevin_srk import ( + _LangevinArgs, + AbstractCoeffs, + AbstractLangevinSRK, + SolverState, +) # UBU evaluates at l = (3 -sqrt(3))/6, at r = (3 + sqrt(3))/6 and at 1, @@ -36,7 +41,7 @@ def __init__(self, beta, a1, b1, aa, chh): self.dtype = jnp.result_type(*all_leaves) -_ErrorEstimate = LangevinTuple +_ErrorEstimate = _LangevinTuple class ALIGN(AbstractLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]): @@ -52,7 +57,6 @@ class ALIGN(AbstractLangevinSRK[_ALIGNCoeffs, _ErrorEstimate]): $W$ is a Brownian motion. """ - term_structure = LangevinTerm interpolation_cls = LocalLinearInterpolation taylor_threshold: RealScalarLike = eqx.field(static=True) _coeffs_structure = jtu.tree_structure( @@ -131,15 +135,15 @@ def _tay_coeffs_single(self, c: Array) -> _ALIGNCoeffs: def _compute_step( h: RealScalarLike, levy: AbstractSpaceTimeLevyArea, - x0: LangevinX, - v0: LangevinX, + x0: _LangevinX, + v0: _LangevinX, langevin_args: _LangevinArgs, coeffs: _ALIGNCoeffs, st: SolverState, - ) -> tuple[LangevinX, LangevinX, LangevinX, LangevinTuple]: + ) -> tuple[_LangevinX, _LangevinX, _LangevinX, _LangevinTuple]: dtypes = jtu.tree_map(jnp.dtype, x0) - w: LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes) - hh: LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes) + w: _LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes) + hh: _LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes) gamma, u, f = langevin_args diff --git a/diffrax/_solver/euler_heun.py b/diffrax/_solver/euler_heun.py index 88b99776..c8338c88 100644 --- a/diffrax/_solver/euler_heun.py +++ b/diffrax/_solver/euler_heun.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import ClassVar +from typing import Any, ClassVar from typing_extensions import TypeAlias from equinox.internal import ω @@ -7,7 +7,7 @@ from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y from .._local_interpolation import LocalLinearInterpolation from .._solution import RESULTS -from .._term import AbstractTerm, MultiTerm, ODETerm +from .._term import AbstractTerm, MultiTerm from .base import AbstractStratonovichSolver @@ -26,7 +26,9 @@ class EulerHeun(AbstractStratonovichSolver): Used to solve SDEs, and converges to the Stratonovich solution. """ - term_structure: ClassVar = MultiTerm[tuple[ODETerm, AbstractTerm]] + term_structure: ClassVar = MultiTerm[ + tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm] + ] interpolation_cls: ClassVar[ Callable[..., LocalLinearInterpolation] ] = LocalLinearInterpolation @@ -39,7 +41,7 @@ def strong_order(self, terms): def init( self, - terms: MultiTerm[tuple[ODETerm, AbstractTerm]], + terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]], t0: RealScalarLike, t1: RealScalarLike, y0: Y, @@ -49,7 +51,7 @@ def init( def step( self, - terms: MultiTerm[tuple[ODETerm, AbstractTerm]], + terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]], t0: RealScalarLike, t1: RealScalarLike, y0: Y, diff --git a/diffrax/_solver/langevin_srk.py b/diffrax/_solver/langevin_srk.py index 52022274..526c2ebb 100644 --- a/diffrax/_solver/langevin_srk.py +++ b/diffrax/_solver/langevin_srk.py @@ -1,5 +1,5 @@ import abc -from typing import Generic, TypeVar +from typing import Any, Callable, Generic, TypeVar import equinox as eqx import equinox.internal as eqxi @@ -19,11 +19,42 @@ ) from .._local_interpolation import LocalLinearInterpolation from .._solution import RESULTS -from .._term import _LangevinArgs, LangevinTerm, LangevinTuple, LangevinX +from .._term import ( + _LangevinDiffusionTerm, + _LangevinDriftTerm, + _LangevinTuple, + _LangevinX, + AbstractTerm, + MultiTerm, + WrapTerm, +) from .base import AbstractItoSolver, AbstractStratonovichSolver -_ErrorEstimate = TypeVar("_ErrorEstimate", None, LangevinTuple) +_ErrorEstimate = TypeVar("_ErrorEstimate", None, _LangevinTuple) +_LangevinArgs = tuple[_LangevinX, _LangevinX, Callable[[_LangevinX], _LangevinX]] + + +def get_args_from_terms( + terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]], +) -> _LangevinArgs: + drift, diffusion = terms.terms + if isinstance(drift, WrapTerm): + unwrapped_drift = drift.term + assert isinstance(diffusion, WrapTerm) + unwrapped_diffusion = diffusion.term + assert isinstance(unwrapped_drift, _LangevinDriftTerm) + assert isinstance(unwrapped_diffusion, _LangevinDiffusionTerm) + gamma = unwrapped_drift.gamma + u = unwrapped_drift.u + f = unwrapped_drift.grad_f + else: + assert isinstance(drift, _LangevinDriftTerm) + assert isinstance(diffusion, _LangevinDiffusionTerm) + gamma = drift.gamma + u = drift.u + f = drift.grad_f + return gamma, u, f class AbstractCoeffs(eqx.Module): @@ -38,10 +69,10 @@ class AbstractCoeffs(eqx.Module): # How should I work around this? class SolverState(eqx.Module, Generic[_Coeffs]): h: RealScalarLike - taylor_coeffs: PyTree[_Coeffs, "LangevinX"] + taylor_coeffs: PyTree[_Coeffs, "_LangevinX"] coeffs: _Coeffs - rho: LangevinX - prev_f: LangevinX + rho: _LangevinX + prev_f: _LangevinX # CONCERNING COEFFICIENTS: @@ -77,7 +108,7 @@ class AbstractLangevinSRK( QUIC_SORT. """ - term_structure = LangevinTerm + term_structure = MultiTerm[tuple[_LangevinDriftTerm, _LangevinDiffusionTerm]] interpolation_cls = LocalLinearInterpolation taylor_threshold: RealScalarLike = eqx.field(static=True) _coeffs_structure: eqx.AbstractClassVar[PyTreeDef] @@ -119,7 +150,7 @@ def _eval_taylor(h, tay_coeffs: _Coeffs) -> _Coeffs: ) def _recompute_coeffs( - self, h, gamma: LangevinX, tay_coeffs: PyTree[_Coeffs], state_h + self, h, gamma: _LangevinX, tay_coeffs: PyTree[_Coeffs], state_h ) -> _Coeffs: def recompute_coeffs_leaf(c: ArrayLike, _tay_coeffs: _Coeffs): # Used when the step-size h changes and coefficients need to be recomputed @@ -167,10 +198,10 @@ def _choose(tay_leaf, direct_leaf): def init( self, - terms: LangevinTerm, + terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]], t0: RealScalarLike, t1: RealScalarLike, - y0: LangevinTuple, + y0: _LangevinTuple, args: PyTree, ) -> SolverState: """Precompute _SolverState which carries the Taylor coefficients and the @@ -178,19 +209,17 @@ def init( Some solvers of this type are FSAL, so _SolverState also carries the previous evaluation of grad_f. """ - assert isinstance(terms, LangevinTerm) - gamma, u, f = terms.args # f is in fact grad(f) + drift, diffusion = terms.terms + gamma, u, f = get_args_from_terms(terms) + h = drift.contr(t0, t1) x0, v0 = y0 def _check_shapes(_c, _u, _x, _v): - # assert _x.ndim in [0, 1] assert _c.shape == _u.shape == _x.shape == _v.shape assert jtu.tree_all(jtu.tree_map(_check_shapes, gamma, u, x0, v0)) - h = t1 - t0 - tay_coeffs = jtu.tree_map(self._comp_taylor_coeffs_leaf, gamma) # tay_coeffs have the same tree structure as gamma, with each leaf being a # _Coeffs and the arrays have an extra trailing dimension of 6 @@ -213,30 +242,30 @@ def _check_shapes(_c, _u, _x, _v): def _compute_step( h: RealScalarLike, levy, - x0: LangevinX, - v0: LangevinX, + x0: _LangevinX, + v0: _LangevinX, langevin_args: _LangevinArgs, coeffs: _Coeffs, st: SolverState, - ) -> tuple[LangevinX, LangevinX, LangevinX, _ErrorEstimate]: + ) -> tuple[_LangevinX, _LangevinX, _LangevinX, _ErrorEstimate]: raise NotImplementedError def step( self, - terms: LangevinTerm, + terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]], t0: RealScalarLike, t1: RealScalarLike, - y0: LangevinTuple, + y0: _LangevinTuple, args: PyTree, solver_state: SolverState, made_jump: BoolScalarLike, - ) -> tuple[LangevinTuple, _ErrorEstimate, DenseInfo, SolverState, RESULTS]: + ) -> tuple[_LangevinTuple, _ErrorEstimate, DenseInfo, SolverState, RESULTS]: del made_jump, args st = solver_state - h = t1 - t0 - assert isinstance(terms, LangevinTerm) - gamma, u, f = terms.args + drift, diffusion = terms.terms + gamma, u, f = get_args_from_terms(terms) + h = drift.contr(t0, t1) h_state = st.h tay: PyTree[_Coeffs] = st.taylor_coeffs coeffs: _Coeffs = st.coeffs @@ -250,7 +279,6 @@ def step( coeffs, ) - drift, diffusion = terms.term.terms # compute the Brownian increment and space-time Levy area levy = diffusion.contr(t0, t1, use_levy=True) assert isinstance(levy, self.minimal_levy_area), ( @@ -260,7 +288,7 @@ def step( x0, v0 = y0 x_out, v_out, f_fsal, error = self._compute_step( - h, levy, x0, v0, terms.args, coeffs, st + h, levy, x0, v0, (gamma, u, f), coeffs, st ) def check_shapes_dtypes(_x, _v, _f, _x0): @@ -289,9 +317,9 @@ def check_shapes_dtypes(_x, _v, _f, _x0): def func( self, - terms: LangevinTerm, + terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]], t0: RealScalarLike, - y0: LangevinTuple, + y0: _LangevinTuple, args: PyTree, ): return terms.vf(t0, y0, args) diff --git a/diffrax/_solver/milstein.py b/diffrax/_solver/milstein.py index 3d14e343..5ab037d2 100644 --- a/diffrax/_solver/milstein.py +++ b/diffrax/_solver/milstein.py @@ -1,5 +1,5 @@ from collections.abc import Callable -from typing import ClassVar +from typing import Any, ClassVar from typing_extensions import TypeAlias import jax @@ -10,7 +10,7 @@ from .._custom_types import Args, BoolScalarLike, DenseInfo, RealScalarLike, VF, Y from .._local_interpolation import LocalLinearInterpolation from .._solution import RESULTS -from .._term import AbstractTerm, MultiTerm, ODETerm +from .._term import AbstractTerm, MultiTerm from .base import AbstractItoSolver, AbstractStratonovichSolver @@ -42,7 +42,9 @@ class StratonovichMilstein(AbstractStratonovichSolver): Note that this commutativity condition is not checked. """ # noqa: E501 - term_structure: ClassVar = MultiTerm[tuple[ODETerm, AbstractTerm]] + term_structure: ClassVar = MultiTerm[ + tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm] + ] interpolation_cls: ClassVar[ Callable[..., LocalLinearInterpolation] ] = LocalLinearInterpolation @@ -55,7 +57,7 @@ def strong_order(self, terms): def init( self, - terms: MultiTerm[tuple[ODETerm, AbstractTerm]], + terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]], t0: RealScalarLike, t1: RealScalarLike, y0: Y, @@ -65,7 +67,7 @@ def init( def step( self, - terms: MultiTerm[tuple[ODETerm, AbstractTerm]], + terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]], t0: RealScalarLike, t1: RealScalarLike, y0: Y, @@ -116,7 +118,9 @@ class ItoMilstein(AbstractItoSolver): Note that this commutativity condition is not checked. """ # noqa: E501 - term_structure: ClassVar = MultiTerm[tuple[ODETerm, AbstractTerm]] + term_structure: ClassVar = MultiTerm[ + tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm] + ] interpolation_cls: ClassVar[ Callable[..., LocalLinearInterpolation] ] = LocalLinearInterpolation @@ -129,7 +133,7 @@ def strong_order(self, terms): def init( self, - terms: MultiTerm[tuple[ODETerm, AbstractTerm]], + terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]], t0: RealScalarLike, t1: RealScalarLike, y0: Y, @@ -139,7 +143,7 @@ def init( def step( self, - terms: MultiTerm[tuple[ODETerm, AbstractTerm]], + terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]], t0: RealScalarLike, t1: RealScalarLike, y0: Y, @@ -365,7 +369,7 @@ def _dot(_, _v0): def func( self, - terms: MultiTerm[tuple[ODETerm, AbstractTerm]], + terms: MultiTerm[tuple[AbstractTerm[Any, RealScalarLike], AbstractTerm]], t0: RealScalarLike, y0: Y, args: Args, diff --git a/diffrax/_solver/quicsort.py b/diffrax/_solver/quicsort.py index b7c88d74..1dfe2649 100644 --- a/diffrax/_solver/quicsort.py +++ b/diffrax/_solver/quicsort.py @@ -13,8 +13,13 @@ RealScalarLike, ) from .._local_interpolation import LocalLinearInterpolation -from .._term import _LangevinArgs, LangevinTerm, LangevinX -from .langevin_srk import AbstractCoeffs, AbstractLangevinSRK, SolverState +from .._term import _LangevinX +from .langevin_srk import ( + _LangevinArgs, + AbstractCoeffs, + AbstractLangevinSRK, + SolverState, +) # UBU evaluates at l = (3 -sqrt(3))/6, at r = (3 + sqrt(3))/6 and at 1, @@ -71,7 +76,6 @@ class QUICSORT(AbstractLangevinSRK[_QUICSORTCoeffs, None]): $W$ is a Brownian motion. """ - term_structure = LangevinTerm interpolation_cls = LocalLinearInterpolation taylor_threshold: RealScalarLike = eqx.field(static=True) _coeffs_structure = jtu.tree_structure( @@ -182,16 +186,16 @@ def _tay_coeffs_single(self, c: Array) -> _QUICSORTCoeffs: def _compute_step( h: RealScalarLike, levy: AbstractSpaceTimeTimeLevyArea, - x0: LangevinX, - v0: LangevinX, + x0: _LangevinX, + v0: _LangevinX, langevin_args: _LangevinArgs, coeffs: _QUICSORTCoeffs, st: SolverState, - ) -> tuple[LangevinX, LangevinX, LangevinX, None]: + ) -> tuple[_LangevinX, _LangevinX, _LangevinX, None]: dtypes = jtu.tree_map(jnp.dtype, x0) - w: LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes) - hh: LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes) - kk: LangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes) + w: _LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes) + hh: _LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes) + kk: _LangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes) gamma, u, f = langevin_args diff --git a/diffrax/_solver/should.py b/diffrax/_solver/should.py index 5c1d9199..0940a58f 100644 --- a/diffrax/_solver/should.py +++ b/diffrax/_solver/should.py @@ -10,8 +10,13 @@ RealScalarLike, ) from .._local_interpolation import LocalLinearInterpolation -from .._term import _LangevinArgs, LangevinTerm, LangevinX -from .langevin_srk import AbstractCoeffs, AbstractLangevinSRK, SolverState +from .._term import _LangevinX +from .langevin_srk import ( + _LangevinArgs, + AbstractCoeffs, + AbstractLangevinSRK, + SolverState, +) # UBU evaluates at l = (3 -sqrt(3))/6, at r = (3 + sqrt(3))/6 and at 1, @@ -69,7 +74,6 @@ class ShOULD(AbstractLangevinSRK[_ShOULDCoeffs, None]): $W$ is a Brownian motion. """ - term_structure = LangevinTerm interpolation_cls = LocalLinearInterpolation taylor_threshold: RealScalarLike = eqx.field(static=True) _coeffs_structure = jtu.tree_structure( @@ -180,16 +184,16 @@ def _tay_coeffs_single(self, c: Array) -> _ShOULDCoeffs: def _compute_step( h: RealScalarLike, levy: AbstractSpaceTimeTimeLevyArea, - x0: LangevinX, - v0: LangevinX, + x0: _LangevinX, + v0: _LangevinX, langevin_args: _LangevinArgs, coeffs: _ShOULDCoeffs, st: SolverState, - ) -> tuple[LangevinX, LangevinX, LangevinX, None]: + ) -> tuple[_LangevinX, _LangevinX, _LangevinX, None]: dtypes = jtu.tree_map(jnp.dtype, x0) - w: LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes) - hh: LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes) - kk: LangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes) + w: _LangevinX = jtu.tree_map(jnp.asarray, levy.W, dtypes) + hh: _LangevinX = jtu.tree_map(jnp.asarray, levy.H, dtypes) + kk: _LangevinX = jtu.tree_map(jnp.asarray, levy.K, dtypes) gamma, u, f = langevin_args diff --git a/diffrax/_solver/srk.py b/diffrax/_solver/srk.py index 8ac53715..508b4d71 100644 --- a/diffrax/_solver/srk.py +++ b/diffrax/_solver/srk.py @@ -28,7 +28,7 @@ ) from .._local_interpolation import LocalLinearInterpolation from .._solution import RESULTS -from .._term import AbstractTerm, MultiTerm, ODETerm +from .._term import AbstractTerm, MultiTerm from .base import AbstractSolver @@ -276,11 +276,21 @@ def minimal_levy_area(self) -> type[AbstractBrownianIncrement]: @property def term_structure(self): - return MultiTerm[tuple[ODETerm, AbstractTerm[Any, self.minimal_levy_area]]] + return MultiTerm[ + tuple[ + AbstractTerm[Any, RealScalarLike], + AbstractTerm[Any, self.minimal_levy_area], + ] + ] def init( self, - terms: MultiTerm[tuple[ODETerm, AbstractTerm[Any, AbstractBrownianIncrement]]], + terms: MultiTerm[ + tuple[ + AbstractTerm[Any, RealScalarLike], + AbstractTerm[Any, AbstractBrownianIncrement], + ] + ], t0: RealScalarLike, t1: RealScalarLike, y0: Y, @@ -314,7 +324,12 @@ def _embed_a_lower(self, _a, dtype): def step( self, - terms: MultiTerm[tuple[ODETerm, AbstractTerm[Any, AbstractBrownianIncrement]]], + terms: MultiTerm[ + tuple[ + AbstractTerm[Any, RealScalarLike], + AbstractTerm[Any, AbstractBrownianIncrement], + ] + ], t0: RealScalarLike, t1: RealScalarLike, y0: Y, @@ -643,7 +658,12 @@ def compute_and_insert_kg_j(_w_kgs_in, _levylist_kgs_in): def func( self, - terms: MultiTerm[tuple[ODETerm, AbstractTerm[Any, AbstractBrownianIncrement]]], + terms: MultiTerm[ + tuple[ + AbstractTerm[Any, RealScalarLike], + AbstractTerm[Any, AbstractBrownianIncrement], + ] + ], t0: RealScalarLike, y0: Y, args: PyTree, diff --git a/diffrax/_term.py b/diffrax/_term.py index 2b6272a0..86467c07 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -2,7 +2,7 @@ import operator import warnings from collections.abc import Callable -from typing import Any, cast, Generic, Optional, TypeVar, Union +from typing import cast, Generic, Optional, TypeVar, Union import equinox as eqx import jax @@ -791,12 +791,12 @@ def _to_vjp(_y, _diff_args, _diff_term): # `x` and the velocity `v`. Both of these have the same shape. So, by LangevinX we # denote the shape of the x component, and by LangevinTuple we denote the shape of # the tuple (x, v). -LangevinX = PyTree[Shaped[Array, "?*langevin"], "LangevinX"] -LangevinTuple = tuple[LangevinX, LangevinX] -_LangevinBM = TypeVar("_LangevinBM", bound=Union[LangevinX, AbstractBrownianIncrement]) +_LangevinX = PyTree[Shaped[Array, "?*langevin"], "LangevinX"] +_LangevinTuple = tuple[_LangevinX, _LangevinX] +_LangevinBM = TypeVar("_LangevinBM", bound=Union[_LangevinX, AbstractBrownianIncrement]) -class _LangevinDiffusionTerm(AbstractTerm[LangevinX, _LangevinBM]): +class _LangevinDiffusionTerm(AbstractTerm[_LangevinX, _LangevinBM]): r"""Represents the diffusion term in the Langevin SDE: $d \mathbf{x}_t = \mathbf{v}_t dt$ @@ -805,16 +805,11 @@ class _LangevinDiffusionTerm(AbstractTerm[LangevinX, _LangevinBM]): \nabla f( \mathbf{x}_t ) dt + \sqrt{2 \gamma u} d W_t.$ """ - gamma: LangevinX - u: LangevinX + gamma: _LangevinX + u: _LangevinX control: AbstractBrownianPath - def __init__(self, gamma, u, control: AbstractBrownianPath): - self.gamma = gamma - self.u = u - self.control = control - - def vf(self, t: RealScalarLike, y: LangevinTuple, args: Args) -> LangevinX: + def vf(self, t: RealScalarLike, y: _LangevinTuple, args: Args) -> _LangevinX: x, v = y def _fun(_gamma, _u, _v): @@ -826,7 +821,7 @@ def _fun(_gamma, _u, _v): def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> _LangevinBM: return self.control.evaluate(t0, t1, **kwargs) - def prod(self, vf: LangevinX, control: LangevinX) -> LangevinTuple: + def prod(self, vf: _LangevinX, control: _LangevinX) -> _LangevinTuple: if isinstance(control, AbstractBrownianIncrement): control = control.W @@ -837,25 +832,6 @@ def prod(self, vf: LangevinX, control: LangevinX) -> LangevinTuple: return x_out, v_out -_LangevinArgs = tuple[LangevinX, LangevinX, Callable[[LangevinX], LangevinX]] - - -def _langevin_drift(t, y: LangevinTuple, args: _LangevinArgs) -> LangevinTuple: - gamma, u, grad_f = args - x, v = y - f_x = grad_f(x) - d_x = v - d_v = jtu.tree_map( - lambda _gamma, _u, _v, _f_x: -_gamma * _v - _u * _f_x, - gamma, - u, - v, - f_x, - ) - d_y = (d_x, d_v) - return d_y - - def _broadcast_pytree(source, target_tree_shape): # Requires that source is a prefix tree of target_tree_shape def _inner_broadcast(_src_arr, _inner_tree_shape): @@ -867,57 +843,78 @@ def _inner_broadcast(_src_arr, _inner_tree_shape): return jtu.tree_map(_inner_broadcast, source, target_tree_shape) -class LangevinTerm(AbstractTerm): - r"""Used to represent the Langevin SDE, given by: +class _LangevinDriftTerm(AbstractTerm): + gamma: _LangevinX + u: _LangevinX + grad_f: Callable[[_LangevinX], _LangevinX] - $d \mathbf{x}_t = \mathbf{v}_t dt$ + def vf(self, t: RealScalarLike, y: _LangevinTuple, args: Args) -> _LangevinTuple: + x, v = y + f_x = self.grad_f(x) + d_x = v + d_v = jtu.tree_map( + lambda _gamma, _u, _v, _f_x: -_gamma * _v - _u * _f_x, + self.gamma, + self.u, + v, + f_x, + ) + d_y = (d_x, d_v) + return d_y + + def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> RealScalarLike: + return t1 - t0 + + def prod(self, vf: _LangevinTuple, control: RealScalarLike) -> _LangevinTuple: + return jtu.tree_map(lambda _vf: control * _vf, vf) + +def make_langevin_term( + gamma: PyTree[ArrayLike], + u: PyTree[ArrayLike], + grad_f: Callable, + bm: AbstractBrownianPath, + x0: _LangevinX, +) -> MultiTerm[tuple[_LangevinDriftTerm, _LangevinDiffusionTerm]]: + r"""Creates a term that represents the Underdamped Langevin Diffusion, given by: + $d \mathbf{x}_t = \mathbf{v}_t dt$ $d \mathbf{v}_t = - \gamma \mathbf{v}_t dt - u \nabla f( \mathbf{x}_t ) dt + \sqrt{2 \gamma u} d W_t.$ - where $\mathbf{x}_t, \mathbf{v}_t \in \mathbb{R}^d$ represent the position and velocity, $W$ is a Brownian motion in $\mathbb{R}^d$, $f: \mathbb{R}^d \rightarrow \mathbb{R}$ is a potential function, and - $ \gamma,u\in\mathbb{R}^{d\times d}$ are diagonal matrices representing - friction and a dampening parameter. - """ + $ \gamma,u\in\mathbb{R}^{d\times d}$ are diagonal matrices governing + the friction and the dampening of the system. - args: _LangevinArgs = eqx.field(static=True) - term: MultiTerm[tuple[ODETerm, _LangevinDiffusionTerm]] + **Arguments:** - def __init__(self, args, bm: AbstractBrownianPath, x0: LangevinX): - r"""**Arguments:** + - `gamma`: A vector containing the diagonal entries of the friction matrix; + a PyTree of the same shape as `x0`. + - `u`: A vector containing the diagonal entries of the dampening matrix; + a PyTree of the same shape as `x0`. + - `grad_f`: A callable representing the gradient of the potential function $f$. + This callable should take a PyTree of the same shape as `x0` and + return a PyTree of the same shape. + - `bm`: A Brownian path representing the Brownian motion $W$. + - `x0`: The initial state of the system (just the position without velocity); + only needed to check the PyTree structure of the other arguments. - - `args`: a tuple of the form $(\gamma, u, \nabla f)$ - - `bm`: a Brownian path - - `x0`: a point in the state space of the process (position only), - needed to determine the PyTree structure and shape of the process. - """ - g1, u1, grad_f = args - # the PyTree structure of g1 and u1 must be a prefix of the PyTree - # structure of x0, and the shapes must be broadcastable to the shape of - # each leaf of x0. - gamma = _broadcast_pytree(g1, x0) - u = _broadcast_pytree(u1, x0) - grad_f_shape = jax.eval_shape(grad_f, x0) - - def _shape_check_fun(_x, _g, _u, _fx): - return _x.shape == _g.shape == _u.shape == _fx.shape - - assert jtu.tree_all( - jtu.tree_map(_shape_check_fun, x0, gamma, u, grad_f_shape) - ), "The shapes of gamma, u, and grad_f(x0) must be the same as x0." - - self.args = (gamma, u, grad_f) - drift = ODETerm(lambda t, y, _: _langevin_drift(t, y, self.args)) - diffusion = _LangevinDiffusionTerm(gamma, u, bm) - self.term = MultiTerm(drift, diffusion) - - def vf(self, t: RealScalarLike, y: LangevinTuple, args: Args) -> tuple[Any, ...]: - return self.term.vf(t, y, args) + **Returns:** - def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> Any: - return self.term.contr(t0, t1, **kwargs) + A `MultiTerm` representing the Langevin SDE. + """ + + gamma = _broadcast_pytree(gamma, x0) + u = _broadcast_pytree(u, x0) + grad_f_shape = jax.eval_shape(grad_f, x0) + + def _shape_check_fun(_x, _g, _u, _fx): + return _x.shape == _g.shape == _u.shape == _fx.shape + + assert jtu.tree_all( + jtu.tree_map(_shape_check_fun, x0, gamma, u, grad_f_shape) + ), "The shapes of gamma, u, and grad_f(x0) must be the same as x0." - def prod(self, vf: tuple[Any, ...], control: Any) -> LangevinTuple: - return self.term.prod(vf, control) + drift = _LangevinDriftTerm(gamma, u, grad_f) + diffusion = _LangevinDiffusionTerm(gamma, u, bm) + return MultiTerm(drift, diffusion) diff --git a/test/helpers.py b/test/helpers.py index ac7af0a1..d27e465c 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -12,7 +12,7 @@ AbstractBrownianPath, AbstractTerm, ControlTerm, - LangevinTerm, + make_langevin_term, MultiTerm, ODETerm, VirtualBrownianTree, @@ -503,12 +503,12 @@ def get_terms(bm): def get_bqp(t0=0.3, t1=15.0, dtype=jnp.float32): grad_f_bqp = lambda x: 4 * x * (jnp.square(x) - 1) - args_bqp = (dtype(0.8), dtype(0.2), grad_f_bqp) + gamma, u = dtype(0.8), dtype(0.2) y0_bqp = (dtype(0), dtype(0)) w_shape_bqp = () def get_terms_bqp(bm): - return LangevinTerm(args_bqp, bm, x0=y0_bqp[0]) + return make_langevin_term(gamma, u, grad_f_bqp, bm, y0_bqp[0]) return SDE(get_terms_bqp, None, y0_bqp, t0, t1, w_shape_bqp) @@ -516,14 +516,13 @@ def get_terms_bqp(bm): def get_harmonic_oscillator(t0=0.3, t1=15.0, dtype=jnp.float32): gamma_hosc = jnp.array([2, 0.5], dtype=dtype) u_hosc = jnp.array([0.5, 2], dtype=dtype) - args_hosc = (gamma_hosc, u_hosc, lambda x: 2 * x) x0 = jnp.zeros((2,), dtype=dtype) v0 = jnp.zeros((2,), dtype=dtype) y0_hosc = (x0, v0) w_shape_hosc = (2,) def get_terms_hosc(bm): - return LangevinTerm(args_hosc, bm, x0) + return make_langevin_term(gamma_hosc, u_hosc, lambda x: 2 * x, bm, x0) return SDE(get_terms_hosc, None, y0_hosc, t0, t1, w_shape_hosc) @@ -538,14 +537,13 @@ def log_p(x): gamma = 2.0 u = 1.0 - args_neal = (gamma, u, grad_log_p) x0 = jnp.zeros((10,), dtype=dtype) v0 = jnp.zeros((10,), dtype=dtype) y0_neal = (x0, v0) w_shape_neal = (10,) def get_terms_neal(bm): - return LangevinTerm(args_neal, bm, x0) + return make_langevin_term(gamma, u, grad_log_p, bm, x0) return SDE(get_terms_neal, None, y0_neal, t0, t1, w_shape_neal) @@ -593,13 +591,12 @@ def grad_f(x): u = 1.0 gamma = 2.0 - args = (u, gamma, grad_f) x0 = jnp.array([-1, 0, 1, 1, 0, -1, 1, 0, -1], dtype=dtype) v0 = jnp.zeros((9,), dtype=dtype) y0_uld3 = (x0, v0) w_shape_uld3 = (9,) def get_terms_uld3(bm): - return LangevinTerm(args, bm, x0) + return make_langevin_term(u, gamma, grad_f, bm, x0) return SDE(get_terms_uld3, None, y0_uld3, t0, t1, w_shape_uld3) diff --git a/test/test_langevin.py b/test/test_langevin.py index 6690ca8b..8625d7e6 100644 --- a/test/test_langevin.py +++ b/test/test_langevin.py @@ -4,7 +4,7 @@ import jax.random as jr import jax.tree_util as jtu import pytest -from diffrax import diffeqsolve, LangevinTerm, SaveAt +from diffrax import diffeqsolve, make_langevin_term, SaveAt from .helpers import ( get_bqp, @@ -26,6 +26,7 @@ def _solvers_and_orders(): yield diffrax.ALIGN(0.1), 2.0 yield diffrax.ShOULD(0.1), 3.0 yield diffrax.QUICSORT(0.1), 3.0 + yield diffrax.ShARK(), 2.0 def get_pytree_langevin(t0=0.3, t1=1.0, dtype=jnp.float32): @@ -63,11 +64,10 @@ def grad_f(x): xb = x["qq"] return {"rr": jtu.tree_map(lambda _x: 0.2 * _x, xa), "qq": xb} - args = g1, u1, grad_f w_shape = jtu.tree_map(lambda _x: jax.ShapeDtypeStruct(_x.shape, _x.dtype), x0) def get_terms(bm): - return LangevinTerm(args, bm, x0) + return make_langevin_term(g1, u1, grad_f, bm, x0) return SDE(get_terms, None, y0, t0, t1, w_shape)