-
-
Notifications
You must be signed in to change notification settings - Fork 140
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add exponential Euler. #567
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, left some high-level comments, but basically I think I like this! I'd be interested in taking this as a PR. Thanks for putting this together :)
|
||
|
||
class ExponentialEuler(diffrax.AbstractItoSolver): | ||
term_structure: ClassVar = AbstractTerm |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This might just be part of the cleanup you mentioned, but I think you want MultiTerm[tuple[LinearODETerm, AbstractTerm]]
to make this well-defined with the unpacking of the term in .step
. (Likewise for the annotations for the terms:
arguments.)
# but it just complicates things here as we return two different | ||
# operators (linear vs full) | ||
exp_Ah = jnp.exp(control * operator.diagonal) | ||
phi = jnp.expm1(control * operator.diagonal) / (operator.diagonal * control) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure this is numerically stable if the denominator is small? (Maybe switch to a Taylor expansion below some cutoff?)
# and solve the linear problem | ||
exp_Ah = expm(control * operator.as_matrix()) | ||
A, b = operator * control, exp_Ah - jnp.eye(exp_Ah.shape[-1]) | ||
phi = jax.vmap(lx.linear_solve, in_axes=(None, 1))(A, b).value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you might be able to skip the vmap by doing the dot against Gh_sdW
before the linear solve.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(+same comment about numerical stability)
class LinearODETerm(ODETerm): | ||
operator: lx.MatrixLinearOperator | lx.DiagonalLinearOperator | ||
|
||
def __init__(self, operator: lx.MatrixLinearOperator | lx.DiagonalLinearOperator): | ||
self.operator = operator | ||
vf = lambda t, y, args: self.operator.mv(y) | ||
super().__init__(vector_field=vf) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, I try to avoid subclassing concrete classes -- c.f. https://docs.kidger.site/equinox/pattern/
Also, does this strictly need to be an ODE term? I think in other problems, we might be interested in expressing linear terms in this way (for any operator: lx.AbstractLinearOperator
) regardless of the control.
This is a (very...) rough PR to see if there's interest in adding support for exponential integrators. Here's a good source with some background. Their main use is for stiff, oscillatory problems where neither IMEX nor stiff methods work well, such as brain dynamics or biochemical reaction networks.
I added the so-called exponential Euler method, which uses forward Euler to evaluate the integral, and which according to this paper also works for the stochastic version (additive, commutative noise).
In short, this solver
The solver is quite straightforward, but the approach requires splitting the equation into terms, a linear and non-linear one. Since the
ODETerm
does not seem to support lineax terms in its vector field (not sure why?), I added a simpleLinearODETerm
which does allow this. Finally, I added themake_operators
function which specialises to when the operator is diagonal.If you like it I can clean this up and merge it.