Skip to content

Commit

Permalink
feat: Expressify list.shift
Browse files Browse the repository at this point in the history
  • Loading branch information
reswqa committed Sep 26, 2023
1 parent 11c23b0 commit 8707b65
Show file tree
Hide file tree
Showing 10 changed files with 127 additions and 16 deletions.
41 changes: 41 additions & 0 deletions crates/polars-core/src/chunked_array/list/iterator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,47 @@ impl ListChunked {
unsafe { self.amortized_iter().for_each(f) }
}

/// Zip with a `ChunkedArray` then apply a binary function `F` elementwise.
#[must_use]
pub fn zip_and_apply_amortized<'a, T, I, F>(&'a self, ca: &'a ChunkedArray<T>, mut f: F) -> Self
where
T: PolarsDataType,
&'a ChunkedArray<T>: IntoIterator<IntoIter = I>,
I: TrustedLen<Item = Option<T::Physical<'a>>>,
F: FnMut(Option<UnstableSeries<'a>>, Option<T::Physical<'a>>) -> Option<Series>,
{
if self.is_empty() {
return self.clone();
}
let mut fast_explode = self.null_count() == 0;
// SAFETY: unstable series never lives longer than the iterator.
let mut out: ListChunked = unsafe {
self.amortized_iter()
.zip(ca)
.map(|(opt_s, opt_v)| {
let out = f(opt_s, opt_v);
match out {
Some(out) if out.is_empty() => {
fast_explode = false;
Some(out)
},
None => {
fast_explode = false;
out
},
_ => out,
}
})
.collect_trusted()
};

out.rename(self.name());
if fast_explode {
out.set_fast_explode();
}
out
}

/// Apply a closure `F` elementwise.
#[must_use]
pub fn apply_amortized<'a, F>(&'a self, mut f: F) -> Self
Expand Down
28 changes: 25 additions & 3 deletions crates/polars-ops/src/chunked_array/list/namespace.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,32 @@ pub trait ListNameSpaceImpl: AsList {
ca.try_apply_amortized(|s| s.as_ref().diff(n, null_behavior))
}

fn lst_shift(&self, periods: i64) -> ListChunked {
fn lst_shift(&self, periods: &Series) -> PolarsResult<ListChunked> {
let ca = self.as_list();
let out = ca.apply_amortized(|s| s.as_ref().shift(periods));
self.same_type(out)

match periods.len() {
1 => {
if let Some(periods) = periods.get(0)?.extract::<i64>() {
Ok(ca.apply_amortized(|s| s.as_ref().shift(periods)))
} else {
Ok(ListChunked::full_null_with_dtype(
ca.name(),
ca.len(),
&ca.inner_dtype(),
))
}
},
_ => {
let periods_s = periods.cast(&DataType::Int64)?;
let periods = periods_s.i64()?;
Ok(ca.zip_and_apply_amortized(periods, |opt_s, opt_periods| {
match (opt_s, opt_periods) {
(Some(s), Some(periods)) => Some(s.as_ref().shift(periods)),
_ => None,
}
}))
},
}
}

fn lst_slice(&self, offset: i64, length: usize) -> ListChunked {
Expand Down
9 changes: 9 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ pub enum ListFunction {
#[cfg(feature = "list_drop_nulls")]
DropNulls,
Slice,
Shift,
Get,
#[cfg(feature = "list_take")]
Take(bool),
Expand Down Expand Up @@ -45,6 +46,7 @@ impl Display for ListFunction {
#[cfg(feature = "list_drop_nulls")]
DropNulls => "drop_nulls",
Slice => "slice",
Shift => "shift",
Get => "get",
#[cfg(feature = "list_take")]
Take(_) => "take",
Expand Down Expand Up @@ -104,6 +106,13 @@ fn check_slice_arg_shape(slice_len: usize, ca_len: usize, name: &str) -> PolarsR
Ok(())
}

pub(super) fn shift(s: &[Series]) -> PolarsResult<Series> {
let list = s[0].list()?;
let periods = &s[1];

list.lst_shift(periods).map(|ok| ok.into_series())
}

pub(super) fn slice(args: &mut [Series]) -> PolarsResult<Option<Series>> {
let s = &args[0];
let list_ca = s.list()?;
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/dsl/function_expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,7 @@ impl From<FunctionExpr> for SpecialEq<Arc<dyn SeriesUdf>> {
#[cfg(feature = "list_drop_nulls")]
DropNulls => map!(list::drop_nulls),
Slice => wrap!(list::slice),
Shift => map_as_slice!(list::shift),
Get => wrap!(list::get),
#[cfg(feature = "list_take")]
Take(null_ob_oob) => map_as_slice!(list::take, null_ob_oob),
Expand Down
1 change: 1 addition & 0 deletions crates/polars-plan/src/dsl/function_expr/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ impl FunctionExpr {
#[cfg(feature = "list_drop_nulls")]
DropNulls => mapper.with_same_dtype(),
Slice => mapper.with_same_dtype(),
Shift => mapper.with_same_dtype(),
Get => mapper.map_to_list_inner_dtype(),
#[cfg(feature = "list_take")]
Take(_) => mapper.with_same_dtype(),
Expand Down
14 changes: 7 additions & 7 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ impl ListNameSpace {
}

/// Shift every sublist.
pub fn shift(self, periods: i64) -> Expr {
self.0
.map(
move |s| Ok(Some(s.list()?.lst_shift(periods).into_series())),
GetOutput::same_type(),
)
.with_fmt("list.shift")
pub fn shift(self, periods: Expr) -> Expr {
self.0.map_many_private(
FunctionExpr::ListExpr(ListFunction::Shift),
&[periods],
false,
false,
)
}

/// Slice every sublist.
Expand Down
10 changes: 8 additions & 2 deletions py-polars/polars/expr/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,12 @@
from datetime import date, datetime, time

from polars import Expr, Series
from polars.type_aliases import IntoExpr, NullBehavior, ToStructStrategy
from polars.type_aliases import (
IntoExpr,
IntoExprColumn,
NullBehavior,
ToStructStrategy,
)


class ExprListNameSpace:
Expand Down Expand Up @@ -642,7 +647,7 @@ def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Expr:
"""
return wrap_expr(self._pyexpr.list_diff(n, null_behavior))

def shift(self, periods: int = 1) -> Expr:
def shift(self, periods: int | IntoExprColumn = 1) -> Expr:
"""
Shift values by the given period.
Expand All @@ -663,6 +668,7 @@ def shift(self, periods: int = 1) -> Expr:
]
"""
periods = parse_as_expression(periods)
return wrap_expr(self._pyexpr.list_shift(periods))

def slice(
Expand Down
9 changes: 7 additions & 2 deletions py-polars/polars/series/list.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from polars import Expr, Series
from polars.polars import PySeries
from polars.type_aliases import IntoExpr, NullBehavior, ToStructStrategy
from polars.type_aliases import (
IntoExpr,
IntoExprColumn,
NullBehavior,
ToStructStrategy,
)


@expr_dispatch
Expand Down Expand Up @@ -333,7 +338,7 @@ def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Series:
"""

def shift(self, periods: int = 1) -> Series:
def shift(self, periods: int | IntoExprColumn = 1) -> Series:
"""
Shift values by the given period.
Expand Down
4 changes: 2 additions & 2 deletions py-polars/src/expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,8 @@ impl PyExpr {
self.inner.clone().list().reverse().into()
}

fn list_shift(&self, periods: i64) -> Self {
self.inner.clone().list().shift(periods).into()
fn list_shift(&self, periods: PyExpr) -> Self {
self.inner.clone().list().shift(periods.inner).into()
}

fn list_slice(&self, offset: PyExpr, length: Option<PyExpr>) -> Self {
Expand Down
26 changes: 26 additions & 0 deletions py-polars/tests/unit/namespaces/test_list.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,32 @@ def test_list_shift() -> None:
expected = pl.Series("a", [[None, 1], [None, 3, 2]])
assert s.list.shift().to_list() == expected.to_list()

df = pl.DataFrame(
{
"values": [
[1, 2, None],
[1, 2, 3],
[None, 1, 2],
[None, None, None],
[1, 2],
],
"shift": [1, -2, 3, 2, None],
}
)
df = df.select(pl.col("values").list.shift(pl.col("shift")))
expected_df = pl.DataFrame(
{
"values": [
[None, 1, 2],
[3, None, None],
[None, None, None],
[None, None, None],
None,
]
}
)
assert_frame_equal(df, expected_df)


def test_list_drop_nulls() -> None:
s = pl.Series("values", [[1, None, 2, None], [None, None], [1, 2], None])
Expand Down

0 comments on commit 8707b65

Please sign in to comment.