Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Expressify list.shift #11320

Merged
merged 2 commits into from
Sep 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions crates/polars-core/src/chunked_array/list/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,47 @@ impl ListChunked {
unsafe { self.amortized_iter().for_each(f) }
}

/// Zip with a `ChunkedArray` then apply a binary function `F` elementwise.
#[must_use]
pub fn zip_and_apply_amortized<'a, T, I, F>(&'a self, ca: &'a ChunkedArray<T>, mut f: F) -> Self
Copy link
Collaborator Author

@reswqa reswqa Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what name would be more appropriate, but I assume that having a safe version would be good.

where
T: PolarsDataType,
&'a ChunkedArray<T>: IntoIterator<IntoIter = I>,
I: TrustedLen<Item = Option<T::Physical<'a>>>,
F: FnMut(Option<UnstableSeries<'a>>, Option<T::Physical<'a>>) -> Option<Series>,
{
if self.is_empty() {
return self.clone();
}
let mut fast_explode = self.null_count() == 0;
// SAFETY: unstable series never lives longer than the iterator.
let mut out: ListChunked = unsafe {
self.amortized_iter()
.zip(ca)
.map(|(opt_s, opt_v)| {
let out = f(opt_s, opt_v);
match out {
Some(out) if out.is_empty() => {
fast_explode = false;
Some(out)
},
None => {
fast_explode = false;
out
},
_ => out,
}
})
.collect_trusted()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should make this a bit more resilient to ambiguous dtypes. We can use collect_ca_trusted_with_dtype and accept the list inner type as argument of this function and then construct the DataType::List<inner> with that.

In a later PR we can do the same with apply_amortized.

Copy link
Collaborator Author

@reswqa reswqa Sep 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pub fn zip_and_apply_amortized<'a, T, I, F>(&'a self, ca: &'a ChunkedArray<T>, mut f: F) -> Self

This function will return Self, maybe we don't need accept an inner type as argument, but pass self.dtype().clone() to collect_ca_trusted_with_dtype directly? What do you think?

In a later PR we can do the same with apply_amortized.

I tend do this for apply_amortized also in this PR, after that, I think we can remove the call to Ok(self.same_type(out)) in lst_shift safely.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function will return Self, maybe we don't need accept an inner type as argument, but pass self.dtype().clone() to collect_ca_trusted_with_dtype directly? What do you think?

Yes, that would be sufficient.

Copy link
Collaborator Author

@reswqa reswqa Sep 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bad news: It seems that our implementation of collect_ca_trusted_with_dtype will forward to collect_arr_trusted_with_dtype and just wrap this single Array to ChunkedArray.

So it requires ArrayFromIterDtype of Option<Series> or Series for ListArray(Because the closure of apply_amortized is F: FnMut(UnstableSeries<'a>) -> Series), but we don't have this(Perhaps we don't want to implement it either, after all, Series may contain multiple arrays if not re-chunk).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let's leave it for now. Maybe we can take a look later. Want to have this in the release. :)

};

out.rename(self.name());
if fast_explode {
out.set_fast_explode();
}
out
}

/// Apply a closure `F` elementwise.
#[must_use]
pub fn apply_amortized<'a, F>(&'a self, mut f: F) -> Self
Expand Down
22 changes: 19 additions & 3 deletions crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,26 @@ pub trait ListNameSpaceImpl: AsList {
ca.try_apply_amortized(|s| s.as_ref().diff(n, null_behavior))
}

fn lst_shift(&self, periods: i64) -> ListChunked {
fn lst_shift(&self, periods: &Series) -> PolarsResult<ListChunked> {
let ca = self.as_list();
let out = ca.apply_amortized(|s| s.as_ref().shift(periods));
self.same_type(out)
let periods_s = periods.cast(&DataType::Int64)?;
let periods = periods_s.i64()?;
let out = match periods.len() {
1 => {
if let Some(periods) = periods.get(0) {
ca.apply_amortized(|s| s.as_ref().shift(periods))
} else {
ListChunked::full_null_with_dtype(ca.name(), ca.len(), &ca.inner_dtype())
}
},
_ => ca.zip_and_apply_amortized(periods, |opt_s, opt_periods| {
match (opt_s, opt_periods) {
(Some(s), Some(periods)) => Some(s.as_ref().shift(periods)),
_ => None,
}
}),
};
Ok(self.same_type(out))
}

fn lst_slice(&self, offset: i64, length: usize) -> ListChunked {
Expand Down
9 changes: 9 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub enum ListFunction {
#[cfg(feature = "list_drop_nulls")]
DropNulls,
Slice,
Shift,
Get,
#[cfg(feature = "list_take")]
Take(bool),
Expand Down Expand Up @@ -45,6 +46,7 @@ impl Display for ListFunction {
#[cfg(feature = "list_drop_nulls")]
DropNulls => "drop_nulls",
Slice => "slice",
Shift => "shift",
Get => "get",
#[cfg(feature = "list_take")]
Take(_) => "take",
Expand Down Expand Up @@ -104,6 +106,13 @@ fn check_slice_arg_shape(slice_len: usize, ca_len: usize, name: &str) -> PolarsR
Ok(())
}

pub(super) fn shift(s: &[Series]) -> PolarsResult<Series> {
let list = s[0].list()?;
let periods = &s[1];

list.lst_shift(periods).map(|ok| ok.into_series())
}

pub(super) fn slice(args: &mut [Series]) -> PolarsResult<Option<Series>> {
let s = &args[0];
let list_ca = s.list()?;
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "list_drop_nulls")]
DropNulls => map!(list::drop_nulls),
Slice => wrap!(list::slice),
Shift => map_as_slice!(list::shift),
Get => wrap!(list::get),
#[cfg(feature = "list_take")]
Take(null_ob_oob) => map_as_slice!(list::take, null_ob_oob),
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ impl FunctionExpr {
#[cfg(feature = "list_drop_nulls")]
DropNulls => mapper.with_same_dtype(),
Slice => mapper.with_same_dtype(),
Shift => mapper.with_same_dtype(),
Get => mapper.map_to_list_inner_dtype(),
#[cfg(feature = "list_take")]
Take(_) => mapper.with_same_dtype(),
Expand Down
14 changes: 7 additions & 7 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ impl ListNameSpace {
}

/// Shift every sublist.
pub fn shift(self, periods: i64) -> Expr {
self.0
.map(
move |s| Ok(Some(s.list()?.lst_shift(periods).into_series())),
GetOutput::same_type(),
)
.with_fmt("list.shift")
pub fn shift(self, periods: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Shift),
&[periods],
false,
false,
)
}

/// Slice every sublist.
Expand Down
10 changes: 8 additions & 2 deletions py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from datetime import date, datetime, time

from polars import Expr, Series
from polars.type_aliases import IntoExpr, NullBehavior, ToStructStrategy
from polars.type_aliases import (
IntoExpr,
IntoExprColumn,
NullBehavior,
ToStructStrategy,
)


class ExprListNameSpace:
Expand Down Expand Up @@ -642,7 +647,7 @@ def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Expr:
"""
return wrap_expr(self._pyexpr.list_diff(n, null_behavior))

def shift(self, periods: int = 1) -> Expr:
def shift(self, periods: int | IntoExprColumn = 1) -> Expr:
"""
Shift values by the given period.

Expand All @@ -663,6 +668,7 @@ def shift(self, periods: int = 1) -> Expr:
]

"""
periods = parse_as_expression(periods)
return wrap_expr(self._pyexpr.list_shift(periods))

def slice(
Expand Down
9 changes: 7 additions & 2 deletions py-polars/polars/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from polars import Expr, Series
from polars.polars import PySeries
from polars.type_aliases import IntoExpr, NullBehavior, ToStructStrategy
from polars.type_aliases import (
IntoExpr,
IntoExprColumn,
NullBehavior,
ToStructStrategy,
)


@expr_dispatch
Expand Down Expand Up @@ -333,7 +338,7 @@ def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Series:

"""

def shift(self, periods: int = 1) -> Series:
def shift(self, periods: int | IntoExprColumn = 1) -> Series:
"""
Shift values by the given period.

Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ impl PyExpr {
self.inner.clone().list().reverse().into()
}

fn list_shift(&self, periods: i64) -> Self {
self.inner.clone().list().shift(periods).into()
fn list_shift(&self, periods: PyExpr) -> Self {
self.inner.clone().list().shift(periods.inner).into()
}

fn list_slice(&self, offset: PyExpr, length: Option<PyExpr>) -> Self {
Expand Down
26 changes: 26 additions & 0 deletions py-polars/tests/unit/namespaces/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,32 @@ def test_list_shift() -> None:
expected = pl.Series("a", [[None, 1], [None, 3, 2]])
assert s.list.shift().to_list() == expected.to_list()

df = pl.DataFrame(
{
"values": [
[1, 2, None],
[1, 2, 3],
[None, 1, 2],
[None, None, None],
[1, 2],
],
"shift": [1, -2, 3, 2, None],
}
)
df = df.select(pl.col("values").list.shift(pl.col("shift")))
expected_df = pl.DataFrame(
{
"values": [
[None, 1, 2],
[3, None, None],
[None, None, None],
[None, None, None],
None,
]
}
)
assert_frame_equal(df, expected_df)


def test_list_drop_nulls() -> None:
s = pl.Series("values", [[1, None, 2, None], [None, None], [1, 2], None])
Expand Down