-
-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pr branch #331
Pr branch #331
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this! I think you've done things very cleanly overall.
I've left a fair few comments (it's a large PR); sometimes they apply to multiple places. E.g. I haven't called out every place where a method should be private.
diffrax/brownian/tree.py
Outdated
@@ -34,12 +42,61 @@ class _State(eqx.Module): | |||
w_t: Scalar | |||
w_u: Scalar | |||
key: "jax.random.PRNGKey" | |||
J_s: Scalar = field(default=None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you're only specified a default you can write J_s: Optional[Scalar] = None
.
Also note that I've fixed the type annotation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(The type annotations are going to be checked properly in the next release of Diffrax, by the way.)
diffrax/brownian/tree.py
Outdated
J_u: Scalar = field(default=None) | ||
|
||
|
||
@jax.jit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You shouldn't need to JIT something that isn't public API.
diffrax/brownian/tree.py
Outdated
|
||
|
||
@jax.jit | ||
def wj_to_wh_diff(x0: LevyVal, x1: LevyVal) -> LevyVal: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As these aren't public API (=they're never imported anywhere) then their names should begin with an underscore.
diffrax/brownian/tree.py
Outdated
x1: LevyVal(W_u, J_u) | ||
""" | ||
h = (x1.h - x0.h).astype(x0.W.dtype) | ||
inverse_h = jnp.nan_to_num(1 / h).astype(x0.W.dtype) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For autodifferentiation reasons, it is better to avoid ever creating a NaN in the first place:
h = ...
h = jnp.where(jnp.abs(h) < jnp.finfo(h).eps, jnp.inf, h)
inverse_h = 1 / h
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, what's the reason for the dtype casts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think without this it wasn't passing one of the tests. But I think that after I added interval normalisation, I was more principled about the types, so yes, this should no longer be required.
diffrax/solver/ansr.py
Outdated
|
||
a = self.embed_a_lower(jnp.dtype(y0)) | ||
c = jnp.insert(jnp.asarray(self.tableau.c, dtype=jnp.dtype(y0)), 0, 0.0) | ||
b = jnp.asarray(self.tableau.b, dtype=jnp.dtype(y0)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These jnp.dtype
calls will fail for pytree-valued y0
.
diffrax/solver/ansr.py
Outdated
) | ||
carry = (0, hks) | ||
|
||
# output of lax.scan is ((num_stages, hks), [None, None, ... , None]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: it will just be a None
, not a [None, ...]
.
diffrax/solver/ansr.py
Outdated
# TODO: add checks for whether the method is FSAL | ||
|
||
|
||
class ANSR(AbstractItoSolver): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, for consistency I try to name abstract classes beginning with Abstract
.
diffrax/solver/ansr.py
Outdated
return 1 # should be modified depending on tableau | ||
|
||
def strong_order(self, terms): | ||
return 0.5 # should be modified depending on tableau |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'd suggest leaving these out here if they get overriden on a per-solver basis.
diffrax/term.py
Outdated
@@ -413,6 +426,10 @@ def is_vf_expensive( | |||
_t1 = jnp.where(self.direction == 1, t1, -t0) | |||
return self.term.is_vf_expensive(_t0, _t1, y, args) | |||
|
|||
def levy_contr(self, t0: Scalar, t1: Scalar) -> (PyTree, PyTree): | |||
assert isinstance(self.term, _ControlTerm) | |||
return self.term.levy_contr(t0, t1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'd suggest removing these extra methods and instead just have contr
pipe any **kwargs
forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to make it explicit that a given method will always return LevyVal. But I suppose we can get rid of that. I'll just add the following check to all solvers requiring STLA:
assert isinstance(levy, LevyVal) and (levy.H is not None), \
("The diffusion should be a ControlTerm controlled by either a"
"VirtualBrownianTree or an UnsafeBrownianPath with"
"`spacetime_levyarea` set to True.")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm I thought I removed this...
diffrax/brownian/base.py
Outdated
"Abstract base class for all Brownian paths." | ||
"""Abstract base class for all Brownian paths.""" | ||
|
||
spacetime_levyarea: bool = eqx.field(static=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's generally an antipattern to put fields on abstract base classes.
Are you looking to declare that there must be a spacetime_levyarea
attribute? (E.g. that the following should typecheck:
def f(x: AbstractBrownianPath):
y = x.spacetime_levyarea
)
or something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that was exactly the intention. It is not strictly needed except in some assert statements, but I think it would be useful to have. I suppose I can also use hasattr() first.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I removed this. In addition, in VBT and UBP I swapped this out for levy_area: str, which can be either "", "space-time", or eventually "space-time-time".
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gotcha! For completeness, the way to express this would be to use eqx.AbstractVar
-- this declares (and enforces) that this attribute must exist -- which can then be implemented by a field
or property
.
So in this case you'd want levy_area: eqx.AbstractVar[Literal["", "space-time"]]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, another partial review! I think we're most of the way there for the changes to the Brownian motions. I've also started to leave a few comments on the solvers.
diffrax/brownian/path.py
Outdated
""" | ||
|
||
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) | ||
levy_area: str = eqx.field(static=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The type annotation here should be Literal["", "space-time", "space-time-time"]
. This helps to be much more precise than simply str
.
diffrax/brownian/tree.py
Outdated
): | ||
if t0 >= t1: | ||
raise ValueError("t0 must be strictly less than t1.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
t0
and t1
might be traced values, in which case this will fail.
You want (t0, t1) = eqx.error_if((t0, t1), t0 >= t1, "t0 must be strictly less than t1")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I didn't do that is I thought that would fail if they were not JAX arrays (as there was nothing to thread the error onto). But anyway, I changed it as you suggested and it somehow works. Thanks for the heads up.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If they're not JAX arrays then I the t0 >= t1
condition should be evaluated statically, in which case the error_if
will eagerly raise at trace-time rather than runtime. Thus it should always work :)
diffrax/brownian/tree.py
Outdated
if self.levy_area == "space-time": | ||
# set tH_t to None | ||
levy_out = jtu.tree_map( | ||
lambda x: dataclasses.replace(x, tH_t=None), levy_0 | ||
) | ||
else: | ||
levy_out = levy_0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all of these comparisons should explicitly check every branch, and raise an error if necessary:
if self.levy_area == "space-time":
...
elif self.levy_area == "":
...
else:
assert False
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair.
diffrax/brownian/tree.py
Outdated
# NOTE: this gives a different result than the original implementation of the | ||
# VirtualBrownianTree by Patrick Kidger. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove this note, since this code lives under patrick-kidger/diffrax
;)
diffrax/solver/additive_srk.py
Outdated
assert np.allclose(sum(a_i), c_i) | ||
assert np.allclose(sum(self.b_sol), 1.0) | ||
|
||
# TODO: add checks for whether the method is FSAL |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the status of this btw? If I understand correctly, right now we simply don't support FSAL?
For that matter, does FSAL even appear in these kinds of stochastic solvers? FSAL is usually used just to form an error estimate for the step, but adaptive stochastic solvers are historically pretty niche. (Which is a topic I have strong opinions on, but best to tackle that some other time!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ALIGN, for example is FSAL, and perhaps it could appear. But yes, you're right, FSAL is just not supported. This is a little remnent of when I thought I might implement it soon, but I realised it is not really relevant for the solvers that I currently have, so no point adding it. Anyway, I just deleted that.
But you'll be happy to know that I implemented embedded error estimates for both ShARK and SRA1, so adaptive time stepping works with both. And same for ALIGN.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!
diffrax/solver/additive_srk.py
Outdated
makes use of space-time Lévy area. | ||
|
||
Given the SDE | ||
$dX_t = f(t, X_t) dt + σ dW_t$ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it have to be constant noise or is time-dependent noise allowed too? (Additive noise usually still allows time-dependent noise, since this still satisfies the conditions of commutative noise.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know. I have a general SRK in the works, which will support additive, time-dependent additive, and non-additive noise. I think I'll have it ready soon (see experiments branch), but for this PR I'd stick with what is in here. Besides, Trevor doesn't need time-dependent noise, so I think that for the time being, nobody in the world will miss that feature.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just for my own education: is there actually any scenario in which time-dependent and time-independent additive noise need to be treated differently?
As for a general SRK -- if this is landing later, perhaps we should remove AbstractAdditiveSRK
from the public interface now, as presumably it'll be superseded later?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be fair it is a very minor difference. The way it is now I just evaluate diffusion.vf once, and if it was time-varying I'd have to evaluate it in every stage with the appropriate t0 + cj*h. And to be fair I can add that now if you want, it really is minor. I just thought given that it's coming later anyway, there's no need to edit it now.
Yes, I can remove it from the public API. Does that mean I should just remove it from the docs or also something else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On your first paragraph: no strong feelings; up to you.
But yes, definitely do remove it from the docs / public API if it's not going to be long for this world!
Do you think it would it be worth splitting this PR into two pieces? Keep just the BM changes here, and then do the solvers separately? The BM changes are now very nearly ready, and I'm thinking it might not be worth ironing out the details of the solvers if they're going to be changing soon anyway. (The bottleneck here is clearly my review time, so anything that minimises that is good :) )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I can try to split it, that's probably a good idea. What's the best way to accomplish this? Should I just start from a fresh branch (up to date with just main Diffrax) and then add in all the brownian-related files, and not the solvers? There are some files (e.g. term.py) which have changes relating to both one and the other, and I think it would be a mess trying to separate all of that out. What do you suggest (mostly in terms of the git procedure + useful commands)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should I just start from a fresh branch (up to date with just main Diffrax) and then add in all the brownian-related files, and not the solvers?
Yup, exactly.
For term.py
, let's leave that out for now as it isn't reviewed yet.
In terms of git commands: the simplest way is probably to copy-paste the browian files somewhere else, then create a new branch, then copy them back!
diffrax/solver/additive_srk.py
Outdated
in the `StochasticButcherTableau`. | ||
""" | ||
|
||
term_structure = MultiTerm[Tuple[ODETerm, _ControlTerm]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think replace _ControlTerm
with just AbstractTerm
for now. Right now we don't really have a good way to communicate things like "the noise must be a Brownian motion" in the type system, and in principle it'd be valid for someone to write their own kind of control term.
Also, nit: replace Tuple
with tuple
as the latter is now the norm as of Python 3.9. (I have some ongoing work to clean this up in the rest of Diffrax, but we might as well get it right here.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fair enough :)
diffrax/solver/additive_srk.py
Outdated
if not path.levy_area == "space-time": | ||
raise ValueError( | ||
"The Brownian path controlling the diffusion " | ||
"should be initialised with `compute_stla=True`" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this shoudl be levy_area
now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yep, good point. I tried to replace it everywhere, but this one flew under the radar.
diffrax/solver/additive_srk.py
Outdated
# check that the vector field of the diffusion term is constant | ||
sigma, (t_sigma, y_sigma) = eqx.filter_jvp( | ||
lambda t, y: diffusion.vf(t, y, args), (t0, y0), (t0, y0) | ||
) | ||
if (t_sigma is not None) or (y_sigma is not None): | ||
raise ValueError( | ||
"Vector field of the diffusion term should be constant, " | ||
"independent of t and y." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following on from above: note that this should allow dependence on t
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The way it is currently set up, time dependence is not supported. I will modify this in the next PR accordingly.
Hi, I added computation of space-time Levy Area to VirtualBrownianTree and UnsafeBrownianPath, as well as implemented several methods for additive-noise SDEs. This includes an abstract Additive-Noise Stochastic Runge-Kutta solver (ANSR) which allows the user to supply a StochasticButcherTableau to specify an exact solver. I implemented three instances of ANSR:
In addition I implemented the Adaptive Langevin via Interpolated Gradients and Noise (ALIGN) method, which does not fall into the "ANSR" class of solvers, and can only be used for Langevin-like SDEs, but has the benefit of being FSAL, as well as having a 2nd strong order of convergence.