You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, thank you for creating the awesome libraries in JAX. I started to use lineax recently and compared it with the linear solver in JAX. The code below resulted in 931 us for lineax and 171 us for jnp.linalg.solve. Is there anything wrong with my implementation? Or, should I just stick to jnp.linalg.solve? No way to use _gesv Fortran routine through lineax?
from jax import random
import jax.numpy as jnp
import lineax as lx
matrix_key, vector_key = random.split(random.PRNGKey(0))
matrix = random.normal(matrix_key, (10, 10))
vector = random.normal(vector_key, (10,))
operator = lx.MatrixLinearOperator(matrix)
solution = lx.linear_solve(operator, vector)
%timeit lx.linear_solve(operator, vector, solver=lx.LU())
%timeit jnp.linalg.solve(matrix, vector)
The text was updated successfully, but these errors were encountered:
Error-checking on the Lineax output. By default Lineax has an extra check that the return doesn't have NaNs etc., i.e. that the solve was successful. This can be disabled by passing linear_solve(..., throw=False).
Pytree flattening/unflattening across JIT boundaries. matrix and vector are simpler PyTrees than operator and lx.LU().
With this benchmark I obtain identical performance:
Hi, thank you for creating the awesome libraries in JAX. I started to use lineax recently and compared it with the linear solver in JAX. The code below resulted in 931 us for lineax and 171 us for jnp.linalg.solve. Is there anything wrong with my implementation? Or, should I just stick to jnp.linalg.solve? No way to use _gesv Fortran routine through lineax?
The text was updated successfully, but these errors were encountered: