Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose Iterators for manual stepping or autoregressive rollout #132

Open
Ceyron opened this issue Jan 15, 2025 · 3 comments
Open

Expose Iterators for manual stepping or autoregressive rollout #132

Ceyron opened this issue Jan 15, 2025 · 3 comments
Labels
feature New feature

Comments

@Ceyron
Copy link
Contributor

Ceyron commented Jan 15, 2025

At the moment, all iterative solvers (CG, NormalCG, BiCGStab, and GMRES) define a body_fun and cond_fun in their compute(...) method (e.g., for CG see here). This is then used to run a jax.lax.while_loop.

I have multiple use cases for which I want to directly access the body_fun as an iterator (examples described below). However, since the function is private to the compute method, I cannot directly access it. In a private fork of the repository, I did the smaller modification to all iterative solver that added the following method

    def get_iterators(
        self,
        state: ITERATIVE_SOLVER_STATE,
        vector: PyTree[Array],
        options: dict[str, Any],
    ) -> tuple[
        tuple[ITERATIVE_SOLVER_CARRY_TYPE, Callable, Callable, int
    ]:

which effectively has the same input signature as compute(...), but returns the initial_carry, cond_fun, body_fun, and max_steps. (This can then be used as part of compute(...) to remove duplicate code.) Do you think something like this could be worthwhile to add?

Examples for using iterators

Access Convergence History

Say we want to produce a residuum-norm over the number of steps plot; the current approach (correct me if I am wrong) is to run multiple solves with different rtol (or atol) values, then extract the num_steps from the stats. (Alternatively, with #129, we could also fix max_steps and then compute the residuum norm achieved at the end of the iteration, which gives an equidistant array over the number of steps).

We could do this directly by using the iterators. For example, using CG on the Poisson matrix together with jax.lax.scan

A = -2 * jnp.diag(jnp.ones(100)) + jnp.diag(jnp.ones(99), 1) + jnp.diag(jnp.ones(99), -1)
op = lineax.MatrixLinearOperator(A, tags=(lineax.negative_semidefinite_tag, lineax.symmetric_tag))
b = jax.random.normal(jax.random.key(0), (100,))

cg_solver = lineax.CG(1e-5, 1e-5)
cg_state = cg_solver.init(op, {})

initital_carry, _, body_fun, max_steps = cg_solver.get_iterators(cg_state, b, {})

def scan_fun(carry, _):
    next_carry = body_fun(carry)
    next_res_norm = jnp.linalg.norm(carry[2])
    return next_carry, next_res_norm

_, res_norm_history = jax.lax.scan(scan_fun, initial_carry, None, length=max_steps)

Manually Step through the Solution

Say we want to investigate how the state develops over the iterations (for educational reasons or applications like multigrid or some more esoteric research :D). The current approach would be to run the iterative solver for max_steps=1, extract the solution, and then run it again under the same setting to step forward.

With the iterators that would be more straightforward, e.g.:

initital_carry, _, body_fun, max_steps = cg_solver.get_iterators(cg_state, b, {})
zeroth_state = initial_carry[1]
next_carry = body_fun(initial_carry)
first_state = next_carry[1]
#....

Or alternatively with a scan:

_, state_history = jax.lax.scan(lambda carry, _: (body_fun(carry), carry[1]), initial_carry, None, max_steps)

Unrolled Differentiation to assess Jacobian Convergence

In a recent line of research, I investigated the unrolled differentiation over iterative linear solvers, similar in spirit to https://arxiv.org/abs/2209.13271 . The Jacobian suboptimality can be extremely efficiently computed using forward-mode AD over a scan. Abstract example:

def assemble_spd_system(params: Array):
    # ...
    return matrix, rhs

def produce_state_rollout(params):
    matrix, rhs = assemble_spd_system(params)
    op = lineax.MatrixLinearOperator(
        A,
        tags=(lineax.positive_semidefinite_tag, lineax.symmetric_tag),
    )
    cg_solver = lineax.CG(1e-5, 1e-5)
    cg_state = cg_solver.init(op, {})

    initital_carry, _, body_fun, max_steps = cg_solver.get_iterators(cg_state, b, {})

    _, state_history = jax.lax.scan(
        lambda carry, _: (body_fun(carry), carry[1]),
        initial_carry,
        None,
        max_steps,
    )

    return state_history

def solve_system_direct(params):
    matrix, rhs = assemble_spd_system(params)
    return jnp.linalg.solve(matrix, rhs)

jacobian_rollout = jax.jacfwd(produce_state_rollout)(params)
true_jacobian = jax.jacobian(solve_system_direct)(params)

jacobian_suboptimality = jnp.linalg.norm(jacobian_rollout - true_jacobian, axis=(-1, -2))
@patrick-kidger
Copy link
Owner

I think this is a reasonable request. If I understand correctly you don't need any particular detail of the state, you just need to be able to handle the iteration yourself.

The bit I think might take some thought is what a good API for this would look like. For example to avoid stuff like state[2] then I imagine a body function that returns a 2-tuple of (solution_if_we_were_to_stop_at_this_step, opaque_state). Then at each step you either use the provided solution, or run another iteration using the state. The state itself would not necessarily be stable between versions; this could be communicated by making it be something like

class _CgState(eqx.Module):
    _diff: ...
    _y: ...
    ...

with the underscore attributes making it clear that they shouldn't be treated as accessible.

Meanwhile the solution_if_we_were_to_stop_at_this_step would need to include the extra postprocessing of

lineax/lineax/_solver/cg.py

Lines 231 to 232 in 58f2a8b

if is_nsd and not self._normal:
solution = -(solution**ω).ω

(and we'd rely on the compiler to DCE this computation whenever it isn't used).

If the same pattern is come across multiple iterative solvers then perhaps some of the logic could be tidied up into an AbstractIterativeLinearSolver.

Anyway: broadly interested provided a good(=maintainable) API can be found. :)

@Ceyron
Copy link
Contributor Author

Ceyron commented Feb 5, 2025

Good to hear you like the suggestion 👍.

I agree; it needs a good and maintainable API. What do you think about the entire carry being a PyTree container with a property method, for example:

class _CgState(eqx.Module):
    _diff: ...
    _y: ...
    _negated: bool
    ...
   
    @property
    def value(self):
         if self._negated:
               return -(self._y**ω).ω
         else:
               return self._y   

If the same pattern is come across multiple iterative solvers then perhaps some of the logic could be tidied up into an AbstractIterativeLinearSolver.

Makes sense. 👍

If you are still interested, I could set up a draft for a first version.

@patrick-kidger
Copy link
Owner

Hmm, not sure it could be a method on the state, as in general postprocessing is currently a property of the solver. No immediate strong feelings either way though.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature New feature
Projects
None yet
Development

No branches or pull requests

2 participants