Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add exponential Euler. #567

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

GJBoth
Copy link

@GJBoth GJBoth commented Jan 13, 2025

This is a (very...) rough PR to see if there's interest in adding support for exponential integrators. Here's a good source with some background. Their main use is for stiff, oscillatory problems where neither IMEX nor stiff methods work well, such as brain dynamics or biochemical reaction networks.

I added the so-called exponential Euler method, which uses forward Euler to evaluate the integral, and which according to this paper also works for the stochastic version (additive, commutative noise).

In short, this solver

  • handles both the ODE and SDE version.
  • Has a term structure that is compatible with other solvers.
  • Has a specialised version for when the linear operator is diagonal.

The solver is quite straightforward, but the approach requires splitting the equation into terms, a linear and non-linear one. Since the ODETerm does not seem to support lineax terms in its vector field (not sure why?), I added a simple LinearODETerm which does allow this. Finally, I added the make_operators function which specialises to when the operator is diagonal.

If you like it I can clean this up and merge it.

Copy link
Owner

@patrick-kidger patrick-kidger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, left some high-level comments, but basically I think I like this! I'd be interested in taking this as a PR. Thanks for putting this together :)



class ExponentialEuler(diffrax.AbstractItoSolver):
term_structure: ClassVar = AbstractTerm
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might just be part of the cleanup you mentioned, but I think you want MultiTerm[tuple[LinearODETerm, AbstractTerm]] to make this well-defined with the unpacking of the term in .step. (Likewise for the annotations for the terms: arguments.)

# but it just complicates things here as we return two different
# operators (linear vs full)
exp_Ah = jnp.exp(control * operator.diagonal)
phi = jnp.expm1(control * operator.diagonal) / (operator.diagonal * control)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is numerically stable if the denominator is small? (Maybe switch to a Taylor expansion below some cutoff?)

# and solve the linear problem
exp_Ah = expm(control * operator.as_matrix())
A, b = operator * control, exp_Ah - jnp.eye(exp_Ah.shape[-1])
phi = jax.vmap(lx.linear_solve, in_axes=(None, 1))(A, b).value
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you might be able to skip the vmap by doing the dot against Gh_sdW before the linear solve.

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(+same comment about numerical stability)

Comment on lines +791 to +797
class LinearODETerm(ODETerm):
operator: lx.MatrixLinearOperator | lx.DiagonalLinearOperator

def __init__(self, operator: lx.MatrixLinearOperator | lx.DiagonalLinearOperator):
self.operator = operator
vf = lambda t, y, args: self.operator.mv(y)
super().__init__(vector_field=vf)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, I try to avoid subclassing concrete classes -- c.f. https://docs.kidger.site/equinox/pattern/

Also, does this strictly need to be an ODE term? I think in other problems, we might be interested in expressing linear terms in this way (for any operator: lx.AbstractLinearOperator) regardless of the control.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants