-
-
Notifications
You must be signed in to change notification settings - Fork 142
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Parametric control types #364
Parametric control types #364
Conversation
I think this looks excellent!
So for this, I'm now thinking that what we really need is a way to indicate the type of the control directly in It's a bit odd, but would something like this work do you think? class AbstractTerm(eqx.Module, Generic[_Control]):
... # doesn't actually use _Control
class _AbstractControlTerm(AbstractTerm[_Control]):
control: _Control
... (possibly with a I'm not completely sold on the design here, so happy to hear alternate suggestions.
Ah, this is due to a quirk of Dataclasses will create a new In practice we've found that this is a bit awkard ergonomically, so Equinox deliberately diverges from standard dataclass behaviour here, by just inheriting
I've bumped into versions of this one before. It's an annoying case with generics + fields + |
Thank you for the information, and explanations, they will save me from many future headaches. Regarding the design of the type hinting for class AbstractTerm(eqx.Module, Generic[_Control]):
def contr(self, t0:RealScalarLike, t1:RealScalarLike) -> _Control:
...
# Restrict to control which evaluate to provide a LevyArea.
term: AbstractTerm[LevyArea]
# This would be valid
term = ControlTerm(UnsafeBrownianPath(..., levy_area="space-time"))
# This would be invalid.
term = ControlTerm(UnsafeBrownianPath(...)) This approach is more permissive than hinting on the paths themselves. It is also closer to my original problem, where I don't really care what path the user provides, only that the control evaluates to a certain shape. The downside is that you might need to define new types for distinguishing between different controls. |
Ah, that's a nice alternative solution. Indeed it does mean that there's no way to say something like 'this solver only accepts Brownian motions', since our existing Brownian paths just return normal JAX arrays. Another possible difficulty is that I'm not sure how easy this will be for either runtime or static type checkers to pick up on. However on the plus side, it's something which really does make sense for I think one change that we would want to introduce is an diffrax/diffrax/_solver/srk.py Lines 285 to 301 in 59778bd
with something more principled. |
Implemented so far:
Remaining:
Known Issues:
|
It's not obvious to me what the error is supposed to be here? For what it's worth, when it comes to tackling the fact that our Brownian motions switch their return type based on an from typing import cast, Generic, Literal, overload, TypeVar, TYPE_CHECKING, Union
T = TypeVar("T")
class Foo(Generic[T]):
def meth(self) -> T:
...
class Bar(Foo[T]):
if TYPE_CHECKING:
@overload
def __new__(cls, arg: Literal["use_int"]) -> "Bar[int]":
...
@overload
def __new__(cls, arg: Literal["use_str"]) -> "Bar[str]":
...
def __new__(cls, arg: Literal["use_int", "use_str"]):
if arg == "use_int":
return cast(Bar[int], None)
elif arg == "use_str":
return cast(Bar[str], None)
else:
assert False
def __init__(self, arg: Literal["use_int", "use_str"]):
self.arg = arg
def meth(self) -> T:
if self.arg == "use_int":
return 2 # pyright: ignore
elif self.arg == "use_str":
return "hi" # pyright: ignore
else:
assert False
def f(x: Foo[int]):
pass
f(Bar("use_str")) # this produces an error with pyright! |
So I've managed to avoid the error in I've also simplified the type checking tests and added a placeholder for beartype runtime type checking, which can be enabled if beartype/beartype#238 is implemented. Without this feature there will be no way to perform checking of any control typed as "ControlTerm[PyTree[...]]", as static checking evaluates such hints to I like your trick for the init argument dependant hints! Is there any reason why class BrownianIncrement(eqx.Module):
dt: ...
W: ...
class SpaceTimeLevyArea(BrownianIncrement):
H: ...
bar_H: ...
K: ...
bar_k: ... Maybe not exactly like this as they don't obey concrete is final? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I think this is starting to come together :D
Regarding having both use_levy
and levy_area
, this is to decouple two things:
(a) whether the Brownian motion is capable of generating Levy area;
(b) whether the solver would like to request Levy area.
In particular we'd like to be able to support code like this:
bm = VirtualBrownianTree(..., levy_area="space-time")
term = ControlTerm(..., bm)
if foo:
solver = Euler() # no levy area
else:
solver = SomeKindOfSRK() # uses levy area
diffeqsolve(term, solver, ...)
Oh for completeness -- regarding the class AbstractBrownianReturn(eqx.Module):
dt: AbstractVar[...]
W: AbstractVar[...]
class BrownianIncrement(AbstractBrownianReturn):
dt: ...
W: ...
class SpaceTimeLevyArea(AbstractBrownianReturn):
dt: ...
W: ...
H: ...
bar_H: ...
K: ...
bar_k: ... if need be you could also introduce an (Or you could just keep just |
Also also, would it be possible to add checking for this type into the existing runtime Line 80 in 8aafefc
? Probably by doing a That way we can have |
(let me know when you next want my input on this PR by the way) |
I think almost everything should be there now. The primary additions since the last review are:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's great! I've just done a review. You can see my main comments are really just about a few points of either programming style (single-assignment, explicit bools in if statements) or PR style (try to minimise diffs so that they're easy for me to review!)
In other words, they're nits and I like how this looks overall :)
One question -- how do you want to tackle this alongside #370? I'm thinking we probably want the VF type to appear in the signature before the control type, e.g. AbstractTerm[_VF, _Control]
.
Regarding #370, I'm happy to include vector_field typing in this pull, and agree with your suggested format for |
ffd14a2
to
f607ebf
Compare
Heads-up that I've just (a) merged #378 into |
I've added comments to, and attempted to simplify the logic in, the |
d359af8
to
b0b18e4
Compare
Okay, I think I'm happy with this PR now! I can see what's going on with the failing pre-commit checks: this is due to bumping the version of pyright and hitting an incompatiblity with Equinox's abstract variables. I'll submit a PR against yours fixing this. |
Okay, take a look at the
|
Aaaaand opened tttc3#1 with these changes. |
Phew, this ended up being a complicated one! Let's start with the easy stuff: - Disabled spurious pyright errors due to incompatible between pyright and `eqx.AbstractVar`. - Now using ruff.lint and pinned exact typeguard version. Now on to the hard stuff: - Fixed term compatibibility missing some edge cases. Edge cases? What edge cases? Well, what we had before was basically predicated around doing ```python vf, contr = get_args(term_cls) ``` recalling that we may have e.g. `term_cls = AbstractTerm[SomeVectorField, SomeControl]`. So far so simple: get the arguments of a subscripted generic, no big deal. What this failed to account for is that we may also have subclasses of this generic, e.g. `term_cls = ODETerm[SomeVectorField]`, such that some of the type variables have already been filled in when defining it: ```python class ODETerm(AbstractTerm[_VF, RealScaleLike]): ... ``` so in this case, `get_args(term_cls)` simply returns a 1-tuple of `(SomeVectorField,)`. Oh no! Somehow we have to traverse both the filled-in type variables (to find that one of our type variables is `SomeVectorField` due to subscripting) *and* the type hierarchy (to figure out that the other type variable was filled in during the definition). Once again, for clarity: given a subscriptable base class `AbstractTerm[_VF, _Control]` and some arbitrary possible-subscripted subclass, we need to find the values of `_VF` and `_Control`, regardless of whehther they have been passed in via subscripting the final class (and are `get_args`-able) or have been filled in during subclassing (and require traversing pseudo-type-hierarchies of `__orig_bases__`). Any sane implementation would simply... not bother. There is no way that the hassle of figuring this out was going to be worth the small amount of type safety this brings... So anyway, after a few hours working on this *far* past the point I should be going to sleep, this problem this is now solved. This PR introduces a new `get_args_of` function, called as `get_args_of(superclass, subclass, error_msg_if_necessary)`. This acts analogous to `get_args`, but instead of looking up both parameters (the type variables we want filled in) and the arguments (the values those type variables have been filled in with) on the same class, it looks up the parameters on the superclass, and their filled-in-values on the subclass. Pure madness. (I'm also tagging @leycec here because this is exactly the kind of insane typing hackery that he seems to really enjoy.) Does anyone else remember the days when this was a package primarily concerned about solving differential equations?
Parameterised terms: fixed term compatibility + spurious pyright errors
And merged! Gosh, that ended up being much harder than we all thought it would be, huh? Thank you for the contribution, and great work putting this all together. I'm really happy to have this in, it's a huge improvement to the codebase. :) |
I just noticed that at some point we lost the |
Ah, agreed throughout. Can you send a PR? |
* Added parametric control types * Demo of AbstractTerm change * Correct initial parametric control type implementation. * Parametric AbstractTerm initial implementation. * Update tests and fix hinting * Implement review comments. * Add parametric control check to integrator. * Update and test parametric control check * Introduce new LevyArea types * Updated Brownian path LevyArea types * Replace Union types in isinstance checks * Remove rogue comment * Revert _brownian_arch to single assignment * Revert _evaluate_leaf key splitting * Rename variables in test_term * Update isinstance and issubclass checks * Safer handling in _denormalise_bm_inc * Fix style in integrate control type check * Add draft vector_field typing * Add draft vector_field typing * Fix term test * Revert extemporaneous modifications in _tree * Rename TimeLevyArea to BrownianIncrement and simplify diff * Rename AbstractLevyReturn to AbstractBrownianReturn * Rename _LevyArea to _BrownianReturn * Enhance _term_compatiblity checks * Fix merge issues * Bump pre-commit and fix type hints * Clean up from self-review * Explicitly add typeguard to deps * Bump ruff config to new syntax * Parameterised terms: fixed term compatibility + spurious pyright errors Phew, this ended up being a complicated one! Let's start with the easy stuff: - Disabled spurious pyright errors due to incompatible between pyright and `eqx.AbstractVar`. - Now using ruff.lint and pinned exact typeguard version. Now on to the hard stuff: - Fixed term compatibibility missing some edge cases. Edge cases? What edge cases? Well, what we had before was basically predicated around doing ```python vf, contr = get_args(term_cls) ``` recalling that we may have e.g. `term_cls = AbstractTerm[SomeVectorField, SomeControl]`. So far so simple: get the arguments of a subscripted generic, no big deal. What this failed to account for is that we may also have subclasses of this generic, e.g. `term_cls = ODETerm[SomeVectorField]`, such that some of the type variables have already been filled in when defining it: ```python class ODETerm(AbstractTerm[_VF, RealScaleLike]): ... ``` so in this case, `get_args(term_cls)` simply returns a 1-tuple of `(SomeVectorField,)`. Oh no! Somehow we have to traverse both the filled-in type variables (to find that one of our type variables is `SomeVectorField` due to subscripting) *and* the type hierarchy (to figure out that the other type variable was filled in during the definition). Once again, for clarity: given a subscriptable base class `AbstractTerm[_VF, _Control]` and some arbitrary possible-subscripted subclass, we need to find the values of `_VF` and `_Control`, regardless of whehther they have been passed in via subscripting the final class (and are `get_args`-able) or have been filled in during subclassing (and require traversing pseudo-type-hierarchies of `__orig_bases__`). Any sane implementation would simply... not bother. There is no way that the hassle of figuring this out was going to be worth the small amount of type safety this brings... So anyway, after a few hours working on this *far* past the point I should be going to sleep, this problem this is now solved. This PR introduces a new `get_args_of` function, called as `get_args_of(superclass, subclass, error_msg_if_necessary)`. This acts analogous to `get_args`, but instead of looking up both parameters (the type variables we want filled in) and the arguments (the values those type variables have been filled in with) on the same class, it looks up the parameters on the superclass, and their filled-in-values on the subclass. Pure madness. (I'm also tagging @leycec here because this is exactly the kind of insane typing hackery that he seems to really enjoy.) Does anyone else remember the days when this was a package primarily concerned about solving differential equations? --------- Co-authored-by: Patrick Kidger <[email protected]>
* Added parametric control types * Demo of AbstractTerm change * Correct initial parametric control type implementation. * Parametric AbstractTerm initial implementation. * Update tests and fix hinting * Implement review comments. * Add parametric control check to integrator. * Update and test parametric control check * Introduce new LevyArea types * Updated Brownian path LevyArea types * Replace Union types in isinstance checks * Remove rogue comment * Revert _brownian_arch to single assignment * Revert _evaluate_leaf key splitting * Rename variables in test_term * Update isinstance and issubclass checks * Safer handling in _denormalise_bm_inc * Fix style in integrate control type check * Add draft vector_field typing * Add draft vector_field typing * Fix term test * Revert extemporaneous modifications in _tree * Rename TimeLevyArea to BrownianIncrement and simplify diff * Rename AbstractLevyReturn to AbstractBrownianReturn * Rename _LevyArea to _BrownianReturn * Enhance _term_compatiblity checks * Fix merge issues * Bump pre-commit and fix type hints * Clean up from self-review * Explicitly add typeguard to deps * Bump ruff config to new syntax * Parameterised terms: fixed term compatibility + spurious pyright errors Phew, this ended up being a complicated one! Let's start with the easy stuff: - Disabled spurious pyright errors due to incompatible between pyright and `eqx.AbstractVar`. - Now using ruff.lint and pinned exact typeguard version. Now on to the hard stuff: - Fixed term compatibibility missing some edge cases. Edge cases? What edge cases? Well, what we had before was basically predicated around doing ```python vf, contr = get_args(term_cls) ``` recalling that we may have e.g. `term_cls = AbstractTerm[SomeVectorField, SomeControl]`. So far so simple: get the arguments of a subscripted generic, no big deal. What this failed to account for is that we may also have subclasses of this generic, e.g. `term_cls = ODETerm[SomeVectorField]`, such that some of the type variables have already been filled in when defining it: ```python class ODETerm(AbstractTerm[_VF, RealScaleLike]): ... ``` so in this case, `get_args(term_cls)` simply returns a 1-tuple of `(SomeVectorField,)`. Oh no! Somehow we have to traverse both the filled-in type variables (to find that one of our type variables is `SomeVectorField` due to subscripting) *and* the type hierarchy (to figure out that the other type variable was filled in during the definition). Once again, for clarity: given a subscriptable base class `AbstractTerm[_VF, _Control]` and some arbitrary possible-subscripted subclass, we need to find the values of `_VF` and `_Control`, regardless of whehther they have been passed in via subscripting the final class (and are `get_args`-able) or have been filled in during subclassing (and require traversing pseudo-type-hierarchies of `__orig_bases__`). Any sane implementation would simply... not bother. There is no way that the hassle of figuring this out was going to be worth the small amount of type safety this brings... So anyway, after a few hours working on this *far* past the point I should be going to sleep, this problem this is now solved. This PR introduces a new `get_args_of` function, called as `get_args_of(superclass, subclass, error_msg_if_necessary)`. This acts analogous to `get_args`, but instead of looking up both parameters (the type variables we want filled in) and the arguments (the values those type variables have been filled in with) on the same class, it looks up the parameters on the superclass, and their filled-in-values on the subclass. Pure madness. (I'm also tagging @leycec here because this is exactly the kind of insane typing hackery that he seems to really enjoy.) Does anyone else remember the days when this was a package primarily concerned about solving differential equations? --------- Co-authored-by: Patrick Kidger <[email protected]>
* Added parametric control types * Demo of AbstractTerm change * Correct initial parametric control type implementation. * Parametric AbstractTerm initial implementation. * Update tests and fix hinting * Implement review comments. * Add parametric control check to integrator. * Update and test parametric control check * Introduce new LevyArea types * Updated Brownian path LevyArea types * Replace Union types in isinstance checks * Remove rogue comment * Revert _brownian_arch to single assignment * Revert _evaluate_leaf key splitting * Rename variables in test_term * Update isinstance and issubclass checks * Safer handling in _denormalise_bm_inc * Fix style in integrate control type check * Add draft vector_field typing * Add draft vector_field typing * Fix term test * Revert extemporaneous modifications in _tree * Rename TimeLevyArea to BrownianIncrement and simplify diff * Rename AbstractLevyReturn to AbstractBrownianReturn * Rename _LevyArea to _BrownianReturn * Enhance _term_compatiblity checks * Fix merge issues * Bump pre-commit and fix type hints * Clean up from self-review * Explicitly add typeguard to deps * Bump ruff config to new syntax * Parameterised terms: fixed term compatibility + spurious pyright errors Phew, this ended up being a complicated one! Let's start with the easy stuff: - Disabled spurious pyright errors due to incompatible between pyright and `eqx.AbstractVar`. - Now using ruff.lint and pinned exact typeguard version. Now on to the hard stuff: - Fixed term compatibibility missing some edge cases. Edge cases? What edge cases? Well, what we had before was basically predicated around doing ```python vf, contr = get_args(term_cls) ``` recalling that we may have e.g. `term_cls = AbstractTerm[SomeVectorField, SomeControl]`. So far so simple: get the arguments of a subscripted generic, no big deal. What this failed to account for is that we may also have subclasses of this generic, e.g. `term_cls = ODETerm[SomeVectorField]`, such that some of the type variables have already been filled in when defining it: ```python class ODETerm(AbstractTerm[_VF, RealScaleLike]): ... ``` so in this case, `get_args(term_cls)` simply returns a 1-tuple of `(SomeVectorField,)`. Oh no! Somehow we have to traverse both the filled-in type variables (to find that one of our type variables is `SomeVectorField` due to subscripting) *and* the type hierarchy (to figure out that the other type variable was filled in during the definition). Once again, for clarity: given a subscriptable base class `AbstractTerm[_VF, _Control]` and some arbitrary possible-subscripted subclass, we need to find the values of `_VF` and `_Control`, regardless of whehther they have been passed in via subscripting the final class (and are `get_args`-able) or have been filled in during subclassing (and require traversing pseudo-type-hierarchies of `__orig_bases__`). Any sane implementation would simply... not bother. There is no way that the hassle of figuring this out was going to be worth the small amount of type safety this brings... So anyway, after a few hours working on this *far* past the point I should be going to sleep, this problem this is now solved. This PR introduces a new `get_args_of` function, called as `get_args_of(superclass, subclass, error_msg_if_necessary)`. This acts analogous to `get_args`, but instead of looking up both parameters (the type variables we want filled in) and the arguments (the values those type variables have been filled in with) on the same class, it looks up the parameters on the superclass, and their filled-in-values on the subclass. Pure madness. (I'm also tagging @leycec here because this is exactly the kind of insane typing hackery that he seems to really enjoy.) Does anyone else remember the days when this was a package primarily concerned about solving differential equations? --------- Co-authored-by: Patrick Kidger <[email protected]>
@@ -23,7 +23,7 @@ classifiers = [ | |||
"Topic :: Scientific/Engineering :: Mathematics", | |||
] | |||
urls = {repository = "https://github.com/patrick-kidger/diffrax" } | |||
dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "equinox>=0.11.2", "lineax>=0.0.4", "optimistix>=0.0.6"] | |||
dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.2", "lineax>=0.0.4", "optimistix>=0.0.6"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is typeguard
pinned to an exact version, especially one that is so old?
It's generally considered very bad packaging practice to pin to exact versions as it's ~never required and prevents getting even security patches without publishing a new version.
Also, it's ancient and won't be getting any updates - current HEAD is 4.2.1:
Pull request that closes #359.
Core Changes:
_ControlTerm
toAbstractControlTerm[_Control]
;ControlTerm[AbstractBrownianPath]
. Provided demonstration in_solvers/milstein.py
. Will all relevant solvers want updating?Additional Changes:
_solvers/milstein.py
.Known Issues:
test/test_term.py:133
the following type error is thrown (not sure why):diffrax/_term.py:266
the following type error is being suppressed (as I can't find a nice way to avoid it):