From 8707b65abd91e0aefbb94c1af59191b771977a20 Mon Sep 17 00:00:00 2001 From: Weijie Guo Date: Tue, 26 Sep 2023 16:04:26 +0800 Subject: [PATCH] feat: Expressify list.shift --- .../src/chunked_array/list/iterator.rs | 41 +++++++++++++++++++ .../src/chunked_array/list/namespace.rs | 28 +++++++++++-- .../polars-plan/src/dsl/function_expr/list.rs | 9 ++++ .../polars-plan/src/dsl/function_expr/mod.rs | 1 + .../src/dsl/function_expr/schema.rs | 1 + crates/polars-plan/src/dsl/list.rs | 14 +++---- py-polars/polars/expr/list.py | 10 ++++- py-polars/polars/series/list.py | 9 +++- py-polars/src/expr/list.rs | 4 +- py-polars/tests/unit/namespaces/test_list.py | 26 ++++++++++++ 10 files changed, 127 insertions(+), 16 deletions(-) diff --git a/crates/polars-core/src/chunked_array/list/iterator.rs b/crates/polars-core/src/chunked_array/list/iterator.rs index 6bf9ed2f60d68..2dc2c7eb85590 100644 --- a/crates/polars-core/src/chunked_array/list/iterator.rs +++ b/crates/polars-core/src/chunked_array/list/iterator.rs @@ -188,6 +188,47 @@ impl ListChunked { unsafe { self.amortized_iter().for_each(f) } } + /// Zip with a `ChunkedArray` then apply a binary function `F` elementwise. + #[must_use] + pub fn zip_and_apply_amortized<'a, T, I, F>(&'a self, ca: &'a ChunkedArray, mut f: F) -> Self + where + T: PolarsDataType, + &'a ChunkedArray: IntoIterator, + I: TrustedLen>>, + F: FnMut(Option>, Option>) -> Option, + { + if self.is_empty() { + return self.clone(); + } + let mut fast_explode = self.null_count() == 0; + // SAFETY: unstable series never lives longer than the iterator. + let mut out: ListChunked = unsafe { + self.amortized_iter() + .zip(ca) + .map(|(opt_s, opt_v)| { + let out = f(opt_s, opt_v); + match out { + Some(out) if out.is_empty() => { + fast_explode = false; + Some(out) + }, + None => { + fast_explode = false; + out + }, + _ => out, + } + }) + .collect_trusted() + }; + + out.rename(self.name()); + if fast_explode { + out.set_fast_explode(); + } + out + } + /// Apply a closure `F` elementwise. #[must_use] pub fn apply_amortized<'a, F>(&'a self, mut f: F) -> Self diff --git a/crates/polars-ops/src/chunked_array/list/namespace.rs b/crates/polars-ops/src/chunked_array/list/namespace.rs index 663e710e19df6..4e91d6f4c0332 100644 --- a/crates/polars-ops/src/chunked_array/list/namespace.rs +++ b/crates/polars-ops/src/chunked_array/list/namespace.rs @@ -259,10 +259,32 @@ pub trait ListNameSpaceImpl: AsList { ca.try_apply_amortized(|s| s.as_ref().diff(n, null_behavior)) } - fn lst_shift(&self, periods: i64) -> ListChunked { + fn lst_shift(&self, periods: &Series) -> PolarsResult { let ca = self.as_list(); - let out = ca.apply_amortized(|s| s.as_ref().shift(periods)); - self.same_type(out) + + match periods.len() { + 1 => { + if let Some(periods) = periods.get(0)?.extract::() { + Ok(ca.apply_amortized(|s| s.as_ref().shift(periods))) + } else { + Ok(ListChunked::full_null_with_dtype( + ca.name(), + ca.len(), + &ca.inner_dtype(), + )) + } + }, + _ => { + let periods_s = periods.cast(&DataType::Int64)?; + let periods = periods_s.i64()?; + Ok(ca.zip_and_apply_amortized(periods, |opt_s, opt_periods| { + match (opt_s, opt_periods) { + (Some(s), Some(periods)) => Some(s.as_ref().shift(periods)), + _ => None, + } + })) + }, + } } fn lst_slice(&self, offset: i64, length: usize) -> ListChunked { diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index fdcdef9f039ff..f4b89c36e43af 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -12,6 +12,7 @@ pub enum ListFunction { #[cfg(feature = "list_drop_nulls")] DropNulls, Slice, + Shift, Get, #[cfg(feature = "list_take")] Take(bool), @@ -45,6 +46,7 @@ impl Display for ListFunction { #[cfg(feature = "list_drop_nulls")] DropNulls => "drop_nulls", Slice => "slice", + Shift => "shift", Get => "get", #[cfg(feature = "list_take")] Take(_) => "take", @@ -104,6 +106,13 @@ fn check_slice_arg_shape(slice_len: usize, ca_len: usize, name: &str) -> PolarsR Ok(()) } +pub(super) fn shift(s: &[Series]) -> PolarsResult { + let list = s[0].list()?; + let periods = &s[1]; + + list.lst_shift(periods).map(|ok| ok.into_series()) +} + pub(super) fn slice(args: &mut [Series]) -> PolarsResult> { let s = &args[0]; let list_ca = s.list()?; diff --git a/crates/polars-plan/src/dsl/function_expr/mod.rs b/crates/polars-plan/src/dsl/function_expr/mod.rs index f6b668442b161..f6674b1a55197 100644 --- a/crates/polars-plan/src/dsl/function_expr/mod.rs +++ b/crates/polars-plan/src/dsl/function_expr/mod.rs @@ -555,6 +555,7 @@ impl From for SpecialEq> { #[cfg(feature = "list_drop_nulls")] DropNulls => map!(list::drop_nulls), Slice => wrap!(list::slice), + Shift => map_as_slice!(list::shift), Get => wrap!(list::get), #[cfg(feature = "list_take")] Take(null_ob_oob) => map_as_slice!(list::take, null_ob_oob), diff --git a/crates/polars-plan/src/dsl/function_expr/schema.rs b/crates/polars-plan/src/dsl/function_expr/schema.rs index ecdcc3b1ce496..063531a9c634a 100644 --- a/crates/polars-plan/src/dsl/function_expr/schema.rs +++ b/crates/polars-plan/src/dsl/function_expr/schema.rs @@ -103,6 +103,7 @@ impl FunctionExpr { #[cfg(feature = "list_drop_nulls")] DropNulls => mapper.with_same_dtype(), Slice => mapper.with_same_dtype(), + Shift => mapper.with_same_dtype(), Get => mapper.map_to_list_inner_dtype(), #[cfg(feature = "list_take")] Take(_) => mapper.with_same_dtype(), diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index d3186206a4c89..86261e43a82db 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -167,13 +167,13 @@ impl ListNameSpace { } /// Shift every sublist. - pub fn shift(self, periods: i64) -> Expr { - self.0 - .map( - move |s| Ok(Some(s.list()?.lst_shift(periods).into_series())), - GetOutput::same_type(), - ) - .with_fmt("list.shift") + pub fn shift(self, periods: Expr) -> Expr { + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::Shift), + &[periods], + false, + false, + ) } /// Slice every sublist. diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index c726d8996d923..2ea3d1291ed68 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -13,7 +13,12 @@ from datetime import date, datetime, time from polars import Expr, Series - from polars.type_aliases import IntoExpr, NullBehavior, ToStructStrategy + from polars.type_aliases import ( + IntoExpr, + IntoExprColumn, + NullBehavior, + ToStructStrategy, + ) class ExprListNameSpace: @@ -642,7 +647,7 @@ def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Expr: """ return wrap_expr(self._pyexpr.list_diff(n, null_behavior)) - def shift(self, periods: int = 1) -> Expr: + def shift(self, periods: int | IntoExprColumn = 1) -> Expr: """ Shift values by the given period. @@ -663,6 +668,7 @@ def shift(self, periods: int = 1) -> Expr: ] """ + periods = parse_as_expression(periods) return wrap_expr(self._pyexpr.list_shift(periods)) def slice( diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 184672aa8dac7..8f607ee1e42ea 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -12,7 +12,12 @@ from polars import Expr, Series from polars.polars import PySeries - from polars.type_aliases import IntoExpr, NullBehavior, ToStructStrategy + from polars.type_aliases import ( + IntoExpr, + IntoExprColumn, + NullBehavior, + ToStructStrategy, + ) @expr_dispatch @@ -333,7 +338,7 @@ def diff(self, n: int = 1, null_behavior: NullBehavior = "ignore") -> Series: """ - def shift(self, periods: int = 1) -> Series: + def shift(self, periods: int | IntoExprColumn = 1) -> Series: """ Shift values by the given period. diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index 7653f877f80f1..accd5529cf0bc 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -78,8 +78,8 @@ impl PyExpr { self.inner.clone().list().reverse().into() } - fn list_shift(&self, periods: i64) -> Self { - self.inner.clone().list().shift(periods).into() + fn list_shift(&self, periods: PyExpr) -> Self { + self.inner.clone().list().shift(periods.inner).into() } fn list_slice(&self, offset: PyExpr, length: Option) -> Self { diff --git a/py-polars/tests/unit/namespaces/test_list.py b/py-polars/tests/unit/namespaces/test_list.py index cfc2cfde15233..9eb39b89353f1 100644 --- a/py-polars/tests/unit/namespaces/test_list.py +++ b/py-polars/tests/unit/namespaces/test_list.py @@ -141,6 +141,32 @@ def test_list_shift() -> None: expected = pl.Series("a", [[None, 1], [None, 3, 2]]) assert s.list.shift().to_list() == expected.to_list() + df = pl.DataFrame( + { + "values": [ + [1, 2, None], + [1, 2, 3], + [None, 1, 2], + [None, None, None], + [1, 2], + ], + "shift": [1, -2, 3, 2, None], + } + ) + df = df.select(pl.col("values").list.shift(pl.col("shift"))) + expected_df = pl.DataFrame( + { + "values": [ + [None, 1, 2], + [3, None, None], + [None, None, None], + [None, None, None], + None, + ] + } + ) + assert_frame_equal(df, expected_df) + def test_list_drop_nulls() -> None: s = pl.Series("values", [[1, None, 2, None], [None, None], [1, 2], None])