Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explicity dtype parameter to tests #63

Merged
merged 2 commits into from
Nov 7, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 12 additions & 11 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def getkey():


@ft.lru_cache(maxsize=None)
def _construct_matrix_impl(getkey, cond_cutoff, tags, size):
def _construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype=jnp.float64):
Randl marked this conversation as resolved.
Show resolved Hide resolved
while True:
matrix = jr.normal(getkey(), (size, size))
matrix = jr.normal(getkey(), (size, size), dtype=dtype)
if has_tag(tags, lx.diagonal_tag):
matrix = jnp.diag(jnp.diag(matrix))
if has_tag(tags, lx.symmetric_tag):
Expand All @@ -60,18 +60,19 @@ def _construct_matrix_impl(getkey, cond_cutoff, tags, size):
return matrix


def construct_matrix(getkey, solver, tags, num=1, *, size=3):
def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64):
if isinstance(solver, lx.NormalCG):
cond_cutoff = math.sqrt(1000)
else:
cond_cutoff = 1000
return tuple(
_construct_matrix_impl(getkey, cond_cutoff, tags, size) for _ in range(num)
_construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype)
for _ in range(num)
)


def construct_singular_matrix(getkey, solver, tags, num=1):
matrices = construct_matrix(getkey, solver, tags, num)
def construct_singular_matrix(getkey, solver, tags, num=1, dtype=jnp.float64):
matrices = construct_matrix(getkey, solver, tags, num, dtype=dtype)
if isinstance(solver, (lx.Diagonal, lx.CG, lx.BiCGStab, lx.GMRES)):
return tuple(matrix.at[0, :].set(0) for matrix in matrices)
else:
Expand Down Expand Up @@ -213,10 +214,10 @@ def make_function_operator(matrix, tags):
@_operators_append
def make_jac_operator(matrix, tags):
out_size, in_size = matrix.shape
x = jr.normal(getkey(), (in_size,))
a = jr.normal(getkey(), (out_size,))
b = jr.normal(getkey(), (out_size, in_size))
c = jr.normal(getkey(), (out_size, in_size))
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2
jac = jax.jacfwd(fn_tmp)(x, None)
diff = matrix - jac
Expand Down Expand Up @@ -269,7 +270,7 @@ def make_mul_operator(matrix, tags):
@_operators_append
def make_composed_operator(matrix, tags):
_, size = matrix.shape
diag = jr.normal(getkey(), (size,))
diag = jr.normal(getkey(), (size,), dtype=matrix.dtype)
diag = jnp.where(jnp.abs(diag) < 0.05, 0.8, diag)
operator1 = make_trivial_pytree_operator(matrix / diag, ())
operator2 = lx.DiagonalLinearOperator(diag)
Expand Down
9 changes: 5 additions & 4 deletions tests/test_jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_jvp(
getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix
getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype
):
t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None

if (make_matrix is construct_matrix) or pseudoinverse:
matrix, t_matrix = make_matrix(getkey, solver, tags, num=2)
matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype)

out_size, _ = matrix.shape
vec = jr.normal(getkey(), (out_size,))
t_vec = jr.normal(getkey(), (out_size,))
vec = jr.normal(getkey(), (out_size,), dtype=dtype)
t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)

if has_tag(tags, lx.unit_diagonal_tag):
# For all the other tags, A + εB with A, B \in {matrices satisfying the tag}
Expand Down
13 changes: 7 additions & 6 deletions tests/test_jvp_jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_jvp_jvp(
getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix
getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype
):
t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None
if (make_matrix is construct_matrix) or pseudoinverse:
matrix, t_matrix, tt_matrix, tt_t_matrix = construct_matrix(
getkey, solver, tags, num=4
getkey, solver, tags, num=4, dtype=dtype
)

t_make_operator = lambda p, t_p: eqx.filter_jvp(
Expand All @@ -62,10 +63,10 @@ def test_jvp_jvp(
)

out_size, _ = matrix.shape
vec = jr.normal(getkey(), (out_size,))
t_vec = jr.normal(getkey(), (out_size,))
tt_vec = jr.normal(getkey(), (out_size,))
tt_t_vec = jr.normal(getkey(), (out_size,))
vec = jr.normal(getkey(), (out_size,), dtype=dtype)
t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)
tt_vec = jr.normal(getkey(), (out_size,), dtype=dtype)
tt_t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)

if use_state:

Expand Down
5 changes: 3 additions & 2 deletions tests/test_singular.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@

@pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=True))
@pytest.mark.parametrize("ops", ops)
def test_small_singular(make_operator, solver, tags, ops, getkey):
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_small_singular(make_operator, solver, tags, ops, getkey, dtype):
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-10
else:
tol = 1e-4
(matrix,) = construct_singular_matrix(getkey, solver, tags)
(matrix,) = construct_singular_matrix(getkey, solver, tags, dtype=dtype)
operator = make_operator(matrix, tags)
operator, matrix = ops(operator, matrix)
assert shaped_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol)
Expand Down
9 changes: 5 additions & 4 deletions tests/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ def assert_transpose(operator, out_vec, in_vec, solver):
return assert_transpose

@pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=False))
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_transpose(
_, make_operator, solver, tags, assert_transpose_fixture, getkey
_, make_operator, solver, tags, assert_transpose_fixture, dtype, getkey
):
(matrix,) = construct_matrix(getkey, solver, tags)
(matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype)
operator = make_operator(matrix, tags)
out_size, in_size = matrix.shape
out_vec = jr.normal(getkey(), (out_size,))
in_vec = jr.normal(getkey(), (in_size,))
out_vec = jr.normal(getkey(), (out_size,), dtype=dtype)
in_vec = jr.normal(getkey(), (in_size,), dtype=dtype)
solver = lx.AutoLinearSolver(well_posed=True)
assert_transpose_fixture(operator, out_vec, in_vec, solver)

Expand Down
13 changes: 9 additions & 4 deletions tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_vmap(
getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix
getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype
):
if (make_matrix is construct_matrix) or pseudoinverse:

Expand All @@ -65,14 +66,18 @@ def wrap_solve(matrix, vector):
out_axes = eqx.if_array(0)

(matrix,) = eqx.filter_vmap(
make_matrix, axis_size=axis_size, out_axes=out_axes
lambda getkey, solver, tags: make_matrix(
getkey, solver, tags, dtype=dtype
),
axis_size=axis_size,
out_axes=out_axes,
)(getkey, solver, tags)
out_dim = matrix.shape[-2]

if vec_axis is None:
vec = jr.normal(getkey(), (out_dim,))
vec = jr.normal(getkey(), (out_dim,), dtype=dtype)
else:
vec = jr.normal(getkey(), (10, out_dim))
vec = jr.normal(getkey(), (10, out_dim), dtype=dtype)

jax_result, _, _, _ = eqx.filter_vmap(
jnp.linalg.lstsq, in_axes=(op_axis, vec_axis) # pyright: ignore
Expand Down
13 changes: 7 additions & 6 deletions tests/test_vmap_jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_vmap_jvp(
getkey, solver, tags, make_operator, pseudoinverse, use_state, make_matrix
getkey, solver, tags, make_operator, pseudoinverse, use_state, make_matrix, dtype
):
if (make_matrix is construct_matrix) or pseudoinverse:
t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None
Expand Down Expand Up @@ -75,7 +76,7 @@ def linear_solve1(operator, vector):
out_axes = None

def _make():
matrix, t_matrix = make_matrix(getkey, solver, tags, num=2)
matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype)
operator, t_operator = eqx.filter_jvp(
make_operator, (matrix, tags), (t_matrix, t_tags)
)
Expand All @@ -91,11 +92,11 @@ def _make():
out_size, _ = matrix.shape

if "vec" in mode:
vec = jr.normal(getkey(), (10, out_size))
t_vec = jr.normal(getkey(), (10, out_size))
vec = jr.normal(getkey(), (10, out_size), dtype=dtype)
t_vec = jr.normal(getkey(), (10, out_size), dtype=dtype)
else:
vec = jr.normal(getkey(), (out_size,))
t_vec = jr.normal(getkey(), (out_size,))
vec = jr.normal(getkey(), (out_size,), dtype=dtype)
t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)

if mode == "op":
linear_solve2 = lambda op: linear_solve1(op, vector=vec)
Expand Down
17 changes: 12 additions & 5 deletions tests/test_vmap_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_vmap_vmap(
getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix
getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype
):
if (make_matrix is construct_matrix) or pseudoinverse:
# combinations with nontrivial application across both vmaps
Expand Down Expand Up @@ -69,7 +70,13 @@ def test_vmap_vmap(
out_axis2 = None

(matrix,) = eqx.filter_vmap(
eqx.filter_vmap(make_matrix, axis_size=axis_size1, out_axes=out_axis1),
eqx.filter_vmap(
lambda getkey, solver, tags: make_matrix(
getkey, solver, tags, dtype=dtype
),
axis_size=axis_size1,
out_axes=out_axis1,
),
axis_size=axis_size2,
out_axes=out_axis2,
)(getkey, solver, tags)
Expand All @@ -83,11 +90,11 @@ def test_vmap_vmap(
out_size, _ = matrix.shape

if vmap1_vec is None:
vec = jr.normal(getkey(), (out_size,))
vec = jr.normal(getkey(), (out_size,), dtype=dtype)
elif (vmap1_vec is not None) and (vmap2_vec is None):
vec = jr.normal(getkey(), (10, out_size))
vec = jr.normal(getkey(), (10, out_size), dtype=dtype)
else:
vec = jr.normal(getkey(), (10, 10, out_size))
vec = jr.normal(getkey(), (10, 10, out_size), dtype=dtype)

operator = eqx.filter_vmap(
eqx.filter_vmap(
Expand Down
7 changes: 4 additions & 3 deletions tests/test_well_posed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@

@pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=False))
@pytest.mark.parametrize("ops", ops)
def test_small_wellposed(make_operator, solver, tags, ops, getkey):
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype):
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-10
else:
tol = 1e-4
(matrix,) = construct_matrix(getkey, solver, tags)
(matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype)
operator = make_operator(matrix, tags)
operator, matrix = ops(operator, matrix)
assert shaped_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol)
out_size, _ = matrix.shape
true_x = jr.normal(getkey(), (out_size,))
true_x = jr.normal(getkey(), (out_size,), dtype=dtype)
b = matrix @ true_x
x = lx.linear_solve(operator, b, solver=solver).value
jax_x = jnp.linalg.solve(matrix, b) # pyright: ignore
Expand Down