From 376ce9b1043c0afaa079d1d8626d5a2fdc18be91 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 21 Oct 2024 15:25:57 +0200 Subject: [PATCH] Split SDE tests in half, to try and avoid GitHub runner issues? --- test/{test_sde.py => test_sde1.py} | 158 +---------------------------- test/test_sde2.py | 154 ++++++++++++++++++++++++++++ 2 files changed, 155 insertions(+), 157 deletions(-) rename test/{test_sde.py => test_sde1.py} (56%) create mode 100644 test/test_sde2.py diff --git a/test/test_sde.py b/test/test_sde1.py similarity index 56% rename from test/test_sde.py rename to test/test_sde1.py index cdac924f..b4504872 100644 --- a/test/test_sde.py +++ b/test/test_sde1.py @@ -1,13 +1,9 @@ from typing import Literal import diffrax -import jax import jax.numpy as jnp import jax.random as jr -import jax.tree_util as jtu -import lineax as lx import pytest -from diffrax import ControlTerm, MultiTerm, ODETerm, WeaklyDiagonalControlTerm from .helpers import ( get_mlp_sde, @@ -119,10 +115,7 @@ def get_dt_and_controller(level): # using a single reference solution. We use Euler if the solver is Ito # and Heun if the solver is Stratonovich. @pytest.mark.parametrize("solver_ctr,noise,theoretical_order", _solvers_and_orders()) -@pytest.mark.parametrize( - "dtype", - (jnp.float64,), -) +@pytest.mark.parametrize("dtype", (jnp.float64,)) def test_sde_strong_limit( solver_ctr, noise: Literal["any", "com", "add"], theoretical_order, dtype ): @@ -184,152 +177,3 @@ def test_sde_strong_limit( ) error = path_l2_dist(correct_sol, sol) assert error < 0.05 - - -def _solvers(): - yield diffrax.SPaRK - yield diffrax.GeneralShARK - yield diffrax.SlowRK - yield diffrax.ShARK - yield diffrax.SRA1 - yield diffrax.SEA - - -# Define the SDE -def dict_drift(t, y, args): - pytree, _ = args - return jtu.tree_map(lambda _, x: -0.5 * x, pytree, y) - - -def dict_diffusion(t, y, args): - pytree, additive = args - - def get_matrix(y_leaf): - if additive: - return 2.0 * jnp.ones(y_leaf.shape + (3,), dtype=jnp.float64) - else: - return 2.0 * jnp.broadcast_to( - jnp.expand_dims(y_leaf, axis=y_leaf.ndim), y_leaf.shape + (3,) - ) - - return jtu.tree_map(get_matrix, y) - - -@pytest.mark.parametrize("shape", [(), (5, 2)]) -@pytest.mark.parametrize("solver_ctr", _solvers()) -@pytest.mark.parametrize( - "dtype", - (jnp.float64, jnp.complex128), -) -def test_sde_solver_shape(shape, solver_ctr, dtype): - pytree = ({"a": 0, "b": [0, 0]}, 0, 0) - key = jr.PRNGKey(0) - y0 = jtu.tree_map(lambda _: jr.normal(key, shape, dtype=dtype), pytree) - t0, t1, dt0 = 0.0, 1.0, 0.3 - - # Some solvers only work with additive noise - additive = solver_ctr in [diffrax.ShARK, diffrax.SRA1, diffrax.SEA] - args = (pytree, additive) - solver = solver_ctr() - bmkey = jr.key(1) - struct = jax.ShapeDtypeStruct((3,), dtype) - bm_shape = jtu.tree_map(lambda _: struct, pytree) - bm = diffrax.VirtualBrownianTree( - t0, t1, 0.1, bm_shape, bmkey, diffrax.SpaceTimeLevyArea - ) - terms = MultiTerm(ODETerm(dict_drift), ControlTerm(dict_diffusion, bm)) - solution = diffrax.diffeqsolve( - terms, solver, t0, t1, dt0, y0, args, saveat=diffrax.SaveAt(t1=True) - ) - assert jtu.tree_structure(solution.ys) == jtu.tree_structure(y0) - for leaf in jtu.tree_leaves(solution.ys): - assert leaf[0].shape == shape - - -def _weakly_diagonal_noise_helper(solver, dtype): - w_shape = (3,) - args = (0.5, 1.2) - - def _diffusion(t, y, args): - a, b = args - return jnp.array([b, t, 1 / (t + 1.0)], dtype=dtype) - - def _drift(t, y, args): - a, b = args - return -a * y - - y0 = jnp.ones(w_shape, dtype) - - bm = diffrax.VirtualBrownianTree( - 0.0, 1.0, 0.05, w_shape, jr.key(0), diffrax.SpaceTimeLevyArea - ) - - terms = MultiTerm(ODETerm(_drift), WeaklyDiagonalControlTerm(_diffusion, bm)) - saveat = diffrax.SaveAt(t1=True) - solution = diffrax.diffeqsolve( - terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat - ) - assert solution.ys is not None - assert solution.ys.shape == (1, 3) - - -def _lineax_weakly_diagonal_noise_helper(solver, dtype): - w_shape = (3,) - args = (0.5, 1.2) - - def _diffusion(t, y, args): - a, b = args - return lx.DiagonalLinearOperator(jnp.array([b, t, 1 / (t + 1.0)], dtype=dtype)) - - def _drift(t, y, args): - a, b = args - return -a * y - - y0 = jnp.ones(w_shape, dtype) - - bm = diffrax.VirtualBrownianTree( - 0.0, 1.0, 0.05, w_shape, jr.PRNGKey(0), diffrax.SpaceTimeLevyArea - ) - - terms = MultiTerm(ODETerm(_drift), ControlTerm(_diffusion, bm)) - saveat = diffrax.SaveAt(t1=True) - solution = diffrax.diffeqsolve( - terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat - ) - assert solution.ys is not None - assert solution.ys.shape == (1, 3) - - -@pytest.mark.parametrize("solver_ctr", _solvers()) -@pytest.mark.parametrize( - "dtype", - (jnp.float64, jnp.complex128), -) -@pytest.mark.parametrize( - "weak_type", - ("old", "lineax"), -) -def test_weakly_diagonal_noise(solver_ctr, dtype, weak_type): - if weak_type == "old": - _weakly_diagonal_noise_helper(solver_ctr(), dtype) - elif weak_type == "lineax": - _lineax_weakly_diagonal_noise_helper(solver_ctr(), dtype) - else: - raise ValueError("Invalid weak_type") - - -@pytest.mark.parametrize( - "dtype", - (jnp.float64, jnp.complex128), -) -@pytest.mark.parametrize( - "weak_type", - ("old", "lineax"), -) -def test_halfsolver_term_compatible(dtype, weak_type): - if weak_type == "old": - _weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype) - elif weak_type == "lineax": - _lineax_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype) - else: - raise ValueError("Invalid weak_type") diff --git a/test/test_sde2.py b/test/test_sde2.py new file mode 100644 index 00000000..3b4a4628 --- /dev/null +++ b/test/test_sde2.py @@ -0,0 +1,154 @@ +import diffrax +import jax +import jax.numpy as jnp +import jax.random as jr +import jax.tree_util as jtu +import lineax as lx +import pytest +from diffrax import ControlTerm, MultiTerm, ODETerm, WeaklyDiagonalControlTerm + + +def _solvers(): + yield diffrax.SPaRK + yield diffrax.GeneralShARK + yield diffrax.SlowRK + yield diffrax.ShARK + yield diffrax.SRA1 + yield diffrax.SEA + + +# Define the SDE +def dict_drift(t, y, args): + pytree, _ = args + return jtu.tree_map(lambda _, x: -0.5 * x, pytree, y) + + +def dict_diffusion(t, y, args): + pytree, additive = args + + def get_matrix(y_leaf): + if additive: + return 2.0 * jnp.ones(y_leaf.shape + (3,), dtype=jnp.float64) + else: + return 2.0 * jnp.broadcast_to( + jnp.expand_dims(y_leaf, axis=y_leaf.ndim), y_leaf.shape + (3,) + ) + + return jtu.tree_map(get_matrix, y) + + +@pytest.mark.parametrize("shape", [(), (5, 2)]) +@pytest.mark.parametrize("solver_ctr", _solvers()) +@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) +def test_sde_solver_shape(shape, solver_ctr, dtype): + pytree = ({"a": 0, "b": [0, 0]}, 0, 0) + key = jr.PRNGKey(0) + y0 = jtu.tree_map(lambda _: jr.normal(key, shape, dtype=dtype), pytree) + t0, t1, dt0 = 0.0, 1.0, 0.3 + + # Some solvers only work with additive noise + additive = solver_ctr in [diffrax.ShARK, diffrax.SRA1, diffrax.SEA] + args = (pytree, additive) + solver = solver_ctr() + bmkey = jr.key(1) + struct = jax.ShapeDtypeStruct((3,), dtype) + bm_shape = jtu.tree_map(lambda _: struct, pytree) + bm = diffrax.VirtualBrownianTree( + t0, t1, 0.1, bm_shape, bmkey, diffrax.SpaceTimeLevyArea + ) + terms = MultiTerm(ODETerm(dict_drift), ControlTerm(dict_diffusion, bm)) + solution = diffrax.diffeqsolve( + terms, solver, t0, t1, dt0, y0, args, saveat=diffrax.SaveAt(t1=True) + ) + assert jtu.tree_structure(solution.ys) == jtu.tree_structure(y0) + for leaf in jtu.tree_leaves(solution.ys): + assert leaf[0].shape == shape + + +def _weakly_diagonal_noise_helper(solver, dtype): + w_shape = (3,) + args = (0.5, 1.2) + + def _diffusion(t, y, args): + a, b = args + return jnp.array([b, t, 1 / (t + 1.0)], dtype=dtype) + + def _drift(t, y, args): + a, b = args + return -a * y + + y0 = jnp.ones(w_shape, dtype) + + bm = diffrax.VirtualBrownianTree( + 0.0, 1.0, 0.05, w_shape, jr.key(0), diffrax.SpaceTimeLevyArea + ) + + terms = MultiTerm(ODETerm(_drift), WeaklyDiagonalControlTerm(_diffusion, bm)) + saveat = diffrax.SaveAt(t1=True) + solution = diffrax.diffeqsolve( + terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat + ) + assert solution.ys is not None + assert solution.ys.shape == (1, 3) + + +def _lineax_weakly_diagonal_noise_helper(solver, dtype): + w_shape = (3,) + args = (0.5, 1.2) + + def _diffusion(t, y, args): + a, b = args + return lx.DiagonalLinearOperator(jnp.array([b, t, 1 / (t + 1.0)], dtype=dtype)) + + def _drift(t, y, args): + a, b = args + return -a * y + + y0 = jnp.ones(w_shape, dtype) + + bm = diffrax.VirtualBrownianTree( + 0.0, 1.0, 0.05, w_shape, jr.PRNGKey(0), diffrax.SpaceTimeLevyArea + ) + + terms = MultiTerm(ODETerm(_drift), ControlTerm(_diffusion, bm)) + saveat = diffrax.SaveAt(t1=True) + solution = diffrax.diffeqsolve( + terms, solver, 0.0, 1.0, 0.1, y0, args, saveat=saveat + ) + assert solution.ys is not None + assert solution.ys.shape == (1, 3) + + +@pytest.mark.parametrize("solver_ctr", _solvers()) +@pytest.mark.parametrize( + "dtype", + (jnp.float64, jnp.complex128), +) +@pytest.mark.parametrize( + "weak_type", + ("old", "lineax"), +) +def test_weakly_diagonal_noise(solver_ctr, dtype, weak_type): + if weak_type == "old": + _weakly_diagonal_noise_helper(solver_ctr(), dtype) + elif weak_type == "lineax": + _lineax_weakly_diagonal_noise_helper(solver_ctr(), dtype) + else: + raise ValueError("Invalid weak_type") + + +@pytest.mark.parametrize( + "dtype", + (jnp.float64, jnp.complex128), +) +@pytest.mark.parametrize( + "weak_type", + ("old", "lineax"), +) +def test_halfsolver_term_compatible(dtype, weak_type): + if weak_type == "old": + _weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype) + elif weak_type == "lineax": + _lineax_weakly_diagonal_noise_helper(diffrax.HalfSolver(diffrax.SPaRK()), dtype) + else: + raise ValueError("Invalid weak_type")