Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Smoothers (Iterative Linear Solvers based on Matrix Decomposition) #131

Draft
wants to merge 17 commits into
base: main
Choose a base branch
from

Conversation

Ceyron
Copy link
Contributor

@Ceyron Ceyron commented Jan 15, 2025

Hi @patrick-kidger,

This is to implement #116. I opened it as a draft because I wanted to get some feedback/have questions on handling Lineax internals and what features to add.

So far, I have implemented the Jacobi & GaussSeidel methods as a mixture of the dense solver (with their matrix materialization) and Krylov solvers (with their implementation of the iterative process). Moreover, I added a new tag strictly_diagonally_dominant because this is a sufficient (but not necessary) condition for their convergence (for Jacobi and for GaussSeidel). If max_step=None, the smoothers run for 100 * size instead of 10 * size (what the Krylov solvers default to) because of the slower convergence. I am thinking of bumping it to 200 * size since I otherwise observe that sometimes some tests fail (due to the randomized matrices).

I have some questions to discuss before I polish up the PR with docs etc.:

  1. Is my handling of the tag correctly? (with regards to _operator.py and test_operator.py) Do you think this new tag is a good addition? As far as I understand, Lineax does implement methods to actually check if operators obey their tags. It only relies on the user specification. (One could add a simple function to check if twice the absolute diagonal is larger than the absolute row sum.)
  2. Are there additional tests for which I have to register the two new solvers?
  3. I will implement the option for relaxation in both Jacobi and GaussSeidel. For the latter, this gives the SOR method. For symmetric matrices, there is also the SSOR method. Do you think it is worth adding it as well?
  4. Do you think that there should be some handling of the strictly_diagonally_dominant tag in the smoothing methods? Giving a warning, if it is not present? It is not a necessary condition, though.

These are the open points I will tackle after your feedback:

  • Add Documentation
  • Add a relaxation option to Jacobi and GaussSeidel and test it

@patrick-kidger
Copy link
Owner

Thank you for putting this together! To work through your questions:

  1. I think I'm inclined against this. As you note this is sufficient but not necessary -- e.g. Sassenfeld's criterion also suffices for the convergence of Gauss-Seidel -- and the tag system is pretty much entirely opt-in anyway.
  2. I don't think so!
  3. No strong feelings.
  4. I think what would be best would be to arrange for a failed solve, together with a failed check of this sort, to produce a RESULT that notes this and gives a more helpful error message.

Besides this, and as you note in #116, the use-cases for these are fairly unusual. C.f. also https://people.math.ethz.ch/~mhg/pub/biksm.pdf, which also highlights that Jacobi iterations are not usually a good way to do things. So I'm a little concerned about adding these just because of how not useful they are. I'm tagging @vboussange who also expressed interest in these in #116. Maybe these are okay to add as long as we add big warning signs in the documentation?

FWIW I am curious what your own interest is, both in this and in #132. My experience of iterative solvers is that they really don't work well for 'typical' dense matrices, whilst the larger sparse systems they're designed for are anyway fairly hard to make efficient in JAX.

@Ceyron
Copy link
Contributor Author

Ceyron commented Feb 5, 2025

Thanks a lot for your thoughts. 😊

I agree smoothing methods are (most of the time) not competitive with Krylov-based solvers. The fact that they need to materialize the matrix (because of limited support for sparse arrays in JAX) makes them even less attractive. Why I would like to see them in Lineax is because of using them with other features of the library (working with all kinds of linear operators and the elegant handling of AD through the linsolve). For example, I am working on a time-implicit finite difference solvers library. Defining the discretization just in terms of their residuum via a PyTree-valued function is extremely handy. Then, choosing all the different solvers from Lineax allows for great flexibility.
In parts, my interests are in smaller-scale research with matrices of size $n\in (100, 10000)$ (with sparsity) for which smoothers are still imperfect but usable 😅. They allow for gaining interesting insights in down-stream tasks, e.g., if I train a neural emulator on PDE trajectories for which the data was generated with linear solver X (with X in what Lineax offers) truncated to tolerance Y or truncated to a number of steps Z, what kind of an effect does this have on the neural network. Could this be exploited? Going further, what kind of an effect does it have on the gradient if I differentiate through the solution (with unrolled or implicit diff), etc. For example, we investigated sth like it for a recent ICLR paper.

