Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parametric control types #364

Merged
merged 34 commits into from
Feb 20, 2024
Merged
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
87b5201
Added parametric control types
tttc3 Jan 30, 2024
c05054c
Demo of AbstractTerm change
tttc3 Jan 30, 2024
61e97db
Correct initial parametric control type implementation.
tttc3 Jan 30, 2024
8a504b8
Parametric AbstractTerm initial implementation.
tttc3 Feb 1, 2024
1e20f33
Update tests and fix hinting
tttc3 Feb 3, 2024
cd23a0b
Implement review comments.
tttc3 Feb 6, 2024
6dbda14
Add parametric control check to integrator.
tttc3 Feb 9, 2024
813e85d
Update and test parametric control check
tttc3 Feb 9, 2024
267a41a
Introduce new LevyArea types
tttc3 Feb 11, 2024
9577d43
Updated Brownian path LevyArea types
tttc3 Feb 11, 2024
2e4fc83
Replace Union types in isinstance checks
tttc3 Feb 12, 2024
a744b9d
Remove rogue comment
tttc3 Feb 12, 2024
589f26a
Revert _brownian_arch to single assignment
tttc3 Feb 12, 2024
9e3021b
Revert _evaluate_leaf key splitting
tttc3 Feb 12, 2024
f1aceea
Rename variables in test_term
tttc3 Feb 12, 2024
d9e874a
Update isinstance and issubclass checks
tttc3 Feb 12, 2024
2661b35
Safer handling in _denormalise_bm_inc
tttc3 Feb 12, 2024
fb572d6
Fix style in integrate control type check
tttc3 Feb 13, 2024
e5f15e6
Add draft vector_field typing
tttc3 Feb 16, 2024
787ceee
Add draft vector_field typing
tttc3 Feb 16, 2024
3249ff9
Fix term test
tttc3 Feb 16, 2024
3c834f1
Revert extemporaneous modifications in _tree
tttc3 Feb 16, 2024
f607ebf
Rename TimeLevyArea to BrownianIncrement and simplify diff
tttc3 Feb 16, 2024
80fd358
Rename AbstractLevyReturn to AbstractBrownianReturn
tttc3 Feb 16, 2024
796b40f
Rename _LevyArea to _BrownianReturn
tttc3 Feb 16, 2024
ed073d3
Enhance _term_compatiblity checks
tttc3 Feb 17, 2024
3fd7f3f
Merge branch 'dev' into parametric-types-variant
tttc3 Feb 17, 2024
50f3363
Fix merge issues
tttc3 Feb 18, 2024
5ef24fa
Bump pre-commit and fix type hints
tttc3 Feb 18, 2024
b9fc23f
Clean up from self-review
tttc3 Feb 18, 2024
06fd6a9
Explicitly add typeguard to deps
tttc3 Feb 18, 2024
b0b18e4
Bump ruff config to new syntax
tttc3 Feb 19, 2024
3791d6d
Parameterised terms: fixed term compatibility + spurious pyright errors
patrick-kidger Feb 19, 2024
c08202c
Merge pull request #1 from patrick-kidger/parametric-tweaks
tttc3 Feb 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Parameterised terms: fixed term compatibility + spurious pyright errors
Phew, this ended up being a complicated one!

Let's start with the easy stuff:
- Disabled spurious pyright errors due to incompatible between pyright and `eqx.AbstractVar`.
- Now using ruff.lint and pinned exact typeguard version.

Now on to the hard stuff:
- Fixed term compatibibility missing some edge cases.

Edge cases? What edge cases? Well, what we had before was basically predicated around doing
```python
vf, contr = get_args(term_cls)
```
recalling that we may have e.g. `term_cls = AbstractTerm[SomeVectorField, SomeControl]`. So far so simple: get the arguments of a subscripted generic, no big deal.

What this failed to account for is that we may also have subclasses of this generic, e.g. `term_cls = ODETerm[SomeVectorField]`, such that some of the type variables have already been filled in when defining it:
```python
class ODETerm(AbstractTerm[_VF, RealScaleLike]): ...
```
so in this case, `get_args(term_cls)` simply returns a 1-tuple of `(SomeVectorField,)`. Oh no! Somehow we have to traverse both the filled-in type variables (to find that one of our type variables is `SomeVectorField` due to subscripting) *and* the type hierarchy (to figure out that the other type variable was filled in during the definition).

Once again, for clarity: given a subscriptable base class `AbstractTerm[_VF, _Control]` and some arbitrary possible-subscripted subclass, we need to find the values of `_VF` and `_Control`, regardless of whehther they have been passed in via subscripting the final class (and are `get_args`-able) or have been filled in during subclassing (and require traversing pseudo-type-hierarchies of `__orig_bases__`).

Any sane implementation would simply... not bother. There is no way that the hassle of figuring this out was going to be worth the small amount of type safety this brings...

So anyway, after a few hours working on this *far* past the point I should be going to sleep, this problem this is now solved. This PR introduces a new `get_args_of` function, called as `get_args_of(superclass, subclass, error_msg_if_necessary)`. This acts analogous to `get_args`, but instead of looking up both parameters (the type variables we want filled in) and the arguments (the values those type variables have been filled in with) on the same class, it looks up the parameters on the superclass, and their filled-in-values on the subclass. Pure madness.

(I'm also tagging @leycec here because this is exactly the kind of insane typing hackery that he seems to really enjoy.)

Does anyone else remember the days when this was a package primarily concerned about solving differential equations?
patrick-kidger committed Feb 20, 2024
commit 3791d6d5e552a38a8623c5bf5749474a0766609d
7 changes: 5 additions & 2 deletions diffrax/_adjoint.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import abc
import functools as ft
import warnings
from collections.abc import Iterable
from typing import Any, Optional, Union
from collections.abc import Callable, Iterable
from typing import Any, cast, Optional, Union

import equinox as eqx
import equinox.internal as eqxi
@@ -20,6 +20,9 @@
from ._term import AbstractTerm, AdjointTerm


ω = cast(Callable, ω)


def _is_none(x):
return x is None

5 changes: 4 additions & 1 deletion diffrax/_global_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import functools as ft
from collections.abc import Callable
from typing import Optional, TYPE_CHECKING
from typing import cast, Optional, TYPE_CHECKING

import equinox as eqx
import equinox.internal as eqxi
@@ -24,6 +24,9 @@
from ._path import AbstractPath


ω = cast(Callable, ω)


class AbstractGlobalInterpolation(AbstractPath):
ts: AbstractVar[Real[Array, " times"]]
ts_size: AbstractVar[IntScalarLike]
172 changes: 76 additions & 96 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
@@ -2,7 +2,14 @@
import typing
import warnings
from collections.abc import Callable
from typing import Any, get_args, get_origin, Optional, Tuple, TYPE_CHECKING
from typing import (
Any,
get_args,
get_origin,
Optional,
Tuple,
TYPE_CHECKING,
)

import equinox as eqx
import equinox.internal as eqxi
@@ -11,7 +18,6 @@
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax.internal as lxi
import typeguard
from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real

from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint
@@ -50,6 +56,7 @@
StepTo,
)
from ._term import AbstractTerm, MultiTerm, ODETerm, WrapTerm
from ._typing import better_isinstance, get_args_of, get_origin_no_specials


class SaveState(eqx.Module):
@@ -79,40 +86,6 @@ class State(eqx.Module):
progress_meter_state: PyTree[Array]


def _better_isinstance(x, annotation) -> bool:
"""isinstance check for parameterized generics.

!!! Example
```python
x = (1, jnp.array([2.0]))
y = ("test", Float[Array, "foo"])
expected_type = tuple[int, Float[Array, "foo"]]

assert _better_isinstance(x, expected_type) # passes as expected.
assert _better_isinstance(y, expected_type) # raises AssertionError as expected.

# Whereas with isinstance:
assert isinstance(x, expected_type) # raises TypeError.
assert isinstance(y, expected_type) # raises TypeError.
```
"""

@typeguard.typechecked
def f(y: annotation):
pass

try:
f(x)
except TypeError:
return False
else:
return True


def _assert_term_arg(x, term_arg) -> bool:
return _better_isinstance(x, term_arg)


def _is_none(x: Any) -> bool:
return x is None

@@ -123,52 +96,59 @@ def _term_compatible(
terms: PyTree[AbstractTerm],
term_structure: PyTree,
) -> bool:
def _check(term_cls, term):
if get_origin(term_cls) is MultiTerm:
error_msg = "term_structure"

def _check(term_cls, term, yi):
if get_origin_no_specials(term_cls, error_msg) is MultiTerm:
if isinstance(term, MultiTerm):
[_tmp] = get_args(term_cls)
assert get_origin(_tmp) in (tuple, Tuple), "Malformed term_structure"
if _term_compatible(y, args, term.terms, get_args(_tmp)):
return
raise ValueError
assert len(term.terms) == len(get_args(_tmp))
for term, arg in zip(term.terms, get_args(_tmp)):
if not _term_compatible(yi, args, term, arg):
raise ValueError
else:
raise ValueError
else:
# Check that `term` is an instance of `term_cls` (ignoring any generic
# parameterization).
origin_cls = get_origin(term_cls)
_term_cls = origin_cls if origin_cls is not None else term_cls
if not isinstance(term, _term_cls):
origin_cls = get_origin_no_specials(term_cls, error_msg)
if origin_cls is None:
origin_cls = term_cls
if not isinstance(term, origin_cls):
raise ValueError

# Now check the generic parametrization of `term_cls`; can be one of:
# -----------------------------------------
# `term_cls` | `term_args`
# --------------------------|--------------
# AbstractTerm | ()
# AbstractTerm[VF] | (VF,)
# AbstractTerm[VF, Control] | (VF, Control)
# -----------------------------------------
term_args = get_args(term_cls)
term_args = get_args_of(AbstractTerm, term_cls, error_msg)
n_term_args = len(term_args)
if n_term_args >= 1:
vf_type = eqx.filter_eval_shape(term.vf, 0.0, y, args)
vf_type_expected = term_args[0]
if n_term_args == 0:
pass
elif n_term_args == 2:
vf_type_expected, control_type_expected = term_args
vf_type = eqx.filter_eval_shape(term.vf, 0.0, yi, args)
vf_type_compatible = eqx.filter_eval_shape(
_assert_term_arg, vf_type, vf_type_expected
better_isinstance, vf_type, vf_type_expected
)
if not vf_type_compatible:
raise ValueError
if n_term_args == 2:
control_type = jax.eval_shape(term.contr, 0.0, 0.0)
control_type_expected = term_args[1]
control_type_compatible = eqx.filter_eval_shape(
_assert_term_arg, control_type, control_type_expected
better_isinstance, control_type, control_type_expected
)
if not control_type_compatible:
raise ValueError
return # If we've got to this point then the term is compatible
else:
assert False, "Malformed term structure"
# If we've got to this point then the term is compatible

try:
jtu.tree_map(_check, term_structure, terms)
jtu.tree_map(_check, term_structure, terms, y)
except ValueError:
# ValueError may also arise from mismatched tree structures
return False
@@ -732,47 +712,6 @@ def diffeqsolve(
stacklevel=2,
)

# Backward compatibility
if isinstance(
solver, (EulerHeun, ItoMilstein, StratonovichMilstein)
) and _term_compatible(y0, args, terms, (ODETerm, AbstractTerm)):
warnings.warn(
"Passing `terms=(ODETerm(...), SomeOtherTerm(...))` to "
f"{solver.__class__.__name__} is deprecated in favour of "
"`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that "
"the same terms can now be passed used for both general and SDE-specific "
"solvers!",
stacklevel=2,
)
terms = MultiTerm(*terms)

# Error checking
if not _term_compatible(y0, args, terms, solver.term_structure):
raise ValueError(
"`terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with "
f"structure {solver.term_structure}"
)

if is_sde(terms):
if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):
warnings.warn(
f"`{type(solver).__name__}` is not marked as converging to either the "
"Itô or the Stratonovich solution.",
stacklevel=2,
)
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
# Specific check to not work even if using HalfSolver(Euler())
if isinstance(solver, Euler):
raise ValueError(
"An SDE should not be solved with adaptive step sizes with Euler's "
"method, as it may not converge to the correct solution."
)
if is_unsafe_sde(terms):
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
raise ValueError(
"`UnsafeBrownianPath` cannot be used with adaptive step sizes."
)

# Allow setting e.g. t0 as an int with dt0 as a float.
timelikes = [t0, t1, dt0] + [
s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat)
@@ -816,6 +755,47 @@ def _promote(yi):
y0 = jtu.tree_map(_promote, y0)
del timelikes

# Backward compatibility
if isinstance(
solver, (EulerHeun, ItoMilstein, StratonovichMilstein)
) and _term_compatible(y0, args, terms, (ODETerm, AbstractTerm)):
warnings.warn(
"Passing `terms=(ODETerm(...), SomeOtherTerm(...))` to "
f"{solver.__class__.__name__} is deprecated in favour of "
"`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that "
"the same terms can now be passed used for both general and SDE-specific "
"solvers!",
stacklevel=2,
)
terms = MultiTerm(*terms)

# Error checking
if not _term_compatible(y0, args, terms, solver.term_structure):
raise ValueError(
"`terms` must be a PyTree of `AbstractTerms` (such as `ODETerm`), with "
f"structure {solver.term_structure}"
)

if is_sde(terms):
if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):
warnings.warn(
f"`{type(solver).__name__}` is not marked as converging to either the "
"Itô or the Stratonovich solution.",
stacklevel=2,
)
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
# Specific check to not work even if using HalfSolver(Euler())
if isinstance(solver, Euler):
raise ValueError(
"An SDE should not be solved with adaptive step sizes with Euler's "
"method, as it may not converge to the correct solution."
)
if is_unsafe_sde(terms):
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
raise ValueError(
"`UnsafeBrownianPath` cannot be used with adaptive step sizes."
)

# Normalises time: if t0 > t1 then flip things around.
direction = jnp.where(t0 < t1, 1, -1)
t0 = t0 * direction
6 changes: 5 additions & 1 deletion diffrax/_local_interpolation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Optional, TYPE_CHECKING
from collections.abc import Callable
from typing import cast, Optional, TYPE_CHECKING

import jax.numpy as jnp
import jax.tree_util as jtu
@@ -17,6 +18,9 @@
from ._path import AbstractPath


ω = cast(Callable, ω)


class AbstractLocalInterpolation(AbstractPath):
pass

5 changes: 4 additions & 1 deletion diffrax/_root_finder/_verychord.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections.abc import Callable
from typing import Any
from typing import Any, cast

import equinox as eqx
import jax
@@ -15,6 +15,9 @@
from .._custom_types import Y


ω = cast(Callable, ω)


def _small(diffsize: Scalar) -> Bool[Array, ""]:
# TODO(kidger): make a more careful choice here -- the existence of this
# function is pretty ad-hoc.
6 changes: 5 additions & 1 deletion diffrax/_solver/sil3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import ClassVar
from collections.abc import Callable
from typing import cast, ClassVar

import numpy as np
import optimistix as optx
@@ -15,6 +16,9 @@
)


ω = cast(Callable, ω)


# See
# https://docs.kidger.site/diffrax/devdocs/predictor_dirk/
# for the construction of the a_predictor tableau, which is new here.
3 changes: 3 additions & 0 deletions diffrax/_step_size_controller/adaptive.py
Original file line number Diff line number Diff line change
@@ -33,6 +33,9 @@
from .base import AbstractStepSizeController


ω = cast(Callable, ω)


def _select_initial_step(
terms: PyTree[AbstractTerm],
t0: RealScalarLike,
2 changes: 1 addition & 1 deletion diffrax/_step_size_controller/constant.py
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@ def adapt_step_size(
y1_candidate: Y,
args: Args,
y_error: Optional[Y],
error_order: RealScalarLike,
error_order: Optional[RealScalarLike],
controller_state: RealScalarLike,
) -> tuple[bool, RealScalarLike, RealScalarLike, bool, RealScalarLike, RESULTS]:
del t0, y0, y1_candidate, args, y_error, error_order
Loading