diff --git a/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs b/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs index 674cbab514e9..c77d2ad6f157 100644 --- a/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs +++ b/crates/polars-ops/src/series/ops/interpolation/interpolate_by.rs @@ -87,7 +87,7 @@ fn interpolate_impl_by_sorted( ) -> PolarsResult> where T: PolarsNumericType, - F: PolarsIntegerType, + F: PolarsNumericType, I: Fn(T::Native, T::Native, &[F::Native], &mut Vec), { // This implementation differs from pandas as that boundary None's are not removed. @@ -169,7 +169,7 @@ fn interpolate_impl_by( ) -> PolarsResult> where T: PolarsNumericType, - F: PolarsIntegerType, + F: PolarsNumericType, I: Fn(T::Native, T::Native, &[F::Native], &mut [T::Native], &[IdxSize]), { // This implementation differs from pandas as that boundary None's are not removed. @@ -273,7 +273,7 @@ pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResu ) -> PolarsResult where T: PolarsNumericType, - F: PolarsIntegerType, + F: PolarsNumericType, ChunkedArray: IntoSeries, { if is_sorted { @@ -290,6 +290,18 @@ pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResu } match (s.dtype(), by.dtype()) { + (DataType::Float64, DataType::Float64) => { + func(s.f64().unwrap(), by.f64().unwrap(), by_is_sorted) + }, + (DataType::Float64, DataType::Float32) => { + func(s.f64().unwrap(), by.f32().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::Float64) => { + func(s.f32().unwrap(), by.f64().unwrap(), by_is_sorted) + }, + (DataType::Float32, DataType::Float32) => { + func(s.f32().unwrap(), by.f32().unwrap(), by_is_sorted) + }, (DataType::Float64, DataType::Int64) => { func(s.f64().unwrap(), by.i64().unwrap(), by_is_sorted) }, @@ -326,7 +338,7 @@ pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResu _ => { polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \ Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \ - UInt64, or UInt32") + UInt64, UInt32, Float32 or Float64") }, } } diff --git a/py-polars/tests/unit/operations/test_interpolate_by.py b/py-polars/tests/unit/operations/test_interpolate_by.py index 423992abeadd..98ee656fdaed 100644 --- a/py-polars/tests/unit/operations/test_interpolate_by.py +++ b/py-polars/tests/unit/operations/test_interpolate_by.py @@ -28,6 +28,8 @@ pl.Int32, pl.UInt64, pl.UInt32, + pl.Float32, + pl.Float64, ], ) @pytest.mark.parametrize( @@ -116,22 +118,42 @@ def test_interpolate_by_leading_nulls() -> None: assert_frame_equal(result, expected) -def test_interpolate_by_trailing_nulls() -> None: - df = pl.DataFrame( - { - "times": [ - date(2020, 1, 1), - date(2020, 1, 3), - date(2020, 1, 10), - date(2020, 1, 11), - date(2020, 1, 12), - date(2020, 1, 13), - ], - "values": [1, None, None, 5, None, None], - } - ) +@pytest.mark.parametrize("dataset", ["floats", "dates"]) +def test_interpolate_by_trailing_nulls(dataset: str) -> None: + input_data = { + "dates": pl.DataFrame( + { + "times": [ + date(2020, 1, 1), + date(2020, 1, 3), + date(2020, 1, 10), + date(2020, 1, 11), + date(2020, 1, 12), + date(2020, 1, 13), + ], + "values": [1, None, None, 5, None, None], + } + ), + "floats": pl.DataFrame( + { + "times": [0.2, 0.4, 0.5, 0.6, 0.9, 1.1], + "values": [1, None, None, 5, None, None], + } + ), + } + + expected_data = { + "dates": pl.DataFrame( + {"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]} + ), + "floats": pl.DataFrame({"values": [1.0, 3.0, 4.0, 5.0, None, None]}), + } + + df = input_data[dataset] + expected = expected_data[dataset] + result = df.select(pl.col("values").interpolate_by("times")) - expected = pl.DataFrame({"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]}) + assert_frame_equal(result, expected) result = ( df.sort("times", descending=True) @@ -142,16 +164,28 @@ def test_interpolate_by_trailing_nulls() -> None: assert_frame_equal(result, expected) -@given(data=st.data()) -def test_interpolate_vs_numpy(data: st.DataObject) -> None: +@given(data=st.data(), x_dtype=st.sampled_from([pl.Date, pl.Float64])) +def test_interpolate_vs_numpy(data: st.DataObject, x_dtype: pl.DataType) -> None: + if x_dtype == pl.Float64: + by_strategy = st.floats( + min_value=-1e150, + max_value=1e150, + allow_nan=False, + allow_infinity=False, + allow_subnormal=False, + ) + else: + by_strategy = None + dataframe = ( data.draw( dataframes( [ column( "ts", - dtype=pl.Date, + dtype=x_dtype, allow_null=False, + strategy=by_strategy, ), column( "value", @@ -166,13 +200,24 @@ def test_interpolate_vs_numpy(data: st.DataObject) -> None: .fill_nan(None) .unique("ts") ) + + if x_dtype == pl.Float64: + assume(not dataframe["ts"].is_nan().any()) + assume(not dataframe["ts"].is_null().any()) + assume(not dataframe["ts"].is_in([float("-inf"), float("inf")]).any()) + assume(not dataframe["value"].is_null().all()) assume(not dataframe["value"].is_in([float("-inf"), float("inf")]).any()) + + dataframe = dataframe.sort("ts") + result = dataframe.select(pl.col("value").interpolate_by("ts"))["value"] mask = dataframe["value"].is_not_null() - x = dataframe["ts"].to_numpy().astype("int64") - xp = dataframe["ts"].filter(mask).to_numpy().astype("int64") + + np_dtype = "int64" if x_dtype == pl.Date else "float64" + x = dataframe["ts"].to_numpy().astype(np_dtype) + xp = dataframe["ts"].filter(mask).to_numpy().astype(np_dtype) yp = dataframe["value"].filter(mask).to_numpy().astype("float64") interp = np.interp(x, xp, yp) # Polars preserves nulls on boundaries, but NumPy doesn't.