Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parametric control types #364

Merged
merged 34 commits into from
Feb 20, 2024

Conversation

tttc3
Copy link
Contributor

@tttc3 tttc3 commented Jan 30, 2024

Pull request that closes #359.

Core Changes:

  • Change _ControlTerm to AbstractControlTerm[_Control];
  • Allow type hinting of the form ControlTerm[AbstractBrownianPath]. Provided demonstration in _solvers/milstein.py. Will all relevant solvers want updating?

Additional Changes:

  • Bumped ruff and pyright pre-commit versions;
  • Corrected typos in _solvers/milstein.py.

Known Issues:

  • In test/test_term.py:133 the following type error is thrown (not sure why):
test/test_term.py:133:46 - error: Argument of type "(t0: Unknown, t1: Unknown) -> tuple[Array]" 
cannot be assigned to parameter "control" of type "_Control@ControlTerm" in function "__init__"
    Type "(t0: Unknown, t1: Unknown) -> tuple[Array]" cannot be assigned to type "AbstractPath"
      "function" is incompatible with "AbstractPath" (reportArgumentType)
  • In diffrax/_term.py:266 the following type error is being suppressed (as I can't find a nice way to avoid it):
diffrax/_term.py:266:14 - error: Cannot assign member "control" for type "AbstractControlTerm[_Control@AbstractControlTerm]*"
    Expression of type "AbstractPath | _CallableToPath" cannot be assigned to member "control" of class "AbstractControlTerm[_Control@AbstractControlTerm]"
      Member "__set__" is unknown
      Type "AbstractPath | _CallableToPath" cannot be assigned to type "_Control@AbstractControlTerm" (reportAttributeAccessIssue)

@patrick-kidger
Copy link
Owner

I think this looks excellent!

Allow type hinting of the form ControlTerm[AbstractBrownianPath]. Provided demonstration in _solvers/milstein.py. Will all relevant solvers want updating?

So for this, I'm now thinking that what we really need is a way to indicate the type of the control directly in AbstractTerm, not just AbstractControlTerm. The fact that we had _ControlTerm was only ever just an implementation detail.
Put another way, from a user perspective, then ideally they should only need to know about AbstractTerm and its concrete subclasses; I don't want to introduce a whole abstract hierarchy!

It's a bit odd, but would something like this work do you think?

class AbstractTerm(eqx.Module, Generic[_Control]):
    ...  # doesn't actually use _Control

class _AbstractControlTerm(AbstractTerm[_Control]):
    control: _Control
    ...

(possibly with a pyright: ignore to prevent any complaints about an unused generic parameter?)

I'm not completely sold on the design here, so happy to hear alternate suggestions.

In test/test_term.py:133 the following type error is thrown (not sure why):

Ah, this is due to a quirk of __init__ methods and inheritance. This takes a small amount of explanation.

Dataclasses will create a new __init__ method on subclasses, even if an __init__ method already exists in the parent class. (I.e. dataclasses do not inherit __init__ methods.) So in this case pyright thinks that is what is happening, and that ControlTerm.__init__ is being created just by looking at the fields... including control: _Control. (And not control: Union[AbstractPath, Callable].

In practice we've found that this is a bit awkard ergonomically, so Equinox deliberately diverges from standard dataclass behaviour here, by just inheriting __init__ methods as normal if one is provided. (If you're curious, Equinox has precisely two divergences from dataclasses: the one above, and the fact that you can write self.foo = bar inside __init__ methods, rather than requiring the object.__setattr__(self, "foo", bar) required with regular frozen dataclasses.)

In diffrax/_term.py:266 the following type error is being suppressed (as I can't find a nice way to avoid it):

I've bumped into versions of this one before. It's an annoying case with generics + fields + __init__ methods that I've not found a nice way to resolve. I think adding a pyright: ignore annotation is indeed appropriate here.

@tttc3
Copy link
Contributor Author

tttc3 commented Jan 31, 2024

Thank you for the information, and explanations, they will save me from many future headaches.

Regarding the design of the type hinting for AbstractTerm. An alternative approach would be to type hint the output type of contr. For example, with the new UnsafeBrownianPath, you could do something like:

class AbstractTerm(eqx.Module, Generic[_Control]):
    def contr(self, t0:RealScalarLike, t1:RealScalarLike) -> _Control:
        ...

# Restrict to control which evaluate to provide a LevyArea.
term: AbstractTerm[LevyArea]
# This would be valid
term = ControlTerm(UnsafeBrownianPath(..., levy_area="space-time"))
# This would be invalid.
term = ControlTerm(UnsafeBrownianPath(...))

This approach is more permissive than hinting on the paths themselves. It is also closer to my original problem, where I don't really care what path the user provides, only that the control evaluates to a certain shape.

The downside is that you might need to define new types for distinguishing between different controls.

@patrick-kidger
Copy link
Owner

Ah, that's a nice alternative solution. Indeed it does mean that there's no way to say something like 'this solver only accepts Brownian motions', since our existing Brownian paths just return normal JAX arrays. Another possible difficulty is that I'm not sure how easy this will be for either runtime or static type checkers to pick up on.

However on the plus side, it's something which really does make sense for AbstractTerm itself -- and arguably a point of Diffrax is that you shouldn't need to distinguish between different types of control, with the our term system deliberately erasing any difference in types there. So I think on balance I like your suggestion better - let's go with that!

I think one change that we would want to introduce is an AbstractLevyVal, so that we can distinguish between the different kinds of Levy area that we might want. The goal I have in mind here is that if we're going to do parameterisation by control, then we should be able to replace this (from #344):

is_bm = lambda x: isinstance(x, AbstractBrownianPath)
leaves = jtu.tree_leaves(diffusion, is_leaf=is_bm)
paths = [x for x in leaves if is_bm(x)]
for path in paths:
if sttla:
if not path.levy_area == "space-time-time":
raise ValueError(
"The Brownian path controlling the diffusion "
"should be initialised with `levy_area='space-time-time'`"
)
elif stla:
if path.levy_area not in ["space-time", "space-time-time"]:
raise ValueError(
"The Brownian path controlling the diffusion "
"should be initialised with `levy_area='space-time'`"
"or `levy_area='space-time-time'`"
)

with something more principled.

@tttc3
Copy link
Contributor Author

tttc3 commented Feb 1, 2024

Implemented so far:

  • Generic typing of AbstractTerm.contr
  • Generic typing of AbstractPath.evaluate and AbstractPath.derivative.

Remaining:

  • Abstraction of LevyVal.

Known Issues:

  • test/test_term.py:135 from the original comment.
  • Unfortunately it doesn't appear that any of mypy, pyright or beartype can spot the type violation in the example test/test_term.py:67. I've tested with reveal_type and the Generic term remains as Unknown.

@patrick-kidger
Copy link
Owner

Unfortunately it doesn't appear that any of mypy, pyright or beartype can spot the type violation in the example test/test_term.py:67.

It's not obvious to me what the error is supposed to be here? ControlTwo() has its return type annotated as LevyVal, which is consistent with the term_typed: diffrax.ControlTerm[diffrax.LevyVal] annotation.

For what it's worth, when it comes to tackling the fact that our Brownian motions switch their return type based on an __init__ argument, I think I've managed to figure out how to get pyright to check that (by adding some dummy annotations for __new__):

from typing import cast, Generic, Literal, overload, TypeVar, TYPE_CHECKING, Union

T = TypeVar("T")

class Foo(Generic[T]):
    def meth(self) -> T:
        ...

class Bar(Foo[T]):
    if TYPE_CHECKING:
        @overload
        def __new__(cls, arg: Literal["use_int"]) -> "Bar[int]":
            ...

        @overload
        def __new__(cls, arg: Literal["use_str"]) -> "Bar[str]":
            ...

        def __new__(cls, arg: Literal["use_int", "use_str"]):
            if arg == "use_int":
                return cast(Bar[int], None)
            elif arg == "use_str":
                return cast(Bar[str], None)
            else:
                assert False

    def __init__(self, arg: Literal["use_int", "use_str"]):
        self.arg = arg

    def meth(self) -> T:
        if self.arg == "use_int":
            return 2  # pyright: ignore
        elif self.arg == "use_str":
            return "hi"  # pyright: ignore
        else:
            assert False

def f(x: Foo[int]):
    pass

f(Bar("use_str"))  # this produces an error with pyright!

@tttc3
Copy link
Contributor Author

tttc3 commented Feb 3, 2024

So I've managed to avoid the error in test/test_term.py:136 by utilising a property that calls _callable_to_path on the _control attribute.

I've also simplified the type checking tests and added a placeholder for beartype runtime type checking, which can be enabled if beartype/beartype#238 is implemented. Without this feature there will be no way to perform checking of any control typed as "ControlTerm[PyTree[...]]", as static checking evaluates such hints to Unknown.
That said, the hints will still be useful as documentation.

I like your trick for the init argument dependant hints!

Is there any reason why use_levy and levy_area both need to exist If we had types like these:

class BrownianIncrement(eqx.Module):
    dt: ...
    W: ...

class SpaceTimeLevyArea(BrownianIncrement):
    H: ...
    bar_H: ...
    K: ...
    bar_k: ...

Maybe not exactly like this as they don't obey concrete is final?

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, I think this is starting to come together :D

Regarding having both use_levy and levy_area, this is to decouple two things:
(a) whether the Brownian motion is capable of generating Levy area;
(b) whether the solver would like to request Levy area.

In particular we'd like to be able to support code like this:

bm = VirtualBrownianTree(..., levy_area="space-time")
term = ControlTerm(..., bm)
if foo:
    solver = Euler()  # no levy area
else:
    solver = SomeKindOfSRK()  # uses levy area
diffeqsolve(term, solver, ...)

diffrax/_brownian/base.py Outdated Show resolved Hide resolved
diffrax/_custom_types.py Outdated Show resolved Hide resolved
diffrax/_solver/milstein.py Outdated Show resolved Hide resolved
diffrax/_term.py Outdated Show resolved Hide resolved
test/test_term.py Outdated Show resolved Hide resolved
@patrick-kidger
Copy link
Owner

Oh for completeness -- regarding the BrownianIncrement and SpaceTimeLevyArea types, this is how you'd write them in a concrete-means-final way:

class AbstractBrownianReturn(eqx.Module):
    dt: AbstractVar[...]
    W: AbstractVar[...]

class BrownianIncrement(AbstractBrownianReturn):
    dt: ...
    W: ...

class SpaceTimeLevyArea(AbstractBrownianReturn):
    dt: ...
    W: ...
    H: ...
    bar_H: ...
    K: ...
    bar_k: ...

if need be you could also introduce an AbstractLevyArea that sits between SpaceTimeLevyArea and AbstractBrownianReturn in the type hierarchy, if you wanted to additionally have another kind of FooLevyArea type.

(Or you could just keep just AbstractBrownianReturn, and then just use an ad-hoc supertype like Union[SpaceTimeLevyArea, FooLevyArea] rather than a named supertype like AbstractLevyArea -- this is particularly useful in cases where you don't control the type hierarchy because it's in someone else's library.)

@patrick-kidger
Copy link
Owner

Also also, would it be possible to add checking for this type into the existing runtime term_structure check here:

def _term_compatible(terms, term_structure):

?

Probably by doing a jax.eval_shape(term.contr, ...) and checking the return type.

That way we can have diffeqsolve enforce the validity of term_structure (and we don't depend on beartype to do so).

@patrick-kidger patrick-kidger mentioned this pull request Feb 6, 2024
Merged
@patrick-kidger
Copy link
Owner

(let me know when you next want my input on this PR by the way)

@tttc3
Copy link
Contributor Author

tttc3 commented Feb 11, 2024

I think almost everything should be there now. The primary additions since the last review are:

  • Corrections to directly address the review comments;
  • The addition of a control check in _integrate.py, along with a dedicated test in test_integrate.py;
  • The addition of AbstractLevyArea and the subclasses TimeLevyArea and SpaceTimeLevyArea.
  • Updated Brownian paths to expect one of the new AbstractLevyArea types for the levy_area.

TimeLevyArea might be better noted as BrownianIncrement. However, my reservation is that even if the path evaluates to a PyTree[Array], semantically, we still expect this array to represent a Brownian increment of the path?

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's great! I've just done a review. You can see my main comments are really just about a few points of either programming style (single-assignment, explicit bools in if statements) or PR style (try to minimise diffs so that they're easy for me to review!)

In other words, they're nits and I like how this looks overall :)

One question -- how do you want to tackle this alongside #370? I'm thinking we probably want the VF type to appear in the signature before the control type, e.g. AbstractTerm[_VF, _Control].

diffrax/_brownian/path.py Outdated Show resolved Hide resolved
diffrax/_brownian/tree.py Outdated Show resolved Hide resolved
diffrax/_brownian/tree.py Outdated Show resolved Hide resolved
diffrax/_custom_types.py Outdated Show resolved Hide resolved
diffrax/_custom_types.py Outdated Show resolved Hide resolved
diffrax/_brownian/tree.py Outdated Show resolved Hide resolved
diffrax/_integrate.py Outdated Show resolved Hide resolved
diffrax/_integrate.py Outdated Show resolved Hide resolved
diffrax/_integrate.py Outdated Show resolved Hide resolved
test/test_term.py Outdated Show resolved Hide resolved
@tttc3
Copy link
Contributor Author

tttc3 commented Feb 12, 2024

Regarding #370, I'm happy to include vector_field typing in this pull, and agree with your suggested format forAbstractTerm[_VF, _Control].

@tttc3 tttc3 force-pushed the parametric-types-variant branch from ffd14a2 to f607ebf Compare February 16, 2024 20:40
@patrick-kidger patrick-kidger changed the base branch from main to dev February 16, 2024 23:46
@patrick-kidger
Copy link
Owner

Heads-up that I've just (a) merged #378 into dev, and (b) switched the target branch on this PR to dev. The conflicts now being reported are just the changes from that PR -- I think just rebase on top of dev and things should work.

@tttc3
Copy link
Contributor Author

tttc3 commented Feb 18, 2024

I've added comments to, and attempted to simplify the logic in, the term_compatibility check. I've also minimised the diffs where possible and extended the tests to cover the types for which better_isinstance is needed. Providing the new tests are sufficient, I think this is now (hopefully) feature complete.

@tttc3 tttc3 force-pushed the parametric-types-variant branch from d359af8 to b0b18e4 Compare February 19, 2024 10:31
@patrick-kidger
Copy link
Owner

Okay, I think I'm happy with this PR now! I can see what's going on with the failing pre-commit checks: this is due to bumping the version of pyright and hitting an incompatiblity with Equinox's abstract variables. I'll submit a PR against yours fixing this.

@patrick-kidger
Copy link
Owner

patrick-kidger commented Feb 19, 2024

Okay, take a look at the parametric-tweaks branch I've just pushed. It's not 100% ready because (a) it's completely untested -- I gotta dash for now -- and (b) I want to add some tests for the tweaks I've made to the term compatbility. To explain what's going on here:

  • I don't think the single-argument return from get_args is ever possible; Python demands that all type annotations be filled.
  • However, they can be filled with TypeVars, so I've added a check for that. EDIT: actually, this ended up being way more complicated due to subclassing changing the number of type variables.
  • The y should be tree-mapped alongside the terms. This is admittedly an entirely undocumented point that I need to make clearer somewhere at some point. (But take a look at SemiImplicitEuler for an example of this.)

@patrick-kidger
Copy link
Owner

Aaaaand opened tttc3#1 with these changes.

patrick-kidger and others added 2 commits February 20, 2024 04:25
Phew, this ended up being a complicated one!

Let's start with the easy stuff:
- Disabled spurious pyright errors due to incompatible between pyright and `eqx.AbstractVar`.
- Now using ruff.lint and pinned exact typeguard version.

Now on to the hard stuff:
- Fixed term compatibibility missing some edge cases.

Edge cases? What edge cases? Well, what we had before was basically predicated around doing
```python
vf, contr = get_args(term_cls)
```
recalling that we may have e.g. `term_cls = AbstractTerm[SomeVectorField, SomeControl]`. So far so simple: get the arguments of a subscripted generic, no big deal.

What this failed to account for is that we may also have subclasses of this generic, e.g. `term_cls = ODETerm[SomeVectorField]`, such that some of the type variables have already been filled in when defining it:
```python
class ODETerm(AbstractTerm[_VF, RealScaleLike]): ...
```
so in this case, `get_args(term_cls)` simply returns a 1-tuple of `(SomeVectorField,)`. Oh no! Somehow we have to traverse both the filled-in type variables (to find that one of our type variables is `SomeVectorField` due to subscripting) *and* the type hierarchy (to figure out that the other type variable was filled in during the definition).

Once again, for clarity: given a subscriptable base class `AbstractTerm[_VF, _Control]` and some arbitrary possible-subscripted subclass, we need to find the values of `_VF` and `_Control`, regardless of whehther they have been passed in via subscripting the final class (and are `get_args`-able) or have been filled in during subclassing (and require traversing pseudo-type-hierarchies of `__orig_bases__`).

Any sane implementation would simply... not bother. There is no way that the hassle of figuring this out was going to be worth the small amount of type safety this brings...

So anyway, after a few hours working on this *far* past the point I should be going to sleep, this problem this is now solved. This PR introduces a new `get_args_of` function, called as `get_args_of(superclass, subclass, error_msg_if_necessary)`. This acts analogous to `get_args`, but instead of looking up both parameters (the type variables we want filled in) and the arguments (the values those type variables have been filled in with) on the same class, it looks up the parameters on the superclass, and their filled-in-values on the subclass. Pure madness.

(I'm also tagging @leycec here because this is exactly the kind of insane typing hackery that he seems to really enjoy.)

Does anyone else remember the days when this was a package primarily concerned about solving differential equations?
Parameterised terms: fixed term compatibility + spurious pyright errors
@patrick-kidger patrick-kidger merged commit 89980d0 into patrick-kidger:dev Feb 20, 2024
2 checks passed
@patrick-kidger
Copy link
Owner

And merged! Gosh, that ended up being much harder than we all thought it would be, huh?

Thank you for the contribution, and great work putting this all together. I'm really happy to have this in, it's a huge improvement to the codebase. :)

@tttc3 tttc3 deleted the parametric-types-variant branch February 20, 2024 17:22
@tttc3
Copy link
Contributor Author

tttc3 commented Feb 24, 2024

I just noticed that at some point we lost the reportUnnecessaryTypeIgnoreComment = true from the pyproject.toml, so the contrapositive static type checking test in test_terms aren't actually doing anything. Also, the commented out lines 53-73 in test_terms can probably be removed now that equivalent handling is provided directly in the term_compatibility checks?

@patrick-kidger
Copy link
Owner

Ah, agreed throughout. Can you send a PR?

patrick-kidger added a commit that referenced this pull request Apr 20, 2024
* Added parametric control types

* Demo of AbstractTerm change

* Correct initial parametric control type implementation.

* Parametric AbstractTerm initial implementation.

* Update tests and fix hinting

* Implement review comments.

* Add parametric control check to integrator.

* Update and test parametric control check

* Introduce new LevyArea types

* Updated Brownian path LevyArea types

* Replace Union types in isinstance checks

* Remove rogue comment

* Revert _brownian_arch to single assignment

* Revert _evaluate_leaf key splitting

* Rename variables in test_term

* Update isinstance and issubclass checks

* Safer handling in _denormalise_bm_inc

* Fix style in integrate control type check

* Add draft vector_field typing

* Add draft vector_field typing

* Fix term test

* Revert extemporaneous modifications in _tree

* Rename TimeLevyArea to BrownianIncrement and simplify diff

* Rename AbstractLevyReturn to AbstractBrownianReturn

* Rename _LevyArea to _BrownianReturn

* Enhance _term_compatiblity checks

* Fix merge issues

* Bump pre-commit and fix type hints

* Clean up from self-review

* Explicitly add typeguard to deps

* Bump ruff config to new syntax

* Parameterised terms: fixed term compatibility + spurious pyright errors

Phew, this ended up being a complicated one!

Let's start with the easy stuff:
- Disabled spurious pyright errors due to incompatible between pyright and `eqx.AbstractVar`.
- Now using ruff.lint and pinned exact typeguard version.

Now on to the hard stuff:
- Fixed term compatibibility missing some edge cases.

Edge cases? What edge cases? Well, what we had before was basically predicated around doing
```python
vf, contr = get_args(term_cls)
```
recalling that we may have e.g. `term_cls = AbstractTerm[SomeVectorField, SomeControl]`. So far so simple: get the arguments of a subscripted generic, no big deal.

What this failed to account for is that we may also have subclasses of this generic, e.g. `term_cls = ODETerm[SomeVectorField]`, such that some of the type variables have already been filled in when defining it:
```python
class ODETerm(AbstractTerm[_VF, RealScaleLike]): ...
```
so in this case, `get_args(term_cls)` simply returns a 1-tuple of `(SomeVectorField,)`. Oh no! Somehow we have to traverse both the filled-in type variables (to find that one of our type variables is `SomeVectorField` due to subscripting) *and* the type hierarchy (to figure out that the other type variable was filled in during the definition).

Once again, for clarity: given a subscriptable base class `AbstractTerm[_VF, _Control]` and some arbitrary possible-subscripted subclass, we need to find the values of `_VF` and `_Control`, regardless of whehther they have been passed in via subscripting the final class (and are `get_args`-able) or have been filled in during subclassing (and require traversing pseudo-type-hierarchies of `__orig_bases__`).

Any sane implementation would simply... not bother. There is no way that the hassle of figuring this out was going to be worth the small amount of type safety this brings...

So anyway, after a few hours working on this *far* past the point I should be going to sleep, this problem this is now solved. This PR introduces a new `get_args_of` function, called as `get_args_of(superclass, subclass, error_msg_if_necessary)`. This acts analogous to `get_args`, but instead of looking up both parameters (the type variables we want filled in) and the arguments (the values those type variables have been filled in with) on the same class, it looks up the parameters on the superclass, and their filled-in-values on the subclass. Pure madness.

(I'm also tagging @leycec here because this is exactly the kind of insane typing hackery that he seems to really enjoy.)

Does anyone else remember the days when this was a package primarily concerned about solving differential equations?

---------

Co-authored-by: Patrick Kidger <[email protected]>
patrick-kidger added a commit that referenced this pull request Apr 20, 2024
* Added parametric control types

* Demo of AbstractTerm change

* Correct initial parametric control type implementation.

* Parametric AbstractTerm initial implementation.

* Update tests and fix hinting

* Implement review comments.

* Add parametric control check to integrator.

* Update and test parametric control check

* Introduce new LevyArea types

* Updated Brownian path LevyArea types

* Replace Union types in isinstance checks

* Remove rogue comment

* Revert _brownian_arch to single assignment

* Revert _evaluate_leaf key splitting

* Rename variables in test_term

* Update isinstance and issubclass checks

* Safer handling in _denormalise_bm_inc

* Fix style in integrate control type check

* Add draft vector_field typing

* Add draft vector_field typing

* Fix term test

* Revert extemporaneous modifications in _tree

* Rename TimeLevyArea to BrownianIncrement and simplify diff

* Rename AbstractLevyReturn to AbstractBrownianReturn

* Rename _LevyArea to _BrownianReturn

* Enhance _term_compatiblity checks

* Fix merge issues

* Bump pre-commit and fix type hints

* Clean up from self-review

* Explicitly add typeguard to deps

* Bump ruff config to new syntax

* Parameterised terms: fixed term compatibility + spurious pyright errors

Phew, this ended up being a complicated one!

Let's start with the easy stuff:
- Disabled spurious pyright errors due to incompatible between pyright and `eqx.AbstractVar`.
- Now using ruff.lint and pinned exact typeguard version.

Now on to the hard stuff:
- Fixed term compatibibility missing some edge cases.

Edge cases? What edge cases? Well, what we had before was basically predicated around doing
```python
vf, contr = get_args(term_cls)
```
recalling that we may have e.g. `term_cls = AbstractTerm[SomeVectorField, SomeControl]`. So far so simple: get the arguments of a subscripted generic, no big deal.

What this failed to account for is that we may also have subclasses of this generic, e.g. `term_cls = ODETerm[SomeVectorField]`, such that some of the type variables have already been filled in when defining it:
```python
class ODETerm(AbstractTerm[_VF, RealScaleLike]): ...
```
so in this case, `get_args(term_cls)` simply returns a 1-tuple of `(SomeVectorField,)`. Oh no! Somehow we have to traverse both the filled-in type variables (to find that one of our type variables is `SomeVectorField` due to subscripting) *and* the type hierarchy (to figure out that the other type variable was filled in during the definition).

Once again, for clarity: given a subscriptable base class `AbstractTerm[_VF, _Control]` and some arbitrary possible-subscripted subclass, we need to find the values of `_VF` and `_Control`, regardless of whehther they have been passed in via subscripting the final class (and are `get_args`-able) or have been filled in during subclassing (and require traversing pseudo-type-hierarchies of `__orig_bases__`).

Any sane implementation would simply... not bother. There is no way that the hassle of figuring this out was going to be worth the small amount of type safety this brings...

So anyway, after a few hours working on this *far* past the point I should be going to sleep, this problem this is now solved. This PR introduces a new `get_args_of` function, called as `get_args_of(superclass, subclass, error_msg_if_necessary)`. This acts analogous to `get_args`, but instead of looking up both parameters (the type variables we want filled in) and the arguments (the values those type variables have been filled in with) on the same class, it looks up the parameters on the superclass, and their filled-in-values on the subclass. Pure madness.

(I'm also tagging @leycec here because this is exactly the kind of insane typing hackery that he seems to really enjoy.)

Does anyone else remember the days when this was a package primarily concerned about solving differential equations?

---------

Co-authored-by: Patrick Kidger <[email protected]>
patrick-kidger added a commit that referenced this pull request May 19, 2024
* Added parametric control types

* Demo of AbstractTerm change

* Correct initial parametric control type implementation.

* Parametric AbstractTerm initial implementation.

* Update tests and fix hinting

* Implement review comments.

* Add parametric control check to integrator.

* Update and test parametric control check

* Introduce new LevyArea types

* Updated Brownian path LevyArea types

* Replace Union types in isinstance checks

* Remove rogue comment

* Revert _brownian_arch to single assignment

* Revert _evaluate_leaf key splitting

* Rename variables in test_term

* Update isinstance and issubclass checks

* Safer handling in _denormalise_bm_inc

* Fix style in integrate control type check

* Add draft vector_field typing

* Add draft vector_field typing

* Fix term test

* Revert extemporaneous modifications in _tree

* Rename TimeLevyArea to BrownianIncrement and simplify diff

* Rename AbstractLevyReturn to AbstractBrownianReturn

* Rename _LevyArea to _BrownianReturn

* Enhance _term_compatiblity checks

* Fix merge issues

* Bump pre-commit and fix type hints

* Clean up from self-review

* Explicitly add typeguard to deps

* Bump ruff config to new syntax

* Parameterised terms: fixed term compatibility + spurious pyright errors

Phew, this ended up being a complicated one!

Let's start with the easy stuff:
- Disabled spurious pyright errors due to incompatible between pyright and `eqx.AbstractVar`.
- Now using ruff.lint and pinned exact typeguard version.

Now on to the hard stuff:
- Fixed term compatibibility missing some edge cases.

Edge cases? What edge cases? Well, what we had before was basically predicated around doing
```python
vf, contr = get_args(term_cls)
```
recalling that we may have e.g. `term_cls = AbstractTerm[SomeVectorField, SomeControl]`. So far so simple: get the arguments of a subscripted generic, no big deal.

What this failed to account for is that we may also have subclasses of this generic, e.g. `term_cls = ODETerm[SomeVectorField]`, such that some of the type variables have already been filled in when defining it:
```python
class ODETerm(AbstractTerm[_VF, RealScaleLike]): ...
```
so in this case, `get_args(term_cls)` simply returns a 1-tuple of `(SomeVectorField,)`. Oh no! Somehow we have to traverse both the filled-in type variables (to find that one of our type variables is `SomeVectorField` due to subscripting) *and* the type hierarchy (to figure out that the other type variable was filled in during the definition).

Once again, for clarity: given a subscriptable base class `AbstractTerm[_VF, _Control]` and some arbitrary possible-subscripted subclass, we need to find the values of `_VF` and `_Control`, regardless of whehther they have been passed in via subscripting the final class (and are `get_args`-able) or have been filled in during subclassing (and require traversing pseudo-type-hierarchies of `__orig_bases__`).

Any sane implementation would simply... not bother. There is no way that the hassle of figuring this out was going to be worth the small amount of type safety this brings...

So anyway, after a few hours working on this *far* past the point I should be going to sleep, this problem this is now solved. This PR introduces a new `get_args_of` function, called as `get_args_of(superclass, subclass, error_msg_if_necessary)`. This acts analogous to `get_args`, but instead of looking up both parameters (the type variables we want filled in) and the arguments (the values those type variables have been filled in with) on the same class, it looks up the parameters on the superclass, and their filled-in-values on the subclass. Pure madness.

(I'm also tagging @leycec here because this is exactly the kind of insane typing hackery that he seems to really enjoy.)

Does anyone else remember the days when this was a package primarily concerned about solving differential equations?

---------

Co-authored-by: Patrick Kidger <[email protected]>
@@ -23,7 +23,7 @@ classifiers = [
"Topic :: Scientific/Engineering :: Mathematics",
]
urls = {repository = "https://github.com/patrick-kidger/diffrax" }
dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "equinox>=0.11.2", "lineax>=0.0.4", "optimistix>=0.0.6"]
dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.2", "lineax>=0.0.4", "optimistix>=0.0.6"]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is typeguard pinned to an exact version, especially one that is so old?

It's generally considered very bad packaging practice to pin to exact versions as it's ~never required and prevents getting even security patches without publishing a new version.

Also, it's ancient and won't be getting any updates - current HEAD is 4.2.1:

image

image

@dhirschfeld dhirschfeld mentioned this pull request May 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants