Skip to content

Commit

Permalink
fixup! Fixed a major source of bugs: ControlTerms no longer broadcast.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jan 13, 2025
1 parent a828e5b commit 41b3cf1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
3 changes: 3 additions & 0 deletions diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,9 @@ def vector_field(t, y, args):
```
""" # noqa: E501

vector_field: Callable[[RealScalarLike, Y, Args], _VF]
control: AbstractPath[_Control]

def __init__(
self,
vector_field: Callable[[RealScalarLike, Y, Args], _VF],
Expand Down
3 changes: 2 additions & 1 deletion test/test_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def derivative(self, t, left=True):
return jr.normal(derivkey, (3,))

control = Control()
term = diffrax.WeaklyDiagonalControlTerm(vector_field, control)
with pytest.warns(match="`WeaklyDiagonalControlTerm` is now deprecated"):
term = diffrax.WeaklyDiagonalControlTerm(vector_field, control)
args = getkey()
dx = term.contr(0, 1)
y = jnp.array([1.0, 2.0, 3.0])
Expand Down

0 comments on commit 41b3cf1

Please sign in to comment.