From ec03b5b534fb98e4d7ecff3819fbdf40c7f0bccb Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Wed, 1 Nov 2023 21:50:05 +0200 Subject: [PATCH 1/2] Add explicity dtype parameter to tests --- tests/helpers.py | 23 ++++++++++++----------- tests/test_jvp.py | 9 +++++---- tests/test_jvp_jvp.py | 13 +++++++------ tests/test_singular.py | 5 +++-- tests/test_transpose.py | 9 +++++---- tests/test_vmap.py | 13 +++++++++---- tests/test_vmap_jvp.py | 13 +++++++------ tests/test_vmap_vmap.py | 17 ++++++++++++----- tests/test_well_posed.py | 7 ++++--- 9 files changed, 64 insertions(+), 45 deletions(-) diff --git a/tests/helpers.py b/tests/helpers.py index 391861a..e32c2ae 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -33,9 +33,9 @@ def getkey(): @ft.lru_cache(maxsize=None) -def _construct_matrix_impl(getkey, cond_cutoff, tags, size): +def _construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype=jnp.float64): while True: - matrix = jr.normal(getkey(), (size, size)) + matrix = jr.normal(getkey(), (size, size), dtype=dtype) if has_tag(tags, lx.diagonal_tag): matrix = jnp.diag(jnp.diag(matrix)) if has_tag(tags, lx.symmetric_tag): @@ -60,18 +60,19 @@ def _construct_matrix_impl(getkey, cond_cutoff, tags, size): return matrix -def construct_matrix(getkey, solver, tags, num=1, *, size=3): +def construct_matrix(getkey, solver, tags, num=1, *, size=3, dtype=jnp.float64): if isinstance(solver, lx.NormalCG): cond_cutoff = math.sqrt(1000) else: cond_cutoff = 1000 return tuple( - _construct_matrix_impl(getkey, cond_cutoff, tags, size) for _ in range(num) + _construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype) + for _ in range(num) ) -def construct_singular_matrix(getkey, solver, tags, num=1): - matrices = construct_matrix(getkey, solver, tags, num) +def construct_singular_matrix(getkey, solver, tags, num=1, dtype=jnp.float64): + matrices = construct_matrix(getkey, solver, tags, num, dtype=dtype) if isinstance(solver, (lx.Diagonal, lx.CG, lx.BiCGStab, lx.GMRES)): return tuple(matrix.at[0, :].set(0) for matrix in matrices) else: @@ -213,10 +214,10 @@ def make_function_operator(matrix, tags): @_operators_append def make_jac_operator(matrix, tags): out_size, in_size = matrix.shape - x = jr.normal(getkey(), (in_size,)) - a = jr.normal(getkey(), (out_size,)) - b = jr.normal(getkey(), (out_size, in_size)) - c = jr.normal(getkey(), (out_size, in_size)) + x = jr.normal(getkey(), (in_size,), dtype=matrix.dtype) + a = jr.normal(getkey(), (out_size,), dtype=matrix.dtype) + b = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) + c = jr.normal(getkey(), (out_size, in_size), dtype=matrix.dtype) fn_tmp = lambda x, _: a + b @ x + c @ x**2 jac = jax.jacfwd(fn_tmp)(x, None) diff = matrix - jac @@ -269,7 +270,7 @@ def make_mul_operator(matrix, tags): @_operators_append def make_composed_operator(matrix, tags): _, size = matrix.shape - diag = jr.normal(getkey(), (size,)) + diag = jr.normal(getkey(), (size,), dtype=matrix.dtype) diag = jnp.where(jnp.abs(diag) < 0.05, 0.8, diag) operator1 = make_trivial_pytree_operator(matrix / diag, ()) operator2 = lx.DiagonalLinearOperator(diag) diff --git a/tests/test_jvp.py b/tests/test_jvp.py index 2260a36..221f591 100644 --- a/tests/test_jvp.py +++ b/tests/test_jvp.py @@ -43,17 +43,18 @@ construct_singular_matrix, ), ) +@pytest.mark.parametrize("dtype", (jnp.float64,)) def test_jvp( - getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix + getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype ): t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None if (make_matrix is construct_matrix) or pseudoinverse: - matrix, t_matrix = make_matrix(getkey, solver, tags, num=2) + matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype) out_size, _ = matrix.shape - vec = jr.normal(getkey(), (out_size,)) - t_vec = jr.normal(getkey(), (out_size,)) + vec = jr.normal(getkey(), (out_size,), dtype=dtype) + t_vec = jr.normal(getkey(), (out_size,), dtype=dtype) if has_tag(tags, lx.unit_diagonal_tag): # For all the other tags, A + εB with A, B \in {matrices satisfying the tag} diff --git a/tests/test_jvp_jvp.py b/tests/test_jvp_jvp.py index bc92368..7a87591 100644 --- a/tests/test_jvp_jvp.py +++ b/tests/test_jvp_jvp.py @@ -42,13 +42,14 @@ construct_singular_matrix, ), ) +@pytest.mark.parametrize("dtype", (jnp.float64,)) def test_jvp_jvp( - getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix + getkey, solver, tags, pseudoinverse, make_operator, use_state, make_matrix, dtype ): t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None if (make_matrix is construct_matrix) or pseudoinverse: matrix, t_matrix, tt_matrix, tt_t_matrix = construct_matrix( - getkey, solver, tags, num=4 + getkey, solver, tags, num=4, dtype=dtype ) t_make_operator = lambda p, t_p: eqx.filter_jvp( @@ -62,10 +63,10 @@ def test_jvp_jvp( ) out_size, _ = matrix.shape - vec = jr.normal(getkey(), (out_size,)) - t_vec = jr.normal(getkey(), (out_size,)) - tt_vec = jr.normal(getkey(), (out_size,)) - tt_t_vec = jr.normal(getkey(), (out_size,)) + vec = jr.normal(getkey(), (out_size,), dtype=dtype) + t_vec = jr.normal(getkey(), (out_size,), dtype=dtype) + tt_vec = jr.normal(getkey(), (out_size,), dtype=dtype) + tt_t_vec = jr.normal(getkey(), (out_size,), dtype=dtype) if use_state: diff --git a/tests/test_singular.py b/tests/test_singular.py index f2a65df..2bedede 100644 --- a/tests/test_singular.py +++ b/tests/test_singular.py @@ -37,12 +37,13 @@ @pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=True)) @pytest.mark.parametrize("ops", ops) -def test_small_singular(make_operator, solver, tags, ops, getkey): +@pytest.mark.parametrize("dtype", (jnp.float64,)) +def test_small_singular(make_operator, solver, tags, ops, getkey, dtype): if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: tol = 1e-4 - (matrix,) = construct_singular_matrix(getkey, solver, tags) + (matrix,) = construct_singular_matrix(getkey, solver, tags, dtype=dtype) operator = make_operator(matrix, tags) operator, matrix = ops(operator, matrix) assert shaped_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol) diff --git a/tests/test_transpose.py b/tests/test_transpose.py index 5c88c68..d16511b 100644 --- a/tests/test_transpose.py +++ b/tests/test_transpose.py @@ -40,14 +40,15 @@ def assert_transpose(operator, out_vec, in_vec, solver): return assert_transpose @pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=False)) + @pytest.mark.parametrize("dtype", (jnp.float64,)) def test_transpose( - _, make_operator, solver, tags, assert_transpose_fixture, getkey + _, make_operator, solver, tags, assert_transpose_fixture, dtype, getkey ): - (matrix,) = construct_matrix(getkey, solver, tags) + (matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype) operator = make_operator(matrix, tags) out_size, in_size = matrix.shape - out_vec = jr.normal(getkey(), (out_size,)) - in_vec = jr.normal(getkey(), (in_size,)) + out_vec = jr.normal(getkey(), (out_size,), dtype=dtype) + in_vec = jr.normal(getkey(), (in_size,), dtype=dtype) solver = lx.AutoLinearSolver(well_posed=True) assert_transpose_fixture(operator, out_vec, in_vec, solver) diff --git a/tests/test_vmap.py b/tests/test_vmap.py index 8ac692a..2165e9b 100644 --- a/tests/test_vmap.py +++ b/tests/test_vmap.py @@ -39,8 +39,9 @@ construct_singular_matrix, ), ) +@pytest.mark.parametrize("dtype", (jnp.float64,)) def test_vmap( - getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix + getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype ): if (make_matrix is construct_matrix) or pseudoinverse: @@ -65,14 +66,18 @@ def wrap_solve(matrix, vector): out_axes = eqx.if_array(0) (matrix,) = eqx.filter_vmap( - make_matrix, axis_size=axis_size, out_axes=out_axes + lambda getkey, solver, tags: make_matrix( + getkey, solver, tags, dtype=dtype + ), + axis_size=axis_size, + out_axes=out_axes, )(getkey, solver, tags) out_dim = matrix.shape[-2] if vec_axis is None: - vec = jr.normal(getkey(), (out_dim,)) + vec = jr.normal(getkey(), (out_dim,), dtype=dtype) else: - vec = jr.normal(getkey(), (10, out_dim)) + vec = jr.normal(getkey(), (10, out_dim), dtype=dtype) jax_result, _, _, _ = eqx.filter_vmap( jnp.linalg.lstsq, in_axes=(op_axis, vec_axis) # pyright: ignore diff --git a/tests/test_vmap_jvp.py b/tests/test_vmap_jvp.py index 704e773..3b5b960 100644 --- a/tests/test_vmap_jvp.py +++ b/tests/test_vmap_jvp.py @@ -42,8 +42,9 @@ construct_singular_matrix, ), ) +@pytest.mark.parametrize("dtype", (jnp.float64,)) def test_vmap_jvp( - getkey, solver, tags, make_operator, pseudoinverse, use_state, make_matrix + getkey, solver, tags, make_operator, pseudoinverse, use_state, make_matrix, dtype ): if (make_matrix is construct_matrix) or pseudoinverse: t_tags = (None,) * len(tags) if isinstance(tags, tuple) else None @@ -75,7 +76,7 @@ def linear_solve1(operator, vector): out_axes = None def _make(): - matrix, t_matrix = make_matrix(getkey, solver, tags, num=2) + matrix, t_matrix = make_matrix(getkey, solver, tags, num=2, dtype=dtype) operator, t_operator = eqx.filter_jvp( make_operator, (matrix, tags), (t_matrix, t_tags) ) @@ -91,11 +92,11 @@ def _make(): out_size, _ = matrix.shape if "vec" in mode: - vec = jr.normal(getkey(), (10, out_size)) - t_vec = jr.normal(getkey(), (10, out_size)) + vec = jr.normal(getkey(), (10, out_size), dtype=dtype) + t_vec = jr.normal(getkey(), (10, out_size), dtype=dtype) else: - vec = jr.normal(getkey(), (out_size,)) - t_vec = jr.normal(getkey(), (out_size,)) + vec = jr.normal(getkey(), (out_size,), dtype=dtype) + t_vec = jr.normal(getkey(), (out_size,), dtype=dtype) if mode == "op": linear_solve2 = lambda op: linear_solve1(op, vector=vec) diff --git a/tests/test_vmap_vmap.py b/tests/test_vmap_vmap.py index 3c22a9a..85eade6 100644 --- a/tests/test_vmap_vmap.py +++ b/tests/test_vmap_vmap.py @@ -40,8 +40,9 @@ construct_singular_matrix, ), ) +@pytest.mark.parametrize("dtype", (jnp.float64,)) def test_vmap_vmap( - getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix + getkey, make_operator, solver, tags, pseudoinverse, use_state, make_matrix, dtype ): if (make_matrix is construct_matrix) or pseudoinverse: # combinations with nontrivial application across both vmaps @@ -69,7 +70,13 @@ def test_vmap_vmap( out_axis2 = None (matrix,) = eqx.filter_vmap( - eqx.filter_vmap(make_matrix, axis_size=axis_size1, out_axes=out_axis1), + eqx.filter_vmap( + lambda getkey, solver, tags: make_matrix( + getkey, solver, tags, dtype=dtype + ), + axis_size=axis_size1, + out_axes=out_axis1, + ), axis_size=axis_size2, out_axes=out_axis2, )(getkey, solver, tags) @@ -83,11 +90,11 @@ def test_vmap_vmap( out_size, _ = matrix.shape if vmap1_vec is None: - vec = jr.normal(getkey(), (out_size,)) + vec = jr.normal(getkey(), (out_size,), dtype=dtype) elif (vmap1_vec is not None) and (vmap2_vec is None): - vec = jr.normal(getkey(), (10, out_size)) + vec = jr.normal(getkey(), (10, out_size), dtype=dtype) else: - vec = jr.normal(getkey(), (10, 10, out_size)) + vec = jr.normal(getkey(), (10, 10, out_size), dtype=dtype) operator = eqx.filter_vmap( eqx.filter_vmap( diff --git a/tests/test_well_posed.py b/tests/test_well_posed.py index 193c8af..9bd48c5 100644 --- a/tests/test_well_posed.py +++ b/tests/test_well_posed.py @@ -30,17 +30,18 @@ @pytest.mark.parametrize("make_operator,solver,tags", params(only_pseudo=False)) @pytest.mark.parametrize("ops", ops) -def test_small_wellposed(make_operator, solver, tags, ops, getkey): +@pytest.mark.parametrize("dtype", (jnp.float64,)) +def test_small_wellposed(make_operator, solver, tags, ops, getkey, dtype): if jax.config.jax_enable_x64: # pyright: ignore tol = 1e-10 else: tol = 1e-4 - (matrix,) = construct_matrix(getkey, solver, tags) + (matrix,) = construct_matrix(getkey, solver, tags, dtype=dtype) operator = make_operator(matrix, tags) operator, matrix = ops(operator, matrix) assert shaped_allclose(operator.as_matrix(), matrix, rtol=tol, atol=tol) out_size, _ = matrix.shape - true_x = jr.normal(getkey(), (out_size,)) + true_x = jr.normal(getkey(), (out_size,), dtype=dtype) b = matrix @ true_x x = lx.linear_solve(operator, b, solver=solver).value jax_x = jnp.linalg.solve(matrix, b) # pyright: ignore From fd77828da7d0109d8a77d444ed484be6afe09932 Mon Sep 17 00:00:00 2001 From: Evgenii Zheltonozhskii Date: Wed, 1 Nov 2023 23:11:51 +0200 Subject: [PATCH 2/2] Remove default --- tests/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/helpers.py b/tests/helpers.py index e32c2ae..10d38f1 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -33,7 +33,7 @@ def getkey(): @ft.lru_cache(maxsize=None) -def _construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype=jnp.float64): +def _construct_matrix_impl(getkey, cond_cutoff, tags, size, dtype): while True: matrix = jr.normal(getkey(), (size, size), dtype=dtype) if has_tag(tags, lx.diagonal_tag):