From eea9652d39f8ffc4ff42acc26b0d376a699578dd Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 12 Jan 2025 21:45:09 +0100 Subject: [PATCH 1/2] Fixed a major source of bugs: ControlTerms no longer broadcast. --- diffrax/_term.py | 295 +++++++++++++++++++++++++++++------------ docs/api/terms.md | 4 +- pyproject.toml | 2 +- test/test_adjoint.py | 3 +- test/test_integrate.py | 63 ++++++--- test/test_sde2.py | 4 +- test/test_term.py | 8 +- test/test_typing.py | 5 - 8 files changed, 266 insertions(+), 118 deletions(-) diff --git a/diffrax/_term.py b/diffrax/_term.py index d13d430b..0ea97301 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -256,7 +256,7 @@ def _callable_to_path( x: Union[ AbstractPath[_Control], Callable[[RealScalarLike, RealScalarLike], _Control] ], -) -> AbstractPath[_Control]: +) -> AbstractPath: if isinstance(x, AbstractPath): return x else: @@ -270,55 +270,7 @@ def _prod(vf, control): return jnp.tensordot(jnp.conj(vf), control, axes=jnp.ndim(control)) -# This class exists for backward compatibility with `WeaklyDiagonalControlTerm`. If we -# were writing things again today it would be folded into just `ControlTerm`. -class _AbstractControlTerm(AbstractTerm[_VF, _Control]): - vector_field: Callable[[RealScalarLike, Y, Args], _VF] - control: Union[ - AbstractPath[_Control], Callable[[RealScalarLike, RealScalarLike], _Control] - ] = eqx.field(converter=_callable_to_path) # pyright: ignore - - def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF: - return self.vector_field(t, y, args) - - def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> _Control: - return self.control.evaluate(t0, t1, **kwargs) # pyright: ignore - - def to_ode(self) -> ODETerm: - r"""If the control is differentiable then $f(t, y(t), args) \mathrm{d}x(t)$ - may be thought of as an ODE as - - $f(t, y(t), args) \frac{\mathrm{d}x}{\mathrm{d}t}\mathrm{d}t$. - - This method converts this `ControlTerm` into the corresponding - [`diffrax.ODETerm`][] in this way. - """ - vector_field = _ControlToODE(self) - return ODETerm(vector_field=vector_field) - - -_AbstractControlTerm.__init__.__doc__ = """**Arguments:** - -- `vector_field`: A callable representing the vector field. This callable takes three - arguments `(t, y, args)`. `t` is a scalar representing the integration time. `y` is - the evolving state of the system. `args` are any static arguments as passed to - [`diffrax.diffeqsolve`][]. This `vector_field` can either be - - 1. a function that returns a PyTree of JAX arrays, or - 2. it can return a - [Lineax linear operator](https://docs.kidger.site/lineax/api/operators), - as described above. - -- `control`: The control. Should either be - - 1. a [`diffrax.AbstractPath`][], in which case its `.evaluate(t0, t1)` method - will be used to give the increment of the control over a time interval - `[t0, t1]`, or - 2. a callable `(t0, t1) -> increment`, which returns the increment directly. -""" - - -class ControlTerm(_AbstractControlTerm[_VF, _Control]): +class ControlTerm(AbstractTerm[_VF, _Control]): r"""A term representing the general case of $f(t, y(t), args) \mathrm{d}x(t)$, in which the vector field ($f$) - control ($\mathrm{d}x$) interaction is a matrix-vector product. @@ -380,6 +332,7 @@ def vector_field(t, y, args): diffusion_term = ControlTerm(vector_field, control) diffeqsolve(terms=diffusion_term, y0=y0, ...) ``` + !!! Example In this example we consider an SDE with a one-dimensional state @@ -451,14 +404,182 @@ def vector_field(t, y, args): ``` """ # noqa: E501 + vector_field: Callable[[RealScalarLike, Y, Args], _VF] + control: AbstractPath[_Control] + + def __init__( + self, + vector_field: Callable[[RealScalarLike, Y, Args], _VF], + control: Union[ + AbstractPath[_Control], Callable[[RealScalarLike, RealScalarLike], _Control] + ], + ): + self.vector_field = vector_field + self.control = _callable_to_path(control) + + def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF: + return self.vector_field(t, y, args) + + def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> _Control: + return self.control.evaluate(t0, t1, **kwargs) + def prod(self, vf: _VF, control: _Control) -> Y: if isinstance(vf, lx.AbstractLinearOperator): return vf.mv(control) else: return jtu.tree_map(_prod, vf, control) + def vf_prod(self, t: RealScalarLike, y: Y, args: Args, control: _Control) -> Y: + vf = self.vf(t, y, args) + out = self.prod(vf, control) + + def _raise(): + # SDEs are a common special case; try to make the error message a little + # easier to understand in this case! + if isinstance(self.control, AbstractBrownianPath): + diffusion_word = "diffusion" + control_word = "Brownian motion" + diffusion_phrase = "diffusion matrix" + else: + diffusion_word = "vector field" + control_word = "control" + diffusion_phrase = "vector field in a control term" + if isinstance(vf, lx.AbstractLinearOperator): + dot_phrase = ( + f"combined with `{type(vf).__module__}.{type(vf).__qualname__}.mv`" + ) + else: + dot_phrase = "dotted together" + vf_str = eqx.tree_pformat(vf) + control_str = eqx.tree_pformat(control) + out_str = eqx.tree_pformat(out) + y_str = eqx.tree_pformat(y) + if "\n" in vf_str: + vf_str = f"\n```\n{vf_str}\n```\n" + else: + vf_str = f" `{vf_str}` " + if "\n" in control_str: + control_str = f"\n```\n{control_str}\n```\n" + else: + control_str = f" `{control_str}`, " + if "\n" in out_str: + out_str = f"\n```\n{out_str}\n```\n" + else: + out_str = f" `{out_str}`, " + if "\n" in y_str: + y_str = f"\n```\n{y_str}\n```\n" + else: + y_str = f" `{y_str}`.\n" + raise ValueError( + "The `ControlTerm` returned arrays whose output structure did not " + "match the structure of the evolving state `y`. Specifically, the " + f"{diffusion_word} had structure{vf_str}and the {control_word} " + f"had structure{control_str}which when {dot_phrase} produced an " + f"output of structure{out_str}which is different to the evolving " + f"state `y` which had structure{y_str}" + "\n" + "This became an error in Diffrax 0.7.0. In previous versions of " + "Diffrax then the output was broadcast to the shape of `y`. This " + "has been removed as it was a common source of bugs.\n" + "\n" + "To walk you through what is going on, here is a sample program " + "that now raises an error:\n" + "```\n" + "import diffrax as dfx\n" + "import jax.numpy as jnp\n" + "import jax.random as jr\n" + "\n" + "def drift(t, y, args):\n" + " return -y\n" + "\n" + "def diffusion(t, y, args):\n" + " return jnp.array([1., 0.5])\n" + "\n" + "key = jr.key(0)\n" + "bm = dfx.VirtualBrownianTree(t0=0, t1=1, tol=1e-3, shape=(2,), key=key)\n" # noqa: E501 + "terms = dfx.MultiTerm(dfx.ODETerm(drift), dfx.ControlTerm(diffusion, bm))\n" # noqa: E501 + "solver = dfx.Euler()\n" + "y0 = jnp.array([1., 1.])\n" + "dfx.diffeqsolve(terms, solver, t0=0, t1=1, dt0=0.1, y0=y0)\n" + "```\n" + "In this case, the diffusion returns an array of shape `(2,)` and " + "the Brownian motion is of shape `(2,)`. By the rules of " + "`ControlTerm`, they are then dotted together so that the " + "diffusion term returns a scalar. Under previous versions of " + "Diffrax, this would then be broadcast out to both elements of the " + "evolving state `y`, corresponding to the SDE:\n" + "```\n" + "dy₁(t) = -y₁(t) dt + dW₁ + 0.5 dW₂\n" + "dy₂(t) = -y₂(t) dt + dW₁ + 0.5 dW₂\n" + "```\n" + "or the equivalent in vector notation, with `y(t), W(t) ⋹ R²`\n" + "```\n" + "dy(t) = -y(t) dt + [[1, 0.5], [1, 0.5]] dW\n" + "```\n" + "Which may have been unexpected! Quite possibly what was actually " + "intended was an SDE with diagonal noise:\n" + "```\n" + "dy(t) = -y(t) dt + [[1, 0], [0, 0.5]] dW\n" + "```\n" + "\n" + "As of Diffrax 0.7.0, the recommended way to express the " + f"{diffusion_phrase} is to use a Lineax linear operator. " + "(https://docs.kidger.site/lineax/api/operators/) For example, to " + "represent diagonal noise in the example above:\n" + "```python\n" + "import lineax as lx\n" + "\n" + "def diffusion(t, y, args):\n" + " diagonal = jnp.array([1., 0.5])\n" + " return lx.DiagonalLinearOperator(diagonal)\n" + "```\n" + ) + + if jtu.tree_structure(y) != jtu.tree_structure(out): + _raise() + + def _check_shape(yi, out_i): + if jnp.shape(yi) != jnp.shape(out_i): + _raise() + + jtu.tree_map(_check_shape, y, out) + return out + + def to_ode(self) -> ODETerm: + r"""If the control is differentiable then $f(t, y(t), args) \mathrm{d}x(t)$ + may be thought of as an ODE as + + $f(t, y(t), args) \frac{\mathrm{d}x}{\mathrm{d}t}\mathrm{d}t$. + + This method converts this `ControlTerm` into the corresponding + [`diffrax.ODETerm`][] in this way. + """ + vector_field = _ControlToODE(self) + return ODETerm(vector_field=vector_field) + + +ControlTerm.__init__.__doc__ = """**Arguments:** + +- `vector_field`: A callable representing the vector field. This callable takes three + arguments `(t, y, args)`. `t` is a scalar representing the integration time. `y` is + the evolving state of the system. `args` are any static arguments as passed to + [`diffrax.diffeqsolve`][]. This `vector_field` can either be + + 1. a function that returns a PyTree of JAX arrays, or + 2. it can return a + [Lineax linear operator](https://docs.kidger.site/lineax/api/operators), + as described above. + +- `control`: The control. Should either be + + 1. a [`diffrax.AbstractPath`][], in which case its `.evaluate(t0, t1)` method + will be used to give the increment of the control over a time interval + `[t0, t1]`, or + 2. a callable `(t0, t1) -> increment`, which returns the increment directly. +""" -class WeaklyDiagonalControlTerm(_AbstractControlTerm[_VF, _Control]): + +def WeaklyDiagonalControlTerm(vector_field, control): r""" DEPRECATED. Prefer: @@ -469,6 +590,9 @@ def vector_field(t, y, args): diffrax.ControlTerm(vector_field, ...) ``` + The current implementation is a backward-compatible shim that returns something like + the code snippet the above. + --- A term representing the case of $f(t, y(t), args) \mathrm{d}x(t)$, in @@ -492,45 +616,46 @@ def vector_field(t, y, args): without the "weak". (This stronger property is useful in some SDE solvers.) """ - def __check_init__(self): - warnings.warn( - "`WeaklyDiagonalControlTerm` is now deprecated, in favour combining " - "`ControlTerm` with a `lineax.AbstractLinearOperator`. This offers a way " - "to define a vector field with any kind of structure -- diagonal or " - "otherwise.\n" - "For a diagonal linear operator, then this can be easily converted as " - "follows. What was previously:\n" - "```\n" - "def vector_field(t, y, args):\n" - " ...\n" - " return some_vector\n" - "\n" - "diffrax.WeaklyDiagonalControlTerm(vector_field)\n" - "```\n" - "is now:\n" - "```\n" - "import lineax\n" - "\n" - "def vector_field(t, y, args):\n" - " ...\n" - " return lineax.DiagonalLinearOperator(some_vector)\n" - "\n" - "diffrax.ControlTerm(vector_field)\n" - "```\n" - "Lineax is available at `https://github.com/patrick-kidger/lineax`.\n", - stacklevel=3, - ) - - def prod(self, vf: _VF, control: _Control) -> Y: - with jax.numpy_dtype_promotion("standard"): - return jtu.tree_map(operator.mul, vf, control) + warnings.warn( + "`WeaklyDiagonalControlTerm` is now deprecated, in favour combining " + "`ControlTerm` with a `lineax.AbstractLinearOperator`. This offers a way " + "to define a vector field with any kind of structure -- diagonal or " + "otherwise.\n" + "For a diagonal linear operator, then this can be easily converted as " + "follows. What was previously:\n" + "```\n" + "def vector_field(t, y, args):\n" + " ...\n" + " return some_vector\n" + "\n" + "diffrax.WeaklyDiagonalControlTerm(vector_field)\n" + "```\n" + "is now:\n" + "```\n" + "import lineax\n" + "\n" + "def vector_field(t, y, args):\n" + " ...\n" + " return lineax.DiagonalLinearOperator(some_vector)\n" + "\n" + "diffrax.ControlTerm(vector_field)\n" + "```\n" + "Lineax is available at `https://github.com/patrick-kidger/lineax`.\n", + stacklevel=2, + ) + + def new_vector_field(t, y, args): + vf = vector_field(t, y, args) + return lx.DiagonalLinearOperator(vf) + + return ControlTerm(new_vector_field, control) class _ControlToODE(eqx.Module): - control_term: _AbstractControlTerm + control_term: ControlTerm def __call__(self, t: RealScalarLike, y: Y, args: Args) -> Y: - control = self.control_term.control.derivative(t) # pyright: ignore + control = self.control_term.control.derivative(t) return self.control_term.vf_prod(t, y, args, control) diff --git a/docs/api/terms.md b/docs/api/terms.md index 0c72f9f6..6eecde1b 100644 --- a/docs/api/terms.md +++ b/docs/api/terms.md @@ -71,7 +71,7 @@ Some example term structures include: ??? note "Defining your own term types" - For advanced users: you can create your own terms if appropriate. For example if your diffusion is matrix, itself computed as a matrix-matrix product, then you may wish to define a custom term and specify its [`diffrax.AbstractTerm.vf_prod`][] method. By overriding this method you could express the contraction of the vector field - control as a matrix-(matix-vector) product, which is more efficient than the default (matrix-matrix)-vector product. + For advanced users, you can create your own terms if appropriate. See for example the [underdamped Langevin terms](#underdamped-langevin-terms), which have their own special set of solvers. --- @@ -113,7 +113,7 @@ $\gamma , u \in \mathbb{R}^{d \times d}$ are diagonal matrices governing the friction and the damping of the system. These terms enable the use of ULD-specific solvers which can be found -[here](./solvers/sde_solvers.md#underdamped-langevin-solvers). Note that these ULD solvers will only work if given +[here](./solvers/sde_solvers.md#underdamped-langevin-solvers). These ULD solvers expect terms with structure `MultiTerm(UnderdampedLangevinDriftTerm(gamma, u, grad_f), UnderdampedLangevinDiffusionTerm(gamma, u, bm))`, where `bm` is an [`diffrax.AbstractBrownianPath`][] and the same values of `gammma` and `u` are passed to both terms. diff --git a/pyproject.toml b/pyproject.toml index 01cacf52..30e2a0fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "diffrax" -version = "0.6.2" +version = "0.7.0" description = "GPU+autodiff-capable ODE/SDE/CDE solvers written in JAX." readme = "README.md" requires-python ="~=3.9" diff --git a/test/test_adjoint.py b/test/test_adjoint.py index c45c6286..9e17e535 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -391,7 +391,8 @@ def g_lx(t, y, args): bm = diffrax.VirtualBrownianTree(t0, t1, tol, shape, key=getkey()) drift = diffrax.ODETerm(f) if diffusion_fn == "weak": - diffusion = diffrax.WeaklyDiagonalControlTerm(g, bm) + with pytest.warns(match="`WeaklyDiagonalControlTerm` is now deprecated"): + diffusion = diffrax.WeaklyDiagonalControlTerm(g, bm) else: diffusion = diffrax.ControlTerm(g_lx, bm) terms = diffrax.MultiTerm(drift, diffusion) diff --git a/test/test_integrate.py b/test/test_integrate.py index 555d6ade..4d039ebc 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -603,31 +603,49 @@ def test_term_compatibility(): class TestControl(eqx.Module): dt: Float[ArrayLike, ""] - def __rmul__(self, other): - return other.__mul__(self.dt) - - def __mul__(self, other): - return self.dt * other - class TestSolver(diffrax.Euler): term_structure = diffrax.AbstractTerm[ - tuple[Float[Array, "n 3"]], tuple[TestControl] + lx.AbstractLinearOperator, tuple[TestControl] ] + class TestLinearOperator(lx.AbstractLinearOperator): + def mv(self, vector): + assert ( + type(vector) is tuple + and len(vector) == 1 + and type(vector[0]) is TestControl + ) + return (jnp.ones((2, 3)) * vector[0].dt,) + + def as_matrix(self): + assert False + + def transpose(self): + assert False + + def in_structure(self): + return (jax.eval_shape(lambda: TestControl(1.0)),) + + def out_structure(self): + return (jax.ShapeDtypeStruct((2, 3), jnp.float64),) + + @lx.is_symmetric.register(TestLinearOperator) + def _(operator): + del operator + return False + solver = TestSolver() - incompatible_vf = lambda t, y, args: jnp.ones((2, 1)) - compatible_vf = lambda t, y, args: (jnp.ones((2, 3)),) + incompatible_vf = lambda t, y, args: jnp.ones((2, 3)) + compatible_vf = lambda t, y, args: TestLinearOperator() incompatible_control = lambda t0, t1: t1 - t0 compatible_control = lambda t0, t1: (TestControl(t1 - t0),) incompatible_terms = [ - diffrax.WeaklyDiagonalControlTerm(incompatible_vf, incompatible_control), - diffrax.WeaklyDiagonalControlTerm(incompatible_vf, compatible_control), - diffrax.WeaklyDiagonalControlTerm(compatible_vf, incompatible_control), + diffrax.ControlTerm(incompatible_vf, incompatible_control), + diffrax.ControlTerm(incompatible_vf, compatible_control), + diffrax.ControlTerm(compatible_vf, incompatible_control), ] - compatible_term = diffrax.WeaklyDiagonalControlTerm( - compatible_vf, compatible_control - ) + compatible_term = diffrax.ControlTerm(compatible_vf, compatible_control) for term in incompatible_terms: with pytest.raises(ValueError, match=r"Terms are not compatible with solver!"): diffrax.diffeqsolve(term, solver, 0.0, 1.0, 0.1, (jnp.zeros((2, 1)),)) @@ -669,6 +687,10 @@ def _step(_term, _y): def func(self, terms, t0, y0, args): assert False + def weakly_diagonal(*a): + with pytest.warns(match="`WeaklyDiagonalControlTerm` is now deprecated"): + return diffrax.WeaklyDiagonalControlTerm(*a) + ode_term = diffrax.ODETerm(lambda t, y, args: -y) solver = TestSolver() compatible_term = { @@ -678,8 +700,9 @@ def func(self, terms, t0, y0, args): "d": ode_term, "e": diffrax.MultiTerm( ode_term, - diffrax.WeaklyDiagonalControlTerm( - lambda t, y, args: -y, lambda t0, t1: jnp.array(t1 - t0).repeat(5) + weakly_diagonal( + lambda t, y, args: -lx.DiagonalLinearOperator(y), + lambda t0, t1: jnp.array(t1 - t0).repeat(5), ), ), "f": diffrax.MultiTerm( @@ -707,7 +730,7 @@ def func(self, terms, t0, y0, args): "d": ode_term, "e": diffrax.MultiTerm( ode_term, - diffrax.WeaklyDiagonalControlTerm( + weakly_diagonal( lambda t, y, args: -y, lambda t0, t1: t1 - t0, # wrong control shape ), @@ -727,7 +750,7 @@ def func(self, terms, t0, y0, args): # Missing "d" piece "e": diffrax.MultiTerm( ode_term, - diffrax.WeaklyDiagonalControlTerm( + weakly_diagonal( lambda t, y, args: -y, lambda t0, t1: jnp.array(t1 - t0).repeat(3) ), ), @@ -745,7 +768,7 @@ def func(self, terms, t0, y0, args): "c": ode_term, "d": ode_term, # No MultiTerm for "e" - "e": diffrax.WeaklyDiagonalControlTerm( + "e": weakly_diagonal( lambda t, y, args: -y, lambda t0, t1: jnp.array(t1 - t0).repeat(3) ), "f": diffrax.MultiTerm( diff --git a/test/test_sde2.py b/test/test_sde2.py index 3b4a4628..077177b8 100644 --- a/test/test_sde2.py +++ b/test/test_sde2.py @@ -83,7 +83,9 @@ def _drift(t, y, args): 0.0, 1.0, 0.05, w_shape, jr.key(0), diffrax.SpaceTimeLevyArea ) - terms = MultiTerm(ODETerm(_drift), WeaklyDiagonalControlTerm(_diffusion, bm)) + with pytest.warns(match="`WeaklyDiagonalControlTerm` is now deprecated"): + diffusion = WeaklyDiagonalControlTerm(_diffusion, bm) + terms = MultiTerm(ODETerm(_drift), diffusion) saveat = diffrax.SaveAt(t1=True) solution = diffrax.diffeqsolve( terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat diff --git a/test/test_term.py b/test/test_term.py index 8e8bf8be..0c75fc78 100644 --- a/test/test_term.py +++ b/test/test_term.py @@ -4,6 +4,7 @@ import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu +import lineax as lx import pytest from jaxtyping import Array, PyTree, Shaped @@ -84,15 +85,16 @@ def derivative(self, t, left=True): return jr.normal(derivkey, (3,)) control = Control() - term = diffrax.WeaklyDiagonalControlTerm(vector_field, control) + with pytest.warns(match="`WeaklyDiagonalControlTerm` is now deprecated"): + term = diffrax.WeaklyDiagonalControlTerm(vector_field, control) args = getkey() dx = term.contr(0, 1) y = jnp.array([1.0, 2.0, 3.0]) vf = term.vf(0, y, args) vf_prod = term.vf_prod(0, y, args, dx) - if isinstance(dx, jax.Array) and isinstance(vf, jax.Array): + if isinstance(dx, jax.Array) and isinstance(vf, lx.DiagonalLinearOperator): assert dx.shape == (3,) - assert vf.shape == (3,) + assert vf.diagonal.shape == (3,) else: raise TypeError("dx/vf is not an array") assert vf_prod.shape == (3,) diff --git a/test/test_typing.py b/test/test_typing.py index 4c4f3db1..705b0bcd 100644 --- a/test/test_typing.py +++ b/test/test_typing.py @@ -289,8 +289,3 @@ def test_ode_term(): def test_control_term(): assert _abstract_args(dfx.ControlTerm) == (Any, Any) assert _abstract_args(dfx.ControlTerm[int, str]) == (int, str) - - -def test_weakly_diagonal_control_term(): - assert _abstract_args(dfx.WeaklyDiagonalControlTerm) == (Any, Any) - assert _abstract_args(dfx.WeaklyDiagonalControlTerm[int, str]) == (int, str) From df1ba42eb9f4253c89f4731f6e644a88386ae107 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 12 Jan 2025 21:46:13 +0100 Subject: [PATCH 2/2] Now using jaxtyping.Real for prettier documentation. --- diffrax/_custom_types.py | 19 +++---------------- mkdocs.yml | 2 ++ 2 files changed, 5 insertions(+), 16 deletions(-) diff --git a/diffrax/_custom_types.py b/diffrax/_custom_types.py index 7e08aa1b..a16b4d61 100644 --- a/diffrax/_custom_types.py +++ b/diffrax/_custom_types.py @@ -1,4 +1,3 @@ -import typing from typing import Any, TYPE_CHECKING, Union import equinox as eqx @@ -13,6 +12,7 @@ Float, Int, PyTree, + Real, Shaped, ) @@ -21,27 +21,14 @@ BoolScalarLike = Union[bool, Array, np.ndarray] FloatScalarLike = Union[float, Array, np.ndarray] IntScalarLike = Union[int, Array, np.ndarray] -elif getattr(typing, "GENERATING_DOCUMENTATION", False): - # Skip the union with Array in docs. - BoolScalarLike = bool - FloatScalarLike = float - IntScalarLike = int - - # - # Because they appear in our docstrings, we also monkey-patch some non-Diffrax - # types that have similar defined-in-one-place, exported-in-another behaviour. - # - - jtu.Partial.__module__ = "jax.tree_util" - + RealScalarLike = Union[bool, int, float, Array, np.ndarray] else: BoolScalarLike = Bool[ArrayLike, ""] FloatScalarLike = Float[ArrayLike, ""] IntScalarLike = Int[ArrayLike, ""] + RealScalarLike = Real[ArrayLike, ""] -RealScalarLike = Union[FloatScalarLike, IntScalarLike] - Y = PyTree[Shaped[ArrayLike, "?*y"], "Y"] VF = PyTree[Shaped[ArrayLike, "?*vf"], "VF"] Control = PyTree[Shaped[ArrayLike, "?*control"], "C"] diff --git a/mkdocs.yml b/mkdocs.yml index b399fbd8..067cd458 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -75,6 +75,8 @@ plugins: setup_commands: - import pytkdocs_tweaks - pytkdocs_tweaks.main() + - import jax.tree_util + - jax.tree_util.Partial.__module__ = "jax.tree_util" selection: inherited_members: true # Allow looking up inherited methods