-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Expose Iterators for manual stepping or autoregressive rollout #132
Comments
I think this is a reasonable request. If I understand correctly you don't need any particular detail of the state, you just need to be able to handle the iteration yourself. The bit I think might take some thought is what a good API for this would look like. For example to avoid stuff like class _CgState(eqx.Module):
_diff: ...
_y: ...
... with the underscore attributes making it clear that they shouldn't be treated as accessible. Meanwhile the Lines 231 to 232 in 58f2a8b
(and we'd rely on the compiler to DCE this computation whenever it isn't used). If the same pattern is come across multiple iterative solvers then perhaps some of the logic could be tidied up into an Anyway: broadly interested provided a good(=maintainable) API can be found. :) |
Good to hear you like the suggestion 👍. I agree; it needs a good and maintainable API. What do you think about the entire carry being a PyTree container with a property method, for example: class _CgState(eqx.Module):
_diff: ...
_y: ...
_negated: bool
...
@property
def value(self):
if self._negated:
return -(self._y**ω).ω
else:
return self._y
Makes sense. 👍 If you are still interested, I could set up a draft for a first version. |
Hmm, not sure it could be a method on the state, as in general postprocessing is currently a property of the solver. No immediate strong feelings either way though. |
At the moment, all iterative solvers (CG, NormalCG, BiCGStab, and GMRES) define a
body_fun
andcond_fun
in theircompute(...)
method (e.g., for CG see here). This is then used to run ajax.lax.while_loop
.I have multiple use cases for which I want to directly access the
body_fun
as an iterator (examples described below). However, since the function is private to thecompute
method, I cannot directly access it. In a private fork of the repository, I did the smaller modification to all iterative solver that added the following methodwhich effectively has the same input signature as
compute(...)
, but returns theinitial_carry
,cond_fun
,body_fun
, andmax_steps
. (This can then be used as part ofcompute(...)
to remove duplicate code.) Do you think something like this could be worthwhile to add?Examples for using iterators
Access Convergence History
Say we want to produce a residuum-norm over the number of steps plot; the current approach (correct me if I am wrong) is to run multiple solves with different
rtol
(oratol
) values, then extract thenum_steps
from thestats.
(Alternatively, with #129, we could also fixmax_steps
and then compute the residuum norm achieved at the end of the iteration, which gives an equidistant array over the number of steps).We could do this directly by using the iterators. For example, using CG on the Poisson matrix together with
jax.lax.scan
Manually Step through the Solution
Say we want to investigate how the state develops over the iterations (for educational reasons or applications like multigrid or some more esoteric research :D). The current approach would be to run the iterative solver for
max_steps=1
, extract the solution, and then run it again under the same setting to step forward.With the iterators that would be more straightforward, e.g.:
Or alternatively with a scan:
Unrolled Differentiation to assess Jacobian Convergence
In a recent line of research, I investigated the unrolled differentiation over iterative linear solvers, similar in spirit to https://arxiv.org/abs/2209.13271 . The Jacobian suboptimality can be extremely efficiently computed using forward-mode AD over a scan. Abstract example:
The text was updated successfully, but these errors were encountered: