Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed a major source of bugs: ControlTerms no longer broadcast. #565

Open
wants to merge 2 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 3 additions & 16 deletions diffrax/_custom_types.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import typing
from typing import Any, TYPE_CHECKING, Union

import equinox as eqx
Expand All @@ -13,6 +12,7 @@
Float,
Int,
PyTree,
Real,
Shaped,
)

Expand All @@ -21,27 +21,14 @@
BoolScalarLike = Union[bool, Array, np.ndarray]
FloatScalarLike = Union[float, Array, np.ndarray]
IntScalarLike = Union[int, Array, np.ndarray]
elif getattr(typing, "GENERATING_DOCUMENTATION", False):
# Skip the union with Array in docs.
BoolScalarLike = bool
FloatScalarLike = float
IntScalarLike = int

#
# Because they appear in our docstrings, we also monkey-patch some non-Diffrax
# types that have similar defined-in-one-place, exported-in-another behaviour.
#

jtu.Partial.__module__ = "jax.tree_util"

RealScalarLike = Union[bool, int, float, Array, np.ndarray]
else:
BoolScalarLike = Bool[ArrayLike, ""]
FloatScalarLike = Float[ArrayLike, ""]
IntScalarLike = Int[ArrayLike, ""]
RealScalarLike = Real[ArrayLike, ""]


RealScalarLike = Union[FloatScalarLike, IntScalarLike]

Y = PyTree[Shaped[ArrayLike, "?*y"], "Y"]
VF = PyTree[Shaped[ArrayLike, "?*vf"], "VF"]
Control = PyTree[Shaped[ArrayLike, "?*control"], "C"]
Expand Down
295 changes: 210 additions & 85 deletions diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _callable_to_path(
x: Union[
AbstractPath[_Control], Callable[[RealScalarLike, RealScalarLike], _Control]
],
) -> AbstractPath[_Control]:
) -> AbstractPath:
if isinstance(x, AbstractPath):
return x
else:
Expand All @@ -270,55 +270,7 @@ def _prod(vf, control):
return jnp.tensordot(jnp.conj(vf), control, axes=jnp.ndim(control))


# This class exists for backward compatibility with `WeaklyDiagonalControlTerm`. If we
# were writing things again today it would be folded into just `ControlTerm`.
class _AbstractControlTerm(AbstractTerm[_VF, _Control]):
vector_field: Callable[[RealScalarLike, Y, Args], _VF]
control: Union[
AbstractPath[_Control], Callable[[RealScalarLike, RealScalarLike], _Control]
] = eqx.field(converter=_callable_to_path) # pyright: ignore

def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF:
return self.vector_field(t, y, args)

def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> _Control:
return self.control.evaluate(t0, t1, **kwargs) # pyright: ignore

def to_ode(self) -> ODETerm:
r"""If the control is differentiable then $f(t, y(t), args) \mathrm{d}x(t)$
may be thought of as an ODE as

$f(t, y(t), args) \frac{\mathrm{d}x}{\mathrm{d}t}\mathrm{d}t$.

This method converts this `ControlTerm` into the corresponding
[`diffrax.ODETerm`][] in this way.
"""
vector_field = _ControlToODE(self)
return ODETerm(vector_field=vector_field)


_AbstractControlTerm.__init__.__doc__ = """**Arguments:**

- `vector_field`: A callable representing the vector field. This callable takes three
arguments `(t, y, args)`. `t` is a scalar representing the integration time. `y` is
the evolving state of the system. `args` are any static arguments as passed to
[`diffrax.diffeqsolve`][]. This `vector_field` can either be

1. a function that returns a PyTree of JAX arrays, or
2. it can return a
[Lineax linear operator](https://docs.kidger.site/lineax/api/operators),
as described above.

- `control`: The control. Should either be

1. a [`diffrax.AbstractPath`][], in which case its `.evaluate(t0, t1)` method
will be used to give the increment of the control over a time interval
`[t0, t1]`, or
2. a callable `(t0, t1) -> increment`, which returns the increment directly.
"""


class ControlTerm(_AbstractControlTerm[_VF, _Control]):
class ControlTerm(AbstractTerm[_VF, _Control]):
r"""A term representing the general case of $f(t, y(t), args) \mathrm{d}x(t)$, in
which the vector field ($f$) - control ($\mathrm{d}x$) interaction is a
matrix-vector product.
Expand Down Expand Up @@ -380,6 +332,7 @@ def vector_field(t, y, args):
diffusion_term = ControlTerm(vector_field, control)
diffeqsolve(terms=diffusion_term, y0=y0, ...)
```

!!! Example

In this example we consider an SDE with a one-dimensional state
Expand Down Expand Up @@ -451,14 +404,182 @@ def vector_field(t, y, args):
```
""" # noqa: E501

vector_field: Callable[[RealScalarLike, Y, Args], _VF]
control: AbstractPath[_Control]

def __init__(
self,
vector_field: Callable[[RealScalarLike, Y, Args], _VF],
control: Union[
AbstractPath[_Control], Callable[[RealScalarLike, RealScalarLike], _Control]
],
):
self.vector_field = vector_field
self.control = _callable_to_path(control)

def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF:
return self.vector_field(t, y, args)

def contr(self, t0: RealScalarLike, t1: RealScalarLike, **kwargs) -> _Control:
return self.control.evaluate(t0, t1, **kwargs)

def prod(self, vf: _VF, control: _Control) -> Y:
if isinstance(vf, lx.AbstractLinearOperator):
return vf.mv(control)
else:
return jtu.tree_map(_prod, vf, control)

def vf_prod(self, t: RealScalarLike, y: Y, args: Args, control: _Control) -> Y:
vf = self.vf(t, y, args)
out = self.prod(vf, control)

def _raise():
# SDEs are a common special case; try to make the error message a little
# easier to understand in this case!
if isinstance(self.control, AbstractBrownianPath):
diffusion_word = "diffusion"
control_word = "Brownian motion"
diffusion_phrase = "diffusion matrix"
else:
diffusion_word = "vector field"
control_word = "control"
diffusion_phrase = "vector field in a control term"
if isinstance(vf, lx.AbstractLinearOperator):
dot_phrase = (
f"combined with `{type(vf).__module__}.{type(vf).__qualname__}.mv`"
)
else:
dot_phrase = "dotted together"
vf_str = eqx.tree_pformat(vf)
control_str = eqx.tree_pformat(control)
out_str = eqx.tree_pformat(out)
y_str = eqx.tree_pformat(y)
if "\n" in vf_str:
vf_str = f"\n```\n{vf_str}\n```\n"
else:
vf_str = f" `{vf_str}` "
if "\n" in control_str:
control_str = f"\n```\n{control_str}\n```\n"
else:
control_str = f" `{control_str}`, "
if "\n" in out_str:
out_str = f"\n```\n{out_str}\n```\n"
else:
out_str = f" `{out_str}`, "
if "\n" in y_str:
y_str = f"\n```\n{y_str}\n```\n"
else:
y_str = f" `{y_str}`.\n"
raise ValueError(
"The `ControlTerm` returned arrays whose output structure did not "
"match the structure of the evolving state `y`. Specifically, the "
f"{diffusion_word} had structure{vf_str}and the {control_word} "
f"had structure{control_str}which when {dot_phrase} produced an "
f"output of structure{out_str}which is different to the evolving "
f"state `y` which had structure{y_str}"
"\n"
"This became an error in Diffrax 0.7.0. In previous versions of "
"Diffrax then the output was broadcast to the shape of `y`. This "
"has been removed as it was a common source of bugs.\n"
"\n"
"To walk you through what is going on, here is a sample program "
"that now raises an error:\n"
"```\n"
"import diffrax as dfx\n"
"import jax.numpy as jnp\n"
"import jax.random as jr\n"
"\n"
"def drift(t, y, args):\n"
" return -y\n"
"\n"
"def diffusion(t, y, args):\n"
" return jnp.array([1., 0.5])\n"
"\n"
"key = jr.key(0)\n"
"bm = dfx.VirtualBrownianTree(t0=0, t1=1, tol=1e-3, shape=(2,), key=key)\n" # noqa: E501
"terms = dfx.MultiTerm(dfx.ODETerm(drift), dfx.ControlTerm(diffusion, bm))\n" # noqa: E501
"solver = dfx.Euler()\n"
"y0 = jnp.array([1., 1.])\n"
"dfx.diffeqsolve(terms, solver, t0=0, t1=1, dt0=0.1, y0=y0)\n"
"```\n"
"In this case, the diffusion returns an array of shape `(2,)` and "
"the Brownian motion is of shape `(2,)`. By the rules of "
"`ControlTerm`, they are then dotted together so that the "
"diffusion term returns a scalar. Under previous versions of "
"Diffrax, this would then be broadcast out to both elements of the "
"evolving state `y`, corresponding to the SDE:\n"
"```\n"
"dy₁(t) = -y₁(t) dt + dW₁ + 0.5 dW₂\n"
"dy₂(t) = -y₂(t) dt + dW₁ + 0.5 dW₂\n"
"```\n"
"or the equivalent in vector notation, with `y(t), W(t) ⋹ R²`\n"
"```\n"
"dy(t) = -y(t) dt + [[1, 0.5], [1, 0.5]] dW\n"
"```\n"
"Which may have been unexpected! Quite possibly what was actually "
"intended was an SDE with diagonal noise:\n"
"```\n"
"dy(t) = -y(t) dt + [[1, 0], [0, 0.5]] dW\n"
"```\n"
"\n"
"As of Diffrax 0.7.0, the recommended way to express the "
f"{diffusion_phrase} is to use a Lineax linear operator. "
"(https://docs.kidger.site/lineax/api/operators/) For example, to "
"represent diagonal noise in the example above:\n"
"```python\n"
"import lineax as lx\n"
"\n"
"def diffusion(t, y, args):\n"
" diagonal = jnp.array([1., 0.5])\n"
" return lx.DiagonalLinearOperator(diagonal)\n"
"```\n"
)

if jtu.tree_structure(y) != jtu.tree_structure(out):
_raise()

def _check_shape(yi, out_i):
if jnp.shape(yi) != jnp.shape(out_i):
_raise()

jtu.tree_map(_check_shape, y, out)
return out

def to_ode(self) -> ODETerm:
r"""If the control is differentiable then $f(t, y(t), args) \mathrm{d}x(t)$
may be thought of as an ODE as

$f(t, y(t), args) \frac{\mathrm{d}x}{\mathrm{d}t}\mathrm{d}t$.

This method converts this `ControlTerm` into the corresponding
[`diffrax.ODETerm`][] in this way.
"""
vector_field = _ControlToODE(self)
return ODETerm(vector_field=vector_field)


ControlTerm.__init__.__doc__ = """**Arguments:**

- `vector_field`: A callable representing the vector field. This callable takes three
arguments `(t, y, args)`. `t` is a scalar representing the integration time. `y` is
the evolving state of the system. `args` are any static arguments as passed to
[`diffrax.diffeqsolve`][]. This `vector_field` can either be

1. a function that returns a PyTree of JAX arrays, or
2. it can return a
[Lineax linear operator](https://docs.kidger.site/lineax/api/operators),
as described above.

- `control`: The control. Should either be

1. a [`diffrax.AbstractPath`][], in which case its `.evaluate(t0, t1)` method
will be used to give the increment of the control over a time interval
`[t0, t1]`, or
2. a callable `(t0, t1) -> increment`, which returns the increment directly.
"""

class WeaklyDiagonalControlTerm(_AbstractControlTerm[_VF, _Control]):

def WeaklyDiagonalControlTerm(vector_field, control):
r"""
DEPRECATED. Prefer:

Expand All @@ -469,6 +590,9 @@ def vector_field(t, y, args):
diffrax.ControlTerm(vector_field, ...)
```

The current implementation is a backward-compatible shim that returns something like
the code snippet the above.

---

A term representing the case of $f(t, y(t), args) \mathrm{d}x(t)$, in
Expand All @@ -492,45 +616,46 @@ def vector_field(t, y, args):
without the "weak". (This stronger property is useful in some SDE solvers.)
"""

def __check_init__(self):
warnings.warn(
"`WeaklyDiagonalControlTerm` is now deprecated, in favour combining "
"`ControlTerm` with a `lineax.AbstractLinearOperator`. This offers a way "
"to define a vector field with any kind of structure -- diagonal or "
"otherwise.\n"
"For a diagonal linear operator, then this can be easily converted as "
"follows. What was previously:\n"
"```\n"
"def vector_field(t, y, args):\n"
" ...\n"
" return some_vector\n"
"\n"
"diffrax.WeaklyDiagonalControlTerm(vector_field)\n"
"```\n"
"is now:\n"
"```\n"
"import lineax\n"
"\n"
"def vector_field(t, y, args):\n"
" ...\n"
" return lineax.DiagonalLinearOperator(some_vector)\n"
"\n"
"diffrax.ControlTerm(vector_field)\n"
"```\n"
"Lineax is available at `https://github.com/patrick-kidger/lineax`.\n",
stacklevel=3,
)

def prod(self, vf: _VF, control: _Control) -> Y:
with jax.numpy_dtype_promotion("standard"):
return jtu.tree_map(operator.mul, vf, control)
warnings.warn(
"`WeaklyDiagonalControlTerm` is now deprecated, in favour combining "
"`ControlTerm` with a `lineax.AbstractLinearOperator`. This offers a way "
"to define a vector field with any kind of structure -- diagonal or "
"otherwise.\n"
"For a diagonal linear operator, then this can be easily converted as "
"follows. What was previously:\n"
"```\n"
"def vector_field(t, y, args):\n"
" ...\n"
" return some_vector\n"
"\n"
"diffrax.WeaklyDiagonalControlTerm(vector_field)\n"
"```\n"
"is now:\n"
"```\n"
"import lineax\n"
"\n"
"def vector_field(t, y, args):\n"
" ...\n"
" return lineax.DiagonalLinearOperator(some_vector)\n"
"\n"
"diffrax.ControlTerm(vector_field)\n"
"```\n"
"Lineax is available at `https://github.com/patrick-kidger/lineax`.\n",
stacklevel=2,
)

def new_vector_field(t, y, args):
vf = vector_field(t, y, args)
return lx.DiagonalLinearOperator(vf)

return ControlTerm(new_vector_field, control)


class _ControlToODE(eqx.Module):
control_term: _AbstractControlTerm
control_term: ControlTerm

def __call__(self, t: RealScalarLike, y: Y, args: Args) -> Y:
control = self.control_term.control.derivative(t) # pyright: ignore
control = self.control_term.control.derivative(t)
return self.control_term.vf_prod(t, y, args, control)


Expand Down
Loading
Loading