Regarding the other points:

  1. Agree, I will remove the strictly_diagonally_dominant tag
  2. Nice, thanks 👍
  3. To keep it simple, I would leave it out for now. It could still be added in a future PR.
  4. Interesting. Do you mean adding a new RESULTS.non_diagonally_dominant or modifying the RESULTS.singular message?

Overall, adding the smoothers completes the library with all the "textbook linear solvers". Adding a warning regarding their efficiency is necessary, also because they materialize the matrix.

I would also be curious about what @vboussange thinks.

@vboussange
Copy link

vboussange commented Feb 7, 2025

Hey there, awesome work @Ceyron. I am myself quite interested by these solvers, as I would love to see a multigrid solver in JAX at some point. This work is an important milestone towards that.
However, matrix materialisation is a bummer for me. Could we avoid materialising matrices, by introducing custom bcoo_tril and bcoo_triu?

@filter_jit
def bcoo_tril(mat: BCOO, k: int = 0) -> BCOO:
    """
    Return the upper-triangular part of the given 2D BCOO matrix.
    The result has zeros below the k-th diagonal.
    """
    rows = mat.indices[:, 0]
    cols = mat.indices[:, 1]
    mask = jnp.where(rows >= cols - k, 1.0, 0.0)
    new_data = mat.data * mask
    out = BCOO((new_data, mat.indices), shape=mat.shape)
    return out

@filter_jit
def bcoo_triu(mat: BCOO, k: int = 0) -> BCOO:
    """
    Return the upper-triangular part of the given 2D BCOO matrix.
    The result has zeros below the k-th diagonal.
    """
    rows = mat.indices[:, 0]
    cols = mat.indices[:, 1]
    mask = jnp.where(rows <= cols - k, 1.0, 0.0)
    new_data = mat.data * mask
    out = BCOO((new_data, mat.indices), shape=mat.shape)
    return out

@Ceyron
Copy link
Contributor Author

Ceyron commented Feb 7, 2025

Hey there, awesome work @Ceyron. I am myself quite interested by these solvers, as I would love to see a multigrid solver in JAX at some point.

Thanks for the kind words and the support 👍

This work is an important milestone towards that. However, matrix materialisation is a bummer for me. Could we avoid materialising matrices, by introducing custom bcoo_tril and bcoo_triu?

Interesting, do you refer to using jax.experimental.sparse.BCOO? What is your experience with the experimental sparse features in JAX? Are they efficient? Is it a goal of the JAX team to eventually have them be part of stable JAX?

@vboussange
Copy link

Interesting, do you refer to using jax.experimental.sparse.BCOO?

Correct.

What is your experience with the experimental sparse features in JAX? Are they efficient?

I have been using it quite extensively in the context of graph analysis (see a prototype package for graph analysis here, and although they require some tweaks like bcoo_triu, I am happy with it.

Is it a goal of the JAX team to eventually have them be part of stable JAX?

Not sure, but I hope so!

@Ceyron
Copy link
Contributor Author

Ceyron commented Feb 7, 2025

Nice library 👍.
I guess supporting BCOO sparse matrices in Lineax would require adding a new linear operator, like lineax.BCOOLinearOperator, which could potentially trigger more efficient solve paths in some solvers (like a sparse matrix decomposition for the smoothers or efficient sparse (incomplete) direct solvers (if JAX ever starts wrapping the corresponding CuSparse routines). AFAIK, JAX does not support a way to translate a (linear) function into its sparse matrix representation like, e.g., Julia has. Hence, if Lineax wants to support sparse matrices, they need to be directly supplied by the user. (Just leaving this pointer to PhiFlow/PhiML that has a custom sparse matrix format and they find their sparsity pattern via tracing stencils).

I suggest performing this pull request with dense matrix materialization (similar to what happens to any matrix-free linear operator in Lineax's direct solvers) and potentially adding sparse matrix support with a future update. I can imagine a sparse matrix support naturally leading to ways of implementing algebraic multigrid. Wdyt, @patrick-kidger?

@patrick-kidger
Copy link
Owner

Do you mean adding a new RESULTS.non_diagonally_dominant

Yup!

I suggest performing this pull request with dense matrix materialization (similar to what happens to any matrix-free linear operator in Lineax's direct solvers) and potentially adding sparse matrix support with a future update.

Agreed. Hopefully (!) we can write the solvers in a matrix-free way that is agnostic to whether the operator is sparse or not. (Or even to what kind of sparse representation is being used -- BCOO, BCSR, ...)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants