Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pr branch #331

Closed
wants to merge 1 commit into from
Closed

Pr branch #331

wants to merge 1 commit into from

Conversation

andyElking
Copy link
Contributor

@andyElking andyElking commented Nov 8, 2023

Hi, I added computation of space-time Levy Area to VirtualBrownianTree and UnsafeBrownianPath, as well as implemented several methods for additive-noise SDEs. This includes an abstract Additive-Noise Stochastic Runge-Kutta solver (ANSR) which allows the user to supply a StochasticButcherTableau to specify an exact solver. I implemented three instances of ANSR:

  • Shifted Euler for Additive-noise (SEA)
  • Shifted Additive-noise Runge-Kutta (ShARK)
  • SRA1.

In addition I implemented the Adaptive Langevin via Interpolated Gradients and Noise (ALIGN) method, which does not fall into the "ANSR" class of solvers, and can only be used for Langevin-like SDEs, but has the benefit of being FSAL, as well as having a 2nd strong order of convergence.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this! I think you've done things very cleanly overall.

I've left a fair few comments (it's a large PR); sometimes they apply to multiple places. E.g. I haven't called out every place where a method should be private.

@@ -34,12 +42,61 @@ class _State(eqx.Module):
w_t: Scalar
w_u: Scalar
key: "jax.random.PRNGKey"
J_s: Scalar = field(default=None)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you're only specified a default you can write J_s: Optional[Scalar] = None.
Also note that I've fixed the type annotation.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(The type annotations are going to be checked properly in the next release of Diffrax, by the way.)

J_u: Scalar = field(default=None)


@jax.jit
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't need to JIT something that isn't public API.



@jax.jit
def wj_to_wh_diff(x0: LevyVal, x1: LevyVal) -> LevyVal:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As these aren't public API (=they're never imported anywhere) then their names should begin with an underscore.

diffrax/brownian/tree.py Outdated Show resolved Hide resolved
x1: LevyVal(W_u, J_u)
"""
h = (x1.h - x0.h).astype(x0.W.dtype)
inverse_h = jnp.nan_to_num(1 / h).astype(x0.W.dtype)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For autodifferentiation reasons, it is better to avoid ever creating a NaN in the first place:

h = ...
h = jnp.where(jnp.abs(h) < jnp.finfo(h).eps, jnp.inf, h)
inverse_h = 1 / h

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, what's the reason for the dtype casts?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think without this it wasn't passing one of the tests. But I think that after I added interval normalisation, I was more principled about the types, so yes, this should no longer be required.


a = self.embed_a_lower(jnp.dtype(y0))
c = jnp.insert(jnp.asarray(self.tableau.c, dtype=jnp.dtype(y0)), 0, 0.0)
b = jnp.asarray(self.tableau.b, dtype=jnp.dtype(y0))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These jnp.dtype calls will fail for pytree-valued y0.

)
carry = (0, hks)

# output of lax.scan is ((num_stages, hks), [None, None, ... , None])
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: it will just be a None, not a [None, ...].

# TODO: add checks for whether the method is FSAL


class ANSR(AbstractItoSolver):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, for consistency I try to name abstract classes beginning with Abstract.

return 1 # should be modified depending on tableau

def strong_order(self, terms):
return 0.5 # should be modified depending on tableau
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd suggest leaving these out here if they get overriden on a per-solver basis.

diffrax/term.py Outdated
@@ -413,6 +426,10 @@ def is_vf_expensive(
_t1 = jnp.where(self.direction == 1, t1, -t0)
return self.term.is_vf_expensive(_t0, _t1, y, args)

def levy_contr(self, t0: Scalar, t1: Scalar) -> (PyTree, PyTree):
assert isinstance(self.term, _ControlTerm)
return self.term.levy_contr(t0, t1)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd suggest removing these extra methods and instead just have contr pipe any **kwargs forward.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to make it explicit that a given method will always return LevyVal. But I suppose we can get rid of that. I'll just add the following check to all solvers requiring STLA:

        assert isinstance(levy, LevyVal) and (levy.H is not None), \
            ("The diffusion should be a ControlTerm controlled by either a"
             "VirtualBrownianTree or an UnsafeBrownianPath with"
             "`spacetime_levyarea` set to True.")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm I thought I removed this...

diffrax/__init__.py Outdated Show resolved Hide resolved
diffrax/brownian/base.py Outdated Show resolved Hide resolved
"Abstract base class for all Brownian paths."
"""Abstract base class for all Brownian paths."""

spacetime_levyarea: bool = eqx.field(static=True)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's generally an antipattern to put fields on abstract base classes.

Are you looking to declare that there must be a spacetime_levyarea attribute? (E.g. that the following should typecheck:

def f(x: AbstractBrownianPath):
    y = x.spacetime_levyarea

)
or something else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that was exactly the intention. It is not strictly needed except in some assert statements, but I think it would be useful to have. I suppose I can also use hasattr() first.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I removed this. In addition, in VBT and UBP I swapped this out for levy_area: str, which can be either "", "space-time", or eventually "space-time-time".

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha! For completeness, the way to express this would be to use eqx.AbstractVar -- this declares (and enforces) that this attribute must exist -- which can then be implemented by a field or property.

So in this case you'd want levy_area: eqx.AbstractVar[Literal["", "space-time"]]

diffrax/brownian/base.py Outdated Show resolved Hide resolved
diffrax/brownian/base.py Outdated Show resolved Hide resolved
diffrax/brownian/tree.py Show resolved Hide resolved
diffrax/custom_types.py Outdated Show resolved Hide resolved
diffrax/path.py Outdated Show resolved Hide resolved
diffrax/path.py Outdated Show resolved Hide resolved
diffrax/path.py Outdated Show resolved Hide resolved
Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, another partial review! I think we're most of the way there for the changes to the Brownian motions. I've also started to leave a few comments on the solvers.

"""

shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
levy_area: str = eqx.field(static=True)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The type annotation here should be Literal["", "space-time", "space-time-time"]. This helps to be much more precise than simply str.

diffrax/brownian/path.py Show resolved Hide resolved
):
if t0 >= t1:
raise ValueError("t0 must be strictly less than t1.")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

t0 and t1 might be traced values, in which case this will fail.

You want (t0, t1) = eqx.error_if((t0, t1), t0 >= t1, "t0 must be strictly less than t1")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I didn't do that is I thought that would fail if they were not JAX arrays (as there was nothing to thread the error onto). But anyway, I changed it as you suggested and it somehow works. Thanks for the heads up.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If they're not JAX arrays then I the t0 >= t1 condition should be evaluated statically, in which case the error_if will eagerly raise at trace-time rather than runtime. Thus it should always work :)

Comment on lines 182 to 188
if self.levy_area == "space-time":
# set tH_t to None
levy_out = jtu.tree_map(
lambda x: dataclasses.replace(x, tH_t=None), levy_0
)
else:
levy_out = levy_0
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all of these comparisons should explicitly check every branch, and raise an error if necessary:

if self.levy_area == "space-time":
    ...
elif self.levy_area == "":
    ...
else:
    assert False

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair.

Comment on lines 359 to 361
# NOTE: this gives a different result than the original implementation of the
# VirtualBrownianTree by Patrick Kidger.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove this note, since this code lives under patrick-kidger/diffrax ;)

assert np.allclose(sum(a_i), c_i)
assert np.allclose(sum(self.b_sol), 1.0)

# TODO: add checks for whether the method is FSAL
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the status of this btw? If I understand correctly, right now we simply don't support FSAL?

