From 8aafefcd081d51240ac865662a1767ca999d2455 Mon Sep 17 00:00:00 2001 From: Allen Wang Date: Fri, 26 Jan 2024 13:55:50 -0500 Subject: [PATCH] Fix backward_hermite_coefficients to work in the case where ys=[]. Add unit test case to test_interpolation_classes --- diffrax/_global_interpolation.py | 3 ++- test/test_global_interpolation.py | 6 +----- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/diffrax/_global_interpolation.py b/diffrax/_global_interpolation.py index c7ef513f..f10a21fd 100644 --- a/diffrax/_global_interpolation.py +++ b/diffrax/_global_interpolation.py @@ -14,6 +14,7 @@ from typing import ClassVar as AbstractVar else: from equinox import AbstractVar + from equinox.internal import ω from jaxtyping import Array, ArrayLike, PyTree, Real, Shaped @@ -760,7 +761,7 @@ def backward_hermite_coefficients( fn = ft.partial(_backward_hermite_coefficients, fill_forward_nans_at_end, ts) if len(jtu.tree_leaves(ys)) == 0: - return ({}, {}, {}, {}) + return tuple(ys for _ in range(4)) if deriv0 is None: if replace_nans_at_start is None: diff --git a/test/test_global_interpolation.py b/test/test_global_interpolation.py index 7adc5151..5103000f 100644 --- a/test/test_global_interpolation.py +++ b/test/test_global_interpolation.py @@ -258,11 +258,7 @@ def test_interpolation_classes(mode, getkey): jnp.array([0.0, 2.0, 3.0, 3.1, 4.0, 4.1, 5.0, 5.1]), ] _make = lambda: jr.normal(getkey(), (length, num_channels)) - ys_ = [ - _make(), - [_make(), {"a": _make(), "b": _make()}], - {} - ] + ys_ = [_make(), [_make(), {"a": _make(), "b": _make()}], {}, []] for ts in ts_: assert len(ts) == length for ys in ys_: