diff --git a/py-polars/tests/unit/operations/test_interpolate_by.py b/py-polars/tests/unit/operations/test_interpolate_by.py index 4161a2f1c04c..7e672890cbc4 100644 --- a/py-polars/tests/unit/operations/test_interpolate_by.py +++ b/py-polars/tests/unit/operations/test_interpolate_by.py @@ -118,8 +118,10 @@ def test_interpolate_by_leading_nulls() -> None: assert_frame_equal(result, expected) -def test_interpolate_by_trailing_nulls() -> None: - df = pl.DataFrame( +@pytest.mark.parametrize("dataset", ["floats", "dates"]) +def test_interpolate_by_trailing_nulls(dataset: str) -> None: + input_data = { + "dates": pl.DataFrame( { "times": [ date(2020, 1, 1), @@ -130,10 +132,25 @@ def test_interpolate_by_trailing_nulls() -> None: date(2020, 1, 13), ], "values": [1, None, None, 5, None, None], + }), + "floats": pl.DataFrame( + { + "times": [0.2, 0.4, 0.5, 0.6, 0.9, 1.1], + "values": [1, None, None, 5, None, None], } - ) + ) + } + + expected_data = { + "dates": pl.DataFrame({"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]}), + "floats": pl.DataFrame({"values": [1.0, 3.0, 4.0, 5.0, None, None]}) + } + + df = input_data[dataset] + expected = expected_data[dataset] + result = df.select(pl.col("values").interpolate_by("times")) - expected = pl.DataFrame({"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]}) + assert_frame_equal(result, expected) result = ( df.sort("times", descending=True)