Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Allow float in interpolate_by by column #18015

Merged
merged 11 commits into from
Aug 18, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ fn interpolate_impl_by_sorted<T, F, I>(
) -> PolarsResult<ChunkedArray<T>>
where
T: PolarsNumericType,
F: PolarsIntegerType,
F: PolarsNumericType,
I: Fn(T::Native, T::Native, &[F::Native], &mut Vec<T::Native>),
{
// This implementation differs from pandas as that boundary None's are not removed.
Expand Down Expand Up @@ -169,7 +169,7 @@ fn interpolate_impl_by<T, F, I>(
) -> PolarsResult<ChunkedArray<T>>
where
T: PolarsNumericType,
F: PolarsIntegerType,
F: PolarsNumericType,
I: Fn(T::Native, T::Native, &[F::Native], &mut [T::Native], &[IdxSize]),
{
// This implementation differs from pandas as that boundary None's are not removed.
Expand Down Expand Up @@ -273,7 +273,7 @@ pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResu
) -> PolarsResult<Series>
where
T: PolarsNumericType,
F: PolarsIntegerType,
F: PolarsNumericType,
ChunkedArray<T>: IntoSeries,
{
if is_sorted {
Expand All @@ -290,6 +290,18 @@ pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResu
}

match (s.dtype(), by.dtype()) {
(DataType::Float64, DataType::Float64) => {
func(s.f64().unwrap(), by.f64().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::Float32) => {
func(s.f64().unwrap(), by.f32().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::Float64) => {
func(s.f32().unwrap(), by.f64().unwrap(), by_is_sorted)
},
(DataType::Float32, DataType::Float32) => {
func(s.f32().unwrap(), by.f32().unwrap(), by_is_sorted)
},
(DataType::Float64, DataType::Int64) => {
func(s.f64().unwrap(), by.i64().unwrap(), by_is_sorted)
},
Expand Down Expand Up @@ -326,7 +338,7 @@ pub fn interpolate_by(s: &Series, by: &Series, by_is_sorted: bool) -> PolarsResu
_ => {
polars_bail!(InvalidOperation: "expected series to be Float64, Float32, \
Int64, Int32, UInt64, UInt32, and `by` to be Date, Datetime, Int64, Int32, \
UInt64, or UInt32")
UInt64, UInt32, Float32 or Float64")
},
}
}
85 changes: 65 additions & 20 deletions py-polars/tests/unit/operations/test_interpolate_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
pl.Int32,
pl.UInt64,
pl.UInt32,
pl.Float32,
pl.Float64,
],
)
@pytest.mark.parametrize(
Expand Down Expand Up @@ -116,22 +118,42 @@ def test_interpolate_by_leading_nulls() -> None:
assert_frame_equal(result, expected)


def test_interpolate_by_trailing_nulls() -> None:
df = pl.DataFrame(
{
"times": [
date(2020, 1, 1),
date(2020, 1, 3),
date(2020, 1, 10),
date(2020, 1, 11),
date(2020, 1, 12),
date(2020, 1, 13),
],
"values": [1, None, None, 5, None, None],
}
)
@pytest.mark.parametrize("dataset", ["floats", "dates"])
def test_interpolate_by_trailing_nulls(dataset: str) -> None:
input_data = {
"dates": pl.DataFrame(
{
"times": [
date(2020, 1, 1),
date(2020, 1, 3),
date(2020, 1, 10),
date(2020, 1, 11),
date(2020, 1, 12),
date(2020, 1, 13),
],
"values": [1, None, None, 5, None, None],
}
),
"floats": pl.DataFrame(
{
"times": [0.2, 0.4, 0.5, 0.6, 0.9, 1.1],
"values": [1, None, None, 5, None, None],
}
),
}

expected_data = {
"dates": pl.DataFrame(
{"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]}
),
"floats": pl.DataFrame({"values": [1.0, 3.0, 4.0, 5.0, None, None]}),
}

df = input_data[dataset]
expected = expected_data[dataset]

result = df.select(pl.col("values").interpolate_by("times"))
expected = pl.DataFrame({"values": [1.0, 1.7999999999999998, 4.6, 5.0, None, None]})

assert_frame_equal(result, expected)
result = (
df.sort("times", descending=True)
Expand All @@ -142,16 +164,28 @@ def test_interpolate_by_trailing_nulls() -> None:
assert_frame_equal(result, expected)


@given(data=st.data())
def test_interpolate_vs_numpy(data: st.DataObject) -> None:
@given(data=st.data(), x_dtype=st.sampled_from([pl.Date, pl.Float64]))
def test_interpolate_vs_numpy(data: st.DataObject, x_dtype: pl.DataType) -> None:
if x_dtype == pl.Float64:
by_strategy = st.floats(
min_value=-1e150,
max_value=1e150,
allow_nan=False,
allow_infinity=False,
allow_subnormal=False,
)
else:
by_strategy = None

dataframe = (
data.draw(
dataframes(
[
column(
"ts",
dtype=pl.Date,
dtype=x_dtype,
allow_null=False,
strategy=by_strategy,
),
column(
"value",
Expand All @@ -166,13 +200,24 @@ def test_interpolate_vs_numpy(data: st.DataObject) -> None:
.fill_nan(None)
.unique("ts")
)

if x_dtype == pl.Float64:
assume(not dataframe["ts"].is_nan().any())
assume(not dataframe["ts"].is_null().any())
assume(not dataframe["ts"].is_in([float("-inf"), float("inf")]).any())

assume(not dataframe["value"].is_null().all())
assume(not dataframe["value"].is_in([float("-inf"), float("inf")]).any())

dataframe = dataframe.sort("ts")

result = dataframe.select(pl.col("value").interpolate_by("ts"))["value"]

mask = dataframe["value"].is_not_null()
x = dataframe["ts"].to_numpy().astype("int64")
xp = dataframe["ts"].filter(mask).to_numpy().astype("int64")

np_dtype = "int64" if x_dtype == pl.Date else "float64"
x = dataframe["ts"].to_numpy().astype(np_dtype)
xp = dataframe["ts"].filter(mask).to_numpy().astype(np_dtype)
yp = dataframe["value"].filter(mask).to_numpy().astype("float64")
interp = np.interp(x, xp, yp)
# Polars preserves nulls on boundaries, but NumPy doesn't.
Expand Down