For that matter, does FSAL even appear in these kinds of stochastic solvers? FSAL is usually used just to form an error estimate for the step, but adaptive stochastic solvers are historically pretty niche. (Which is a topic I have strong opinions on, but best to tackle that some other time!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ALIGN, for example is FSAL, and perhaps it could appear. But yes, you're right, FSAL is just not supported. This is a little remnent of when I thought I might implement it soon, but I realised it is not really relevant for the solvers that I currently have, so no point adding it. Anyway, I just deleted that.

But you'll be happy to know that I implemented embedded error estimates for both ShARK and SRA1, so adaptive time stepping works with both. And same for ALIGN.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

makes use of space-time Lévy area.

Given the SDE
$dX_t = f(t, X_t) dt + σ dW_t$
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it have to be constant noise or is time-dependent noise allowed too? (Additive noise usually still allows time-dependent noise, since this still satisfies the conditions of commutative noise.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know. I have a general SRK in the works, which will support additive, time-dependent additive, and non-additive noise. I think I'll have it ready soon (see experiments branch), but for this PR I'd stick with what is in here. Besides, Trevor doesn't need time-dependent noise, so I think that for the time being, nobody in the world will miss that feature.

Copy link
Owner

@patrick-kidger patrick-kidger Dec 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just for my own education: is there actually any scenario in which time-dependent and time-independent additive noise need to be treated differently?

As for a general SRK -- if this is landing later, perhaps we should remove AbstractAdditiveSRK from the public interface now, as presumably it'll be superseded later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be fair it is a very minor difference. The way it is now I just evaluate diffusion.vf once, and if it was time-varying I'd have to evaluate it in every stage with the appropriate t0 + cj*h. And to be fair I can add that now if you want, it really is minor. I just thought given that it's coming later anyway, there's no need to edit it now.

Yes, I can remove it from the public API. Does that mean I should just remove it from the docs or also something else?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On your first paragraph: no strong feelings; up to you.

But yes, definitely do remove it from the docs / public API if it's not going to be long for this world!


Do you think it would it be worth splitting this PR into two pieces? Keep just the BM changes here, and then do the solvers separately? The BM changes are now very nearly ready, and I'm thinking it might not be worth ironing out the details of the solvers if they're going to be changing soon anyway. (The bottleneck here is clearly my review time, so anything that minimises that is good :) )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I can try to split it, that's probably a good idea. What's the best way to accomplish this? Should I just start from a fresh branch (up to date with just main Diffrax) and then add in all the brownian-related files, and not the solvers? There are some files (e.g. term.py) which have changes relating to both one and the other, and I think it would be a mess trying to separate all of that out. What do you suggest (mostly in terms of the git procedure + useful commands)?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I just start from a fresh branch (up to date with just main Diffrax) and then add in all the brownian-related files, and not the solvers?

Yup, exactly.

For term.py, let's leave that out for now as it isn't reviewed yet.

In terms of git commands: the simplest way is probably to copy-paste the browian files somewhere else, then create a new branch, then copy them back!

in the `StochasticButcherTableau`.
"""

term_structure = MultiTerm[Tuple[ODETerm, _ControlTerm]]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think replace _ControlTerm with just AbstractTerm for now. Right now we don't really have a good way to communicate things like "the noise must be a Brownian motion" in the type system, and in principle it'd be valid for someone to write their own kind of control term.

Also, nit: replace Tuple with tuple as the latter is now the norm as of Python 3.9. (I have some ongoing work to clean this up in the rest of Diffrax, but we might as well get it right here.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough :)

if not path.levy_area == "space-time":
raise ValueError(
"The Brownian path controlling the diffusion "
"should be initialised with `compute_stla=True`"
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this shoudl be levy_area now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, good point. I tried to replace it everywhere, but this one flew under the radar.

Comment on lines 138 to 146
# check that the vector field of the diffusion term is constant
sigma, (t_sigma, y_sigma) = eqx.filter_jvp(
lambda t, y: diffusion.vf(t, y, args), (t0, y0), (t0, y0)
)
if (t_sigma is not None) or (y_sigma is not None):
raise ValueError(
"Vector field of the diffusion term should be constant, "
"independent of t and y."
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following on from above: note that this should allow dependence on t.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The way it is currently set up, time dependence is not supported. I will modify this in the next PR accordingly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants