-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Smoothers (Iterative Linear Solvers based on Matrix Decomposition) #131
base: main
Are you sure you want to change the base?
Conversation
Smoothers are only guaranteed to converge for strictly diagonally dominant matrices which this function does not test for.
Probably, they have be adjusted for the corresponding operators.
Thank you for putting this together! To work through your questions:
Besides this, and as you note in #116, the use-cases for these are fairly unusual. C.f. also https://people.math.ethz.ch/~mhg/pub/biksm.pdf, which also highlights that Jacobi iterations are not usually a good way to do things. So I'm a little concerned about adding these just because of how not useful they are. I'm tagging @vboussange who also expressed interest in these in #116. Maybe these are okay to add as long as we add big warning signs in the documentation? FWIW I am curious what your own interest is, both in this and in #132. My experience of iterative solvers is that they really don't work well for 'typical' dense matrices, whilst the larger sparse systems they're designed for are anyway fairly hard to make efficient in JAX. |
Thanks a lot for your thoughts. 😊 I agree smoothing methods are (most of the time) not competitive with Krylov-based solvers. The fact that they need to materialize the matrix (because of limited support for sparse arrays in JAX) makes them even less attractive. Why I would like to see them in Lineax is because of using them with other features of the library (working with all kinds of linear operators and the elegant handling of AD through the linsolve). For example, I am working on a time-implicit finite difference solvers library. Defining the discretization just in terms of their residuum via a PyTree-valued function is extremely handy. Then, choosing all the different solvers from Lineax allows for great flexibility. Regarding the other points:
Overall, adding the smoothers completes the library with all the "textbook linear solvers". Adding a warning regarding their efficiency is necessary, also because they materialize the matrix. I would also be curious about what @vboussange thinks. |
Hey there, awesome work @Ceyron. I am myself quite interested by these solvers, as I would love to see a multigrid solver in JAX at some point. This work is an important milestone towards that. @filter_jit
def bcoo_tril(mat: BCOO, k: int = 0) -> BCOO:
"""
Return the upper-triangular part of the given 2D BCOO matrix.
The result has zeros below the k-th diagonal.
"""
rows = mat.indices[:, 0]
cols = mat.indices[:, 1]
mask = jnp.where(rows >= cols - k, 1.0, 0.0)
new_data = mat.data * mask
out = BCOO((new_data, mat.indices), shape=mat.shape)
return out
@filter_jit
def bcoo_triu(mat: BCOO, k: int = 0) -> BCOO:
"""
Return the upper-triangular part of the given 2D BCOO matrix.
The result has zeros below the k-th diagonal.
"""
rows = mat.indices[:, 0]
cols = mat.indices[:, 1]
mask = jnp.where(rows <= cols - k, 1.0, 0.0)
new_data = mat.data * mask
out = BCOO((new_data, mat.indices), shape=mat.shape)
return out |
Thanks for the kind words and the support 👍
Interesting, do you refer to using |
Correct.
I have been using it quite extensively in the context of graph analysis (see a prototype package for graph analysis here, and although they require some tweaks like
Not sure, but I hope so! |
Nice library 👍. I suggest performing this pull request with dense matrix materialization (similar to what happens to any matrix-free linear operator in Lineax's direct solvers) and potentially adding sparse matrix support with a future update. I can imagine a sparse matrix support naturally leading to ways of implementing algebraic multigrid. Wdyt, @patrick-kidger? |
Yup!
Agreed. Hopefully (!) we can write the solvers in a matrix-free way that is agnostic to whether the operator is sparse or not. (Or even to what kind of sparse representation is being used -- BCOO, BCSR, ...) |
Hi @patrick-kidger,
This is to implement #116. I opened it as a draft because I wanted to get some feedback/have questions on handling Lineax internals and what features to add.
So far, I have implemented the
Jacobi
&GaussSeidel
methods as a mixture of the dense solver (with their matrix materialization) and Krylov solvers (with their implementation of the iterative process). Moreover, I added a new tagstrictly_diagonally_dominant
because this is a sufficient (but not necessary) condition for their convergence (for Jacobi and for GaussSeidel). Ifmax_step=None
, the smoothers run for100 * size
instead of10 * size
(what the Krylov solvers default to) because of the slower convergence. I am thinking of bumping it to200 * size
since I otherwise observe that sometimes some tests fail (due to the randomized matrices).I have some questions to discuss before I polish up the PR with docs etc.:
_operator.py
andtest_operator.py
) Do you think this new tag is a good addition? As far as I understand, Lineax does implement methods to actually check if operators obey their tags. It only relies on the user specification. (One could add a simple function to check if twice the absolute diagonal is larger than the absolute row sum.)Jacobi
andGaussSeidel
. For the latter, this gives the SOR method. For symmetric matrices, there is also the SSOR method. Do you think it is worth adding it as well?strictly_diagonally_dominant
tag in the smoothing methods? Giving a warning, if it is not present? It is not a necessary condition, though.These are the open points I will tackle after your feedback: