From d6d09dcb172955cee121447932ab080dbc5bdd84 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 29 Jun 2024 15:52:00 +0200 Subject: [PATCH] Tweaked warnings --- diffrax/_integrate.py | 1 - diffrax/_step_size_controller/adaptive.py | 2 +- test/test_event.py | 6 +++--- test/test_term.py | 5 +---- 4 files changed, 5 insertions(+), 9 deletions(-) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index 4d08f855..615eafe9 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -927,7 +927,6 @@ def diffeqsolve( "`diffrax.diffeqsolve(..., discrete_terminating_event=...)` is deprecated " "in favour of the more general `diffrax.diffeqsolve(..., event=...)` " "interface. This will be removed in some future version of Diffrax.", - category=DeprecationWarning, stacklevel=2, ) if event is None: diff --git a/diffrax/_step_size_controller/adaptive.py b/diffrax/_step_size_controller/adaptive.py index eb14d712..9d181c95 100644 --- a/diffrax/_step_size_controller/adaptive.py +++ b/diffrax/_step_size_controller/adaptive.py @@ -610,7 +610,7 @@ def _scale(_y0, _y1_candidate, _y_error): # a grad API boundary as part of a larger model.) factor = lax.stop_gradient(factor) factor = eqxi.nondifferentiable(factor) - dt = prev_dt * factor.astype(prev_dt) + dt = prev_dt * factor.astype(jnp.result_type(prev_dt)) # E.g. we failed an implicit step, so y_error=inf, so inv_scaled_error=0, # so factor=factormin, and we shrunk our step. diff --git a/test/test_event.py b/test/test_event.py index dbfcd255..80f0102c 100644 --- a/test/test_event.py +++ b/test/test_event.py @@ -24,7 +24,7 @@ def event_fn(state, **kwargs): return state.tprev > 10 event = diffrax.DiscreteTerminatingEvent(event_fn) - with pytest.warns(DeprecationWarning, match="discrete_terminating_event"): + with pytest.warns(match="discrete_terminating_event"): sol = diffrax.diffeqsolve( term, solver, @@ -51,7 +51,7 @@ def event_fn(state, **kwargs): return state.tprev > 10 event = diffrax.DiscreteTerminatingEvent(event_fn) - with pytest.warns(DeprecationWarning, match="discrete_terminating_event"): + with pytest.warns(match="discrete_terminating_event"): sol = diffrax.diffeqsolve( term, solver, @@ -82,7 +82,7 @@ def event_fn(state, **kwargs): @jax.jit @jax.grad def run(y0): - with pytest.warns(DeprecationWarning, match="discrete_terminating_event"): + with pytest.warns(match="discrete_terminating_event"): sol = diffrax.diffeqsolve( term, solver, diff --git a/test/test_term.py b/test/test_term.py index 05f858fe..5260db2c 100644 --- a/test/test_term.py +++ b/test/test_term.py @@ -154,10 +154,7 @@ def __call__(self, t, y, args): def test_weaklydiagonal_deprecate(): - with pytest.warns( - DeprecationWarning, - match="WeaklyDiagonalControlTerm is pending deprecation", - ): + with pytest.warns(match="WeaklyDiagonalControlTerm"): _ = diffrax.WeaklyDiagonalControlTerm( lambda t, y, args: 0.0, lambda t0, t1: jnp.array(t1 - t0) )