Skip to content

Commit

Permalink
Fix backward_hermite_coefficients to work in the case where ys=[]. Ad…
Browse files Browse the repository at this point in the history
…d unit test case to test_interpolation_classes
  • Loading branch information
allen-adastra authored and patrick-kidger committed Jan 29, 2024
1 parent 984d55c commit 8aafefc
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 6 deletions.
3 changes: 2 additions & 1 deletion diffrax/_global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import ClassVar as AbstractVar
else:
from equinox import AbstractVar

from equinox.internal import ω
from jaxtyping import Array, ArrayLike, PyTree, Real, Shaped

Expand Down Expand Up @@ -760,7 +761,7 @@ def backward_hermite_coefficients(
fn = ft.partial(_backward_hermite_coefficients, fill_forward_nans_at_end, ts)

if len(jtu.tree_leaves(ys)) == 0:
return ({}, {}, {}, {})
return tuple(ys for _ in range(4))

if deriv0 is None:
if replace_nans_at_start is None:
Expand Down
6 changes: 1 addition & 5 deletions test/test_global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,7 @@ def test_interpolation_classes(mode, getkey):
jnp.array([0.0, 2.0, 3.0, 3.1, 4.0, 4.1, 5.0, 5.1]),
]
_make = lambda: jr.normal(getkey(), (length, num_channels))
ys_ = [
_make(),
[_make(), {"a": _make(), "b": _make()}],
{}
]
ys_ = [_make(), [_make(), {"a": _make(), "b": _make()}], {}, []]
for ts in ts_:
assert len(ts) == length
for ys in ys_:
Expand Down

0 comments on commit 8aafefc

Please sign in to comment.