Skip to content

Commit

Permalink
Add explicity dtype parameter to tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl authored and patrick-kidger committed Nov 7, 2023
1 parent 3f34189 commit d803c1e
Show file tree
Hide file tree
Showing 9 changed files with 64 additions and 45 deletions.
23 changes: 12 additions & 11 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ def getkey():


@ft.lru_cache(maxsize=None)
def _construct_matrix_impl(getkey, cond_cutoff, tags, size):
def _construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype=jnp.float64):
while True:
matrix = jr.normal(getkey(), (size, size))
matrix = jr.normal(getkey(), (size, size), dtype=dtype)
if has_tag(tags, lx.diagonal_tag):
matrix = jnp.diag(jnp.diag(matrix))
if has_tag(tags, lx.symmetric_tag):
Expand All @@ -60,18 +60,19 @@ def _construct_matrix_impl(getkey, cond_cutoff, tags, size):
return matrix


def construct_matrix(getkey, solver, tags, num=1, *, size=3):
def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64):
if isinstance(solver, lx.NormalCG):
cond_cutoff = math.sqrt(1000)
else:
cond_cutoff = 1000
return tuple(
_construct_matrix_impl(getkey, cond_cutoff, tags, size) for _ in range(num)
_construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype)
for _ in range(num)
)


def construct_singular_matrix(getkey, solver, tags, num=1):
matrices = construct_matrix(getkey, solver, tags, num)
def construct_singular_matrix(getkey, solver, tags, num=1, dtype=jnp.float64):
matrices = construct_matrix(getkey, solver, tags, num, dtype=dtype)
if isinstance(solver, (lx.Diagonal, lx.CG, lx.BiCGStab, lx.GMRES)):
return tuple(matrix.at[0, :].set(0) for matrix in matrices)
else:
Expand Down Expand Up @@ -213,10 +214,10 @@ def make_function_operator(matrix, tags):
@_operators_append
def make_jac_operator(matrix, tags):
out_size, in_size = matrix.shape
x = jr.normal(getkey(), (in_size,))
a = jr.normal(getkey(), (out_size,))
b = jr.normal(getkey(), (out_size, in_size))
c = jr.normal(getkey(), (out_size, in_size))
x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype)
a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype)
b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype)
fn_tmp = lambda x, _: a + b @ x + c @ x**2
jac = jax.jacfwd(fn_tmp)(x, None)
diff = matrix - jac
Expand Down Expand Up @@ -269,7 +270,7 @@ def make_mul_operator(matrix, tags):
@_operators_append
def make_composed_operator(matrix, tags):
_, size = matrix.shape
diag = jr.normal(getkey(), (size,))
diag = jr.normal(getkey(), (size,), dtype=matrix.dtype)
diag = jnp.where(jnp.abs(diag) < 0.05, 0.8, diag)
operator1 = make_trivial_pytree_operator(matrix / diag, ())
operator2 = lx.DiagonalLinearOperator(diag)
Expand Down
9 changes: 5 additions & 4 deletions tests/test_jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_jvp(
getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix
getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype
):
t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None

if (make_matrix is construct_matrix) or pseudoinverse:
matrix, t_matrix = make_matrix(getkey, solver, tags, num=2)
matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype)

out_size, _ = matrix.shape
vec = jr.normal(getkey(), (out_size,))
t_vec = jr.normal(getkey(), (out_size,))
vec = jr.normal(getkey(), (out_size,), dtype=dtype)
t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)

if has_tag(tags, lx.unit_diagonal_tag):
# For all the other tags, A + εB with A, B \in {matrices satisfying the tag}
Expand Down
13 changes: 7 additions & 6 deletions tests/test_jvp_jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,14 @@
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_jvp_jvp(
getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix
getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype
):
t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None
if (make_matrix is construct_matrix) or pseudoinverse:
matrix, t_matrix, tt_matrix, tt_t_matrix = construct_matrix(
getkey, solver, tags, num=4
getkey, solver, tags, num=4, dtype=dtype
)

t_make_operator = lambda p, t_p: eqx.filter_jvp(
Expand All @@ -62,10 +63,10 @@ def test_jvp_jvp(
)

out_size, _ = matrix.shape
vec = jr.normal(getkey(), (out_size,))
t_vec = jr.normal(getkey(), (out_size,))
tt_vec = jr.normal(getkey(), (out_size,))
tt_t_vec = jr.normal(getkey(), (out_size,))
vec = jr.normal(getkey(), (out_size,), dtype=dtype)
t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)
tt_vec = jr.normal(getkey(), (out_size,), dtype=dtype)
tt_t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)

if use_state:

Expand Down
5 changes: 3 additions & 2 deletions tests/test_singular.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,13 @@

@pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=True))
@pytest.mark.parametrize("ops", ops)
def test_small_singular(make_operator, solver, tags, ops, getkey):
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_small_singular(make_operator, solver, tags, ops, getkey, dtype):
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-10
else:
tol = 1e-4
(matrix,) = construct_singular_matrix(getkey, solver, tags)
(matrix,) = construct_singular_matrix(getkey, solver, tags, dtype=dtype)
operator = make_operator(matrix, tags)
operator, matrix = ops(operator, matrix)
assert shaped_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol)
Expand Down
9 changes: 5 additions & 4 deletions tests/test_transpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,15 @@ def assert_transpose(operator, out_vec, in_vec, solver):
return assert_transpose

@pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=False))
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_transpose(
_, make_operator, solver, tags, assert_transpose_fixture, getkey
_, make_operator, solver, tags, assert_transpose_fixture, dtype, getkey
):
(matrix,) = construct_matrix(getkey, solver, tags)
(matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype)
operator = make_operator(matrix, tags)
out_size, in_size = matrix.shape
out_vec = jr.normal(getkey(), (out_size,))
in_vec = jr.normal(getkey(), (in_size,))
out_vec = jr.normal(getkey(), (out_size,), dtype=dtype)
in_vec = jr.normal(getkey(), (in_size,), dtype=dtype)
solver = lx.AutoLinearSolver(well_posed=True)
assert_transpose_fixture(operator, out_vec, in_vec, solver)

Expand Down
13 changes: 9 additions & 4 deletions tests/test_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_vmap(
getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix
getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype
):
if (make_matrix is construct_matrix) or pseudoinverse:

Expand All @@ -65,14 +66,18 @@ def wrap_solve(matrix, vector):
out_axes = eqx.if_array(0)

(matrix,) = eqx.filter_vmap(
make_matrix, axis_size=axis_size, out_axes=out_axes
lambda getkey, solver, tags: make_matrix(
getkey, solver, tags, dtype=dtype
),
axis_size=axis_size,
out_axes=out_axes,
)(getkey, solver, tags)
out_dim = matrix.shape[-2]

if vec_axis is None:
vec = jr.normal(getkey(), (out_dim,))
vec = jr.normal(getkey(), (out_dim,), dtype=dtype)
else:
vec = jr.normal(getkey(), (10, out_dim))
vec = jr.normal(getkey(), (10, out_dim), dtype=dtype)

jax_result, _, _, _ = eqx.filter_vmap(
jnp.linalg.lstsq, in_axes=(op_axis, vec_axis) # pyright: ignore
Expand Down
13 changes: 7 additions & 6 deletions tests/test_vmap_jvp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,9 @@
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_vmap_jvp(
getkey, solver, tags, make_operator, pseudoinverse, use_state, make_matrix
getkey, solver, tags, make_operator, pseudoinverse, use_state, make_matrix, dtype
):
if (make_matrix is construct_matrix) or pseudoinverse:
t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None
Expand Down Expand Up @@ -75,7 +76,7 @@ def linear_solve1(operator, vector):
out_axes = None

def _make():
matrix, t_matrix = make_matrix(getkey, solver, tags, num=2)
matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype)
operator, t_operator = eqx.filter_jvp(
make_operator, (matrix, tags), (t_matrix, t_tags)
)
Expand All @@ -91,11 +92,11 @@ def _make():
out_size, _ = matrix.shape

if "vec" in mode:
vec = jr.normal(getkey(), (10, out_size))
t_vec = jr.normal(getkey(), (10, out_size))
vec = jr.normal(getkey(), (10, out_size), dtype=dtype)
t_vec = jr.normal(getkey(), (10, out_size), dtype=dtype)
else:
vec = jr.normal(getkey(), (out_size,))
t_vec = jr.normal(getkey(), (out_size,))
vec = jr.normal(getkey(), (out_size,), dtype=dtype)
t_vec = jr.normal(getkey(), (out_size,), dtype=dtype)

if mode == "op":
linear_solve2 = lambda op: linear_solve1(op, vector=vec)
Expand Down
17 changes: 12 additions & 5 deletions tests/test_vmap_vmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,9 @@
construct_singular_matrix,
),
)
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_vmap_vmap(
getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix
getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype
):
if (make_matrix is construct_matrix) or pseudoinverse:
# combinations with nontrivial application across both vmaps
Expand Down Expand Up @@ -69,7 +70,13 @@ def test_vmap_vmap(
out_axis2 = None

(matrix,) = eqx.filter_vmap(
eqx.filter_vmap(make_matrix, axis_size=axis_size1, out_axes=out_axis1),
eqx.filter_vmap(
lambda getkey, solver, tags: make_matrix(
getkey, solver, tags, dtype=dtype
),
axis_size=axis_size1,
out_axes=out_axis1,
),
axis_size=axis_size2,
out_axes=out_axis2,
)(getkey, solver, tags)
Expand All @@ -83,11 +90,11 @@ def test_vmap_vmap(
out_size, _ = matrix.shape

if vmap1_vec is None:
vec = jr.normal(getkey(), (out_size,))
vec = jr.normal(getkey(), (out_size,), dtype=dtype)
elif (vmap1_vec is not None) and (vmap2_vec is None):
vec = jr.normal(getkey(), (10, out_size))
vec = jr.normal(getkey(), (10, out_size), dtype=dtype)
else:
vec = jr.normal(getkey(), (10, 10, out_size))
vec = jr.normal(getkey(), (10, 10, out_size), dtype=dtype)

operator = eqx.filter_vmap(
eqx.filter_vmap(
Expand Down
7 changes: 4 additions & 3 deletions tests/test_well_posed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,17 +30,18 @@

@pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=False))
@pytest.mark.parametrize("ops", ops)
def test_small_wellposed(make_operator, solver, tags, ops, getkey):
@pytest.mark.parametrize("dtype", (jnp.float64,))
def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype):
if jax.config.jax_enable_x64: # pyright: ignore
tol = 1e-10
else:
tol = 1e-4
(matrix,) = construct_matrix(getkey, solver, tags)
(matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype)
operator = make_operator(matrix, tags)
operator, matrix = ops(operator, matrix)
assert shaped_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol)
out_size, _ = matrix.shape
true_x = jr.normal(getkey(), (out_size,))
true_x = jr.normal(getkey(), (out_size,), dtype=dtype)
b = matrix @ true_x
x = lx.linear_solve(operator, b, solver=solver).value
jax_x = jnp.linalg.solve(matrix, b) # pyright: ignore
Expand Down

0 comments on commit d803c1e

Please sign in to comment.