From 41b3cf177f8f8145dce65a7ae166cda6cbce4314 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 13 Jan 2025 18:31:58 +0100 Subject: [PATCH] fixup! Fixed a major source of bugs: ControlTerms no longer broadcast. --- diffrax/_term.py | 3 +++ test/test_term.py | 3 ++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/diffrax/_term.py b/diffrax/_term.py index 4d215d58..191b179b 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -404,6 +404,9 @@ def vector_field(t, y, args): ``` """ # noqa: E501 + vector_field: Callable[[RealScalarLike, Y, Args], _VF] + control: AbstractPath[_Control] + def __init__( self, vector_field: Callable[[RealScalarLike, Y, Args], _VF], diff --git a/test/test_term.py b/test/test_term.py index adef3635..a0477a85 100644 --- a/test/test_term.py +++ b/test/test_term.py @@ -85,7 +85,8 @@ def derivative(self, t, left=True): return jr.normal(derivkey, (3,)) control = Control() - term = diffrax.WeaklyDiagonalControlTerm(vector_field, control) + with pytest.warns(match="`WeaklyDiagonalControlTerm` is now deprecated"): + term = diffrax.WeaklyDiagonalControlTerm(vector_field, control) args = getkey() dx = term.contr(0, 1) y = jnp.array([1.0, 2.0, 3.0])