From 8232818fb96e22ece49dc26a34e6b971b649d1f0 Mon Sep 17 00:00:00 2001 From: Object905 Date: Wed, 14 Feb 2024 22:50:12 +0500 Subject: [PATCH 1/2] wip: list boolean set operations --- .../polars-ops/src/chunked_array/list/sets.rs | 147 +++++++++++++----- .../polars-plan/src/dsl/function_expr/list.rs | 8 +- crates/polars-plan/src/dsl/list.rs | 7 + py-polars/polars/expr/list.py | 59 +++++++ py-polars/src/conversion/mod.rs | 3 +- py-polars/src/expr/list.rs | 1 + 6 files changed, 188 insertions(+), 37 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/sets.rs b/crates/polars-ops/src/chunked_array/list/sets.rs index 4a6f1f0466b4..d8c52c84dc42 100644 --- a/crates/polars-ops/src/chunked_array/list/sets.rs +++ b/crates/polars-ops/src/chunked_array/list/sets.rs @@ -2,13 +2,14 @@ use std::fmt::{Display, Formatter}; use std::hash::Hash; use arrow::array::{ - Array, BinaryViewArray, ListArray, MutableArray, MutablePlBinary, MutablePrimitiveArray, - PrimitiveArray, Utf8ViewArray, + Array, BinaryViewArray, BooleanArray, ListArray, MutableArray, MutableBooleanArray, + MutablePlBinary, MutablePrimitiveArray, PrimitiveArray, Utf8ViewArray, }; use arrow::bitmap::Bitmap; use arrow::compute::utils::combine_validities_and; use arrow::offset::OffsetsBuffer; use arrow::types::NativeType; +use either::Either; use polars_core::prelude::*; use polars_core::with_match_physical_numeric_type; use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrdWrap}; @@ -53,6 +54,7 @@ fn set_operation( a: I, b: J, out: &mut R, + bool_out: &mut MutableBooleanArray, set_op: SetOperation, broadcast_rhs: bool, ) -> usize @@ -99,6 +101,15 @@ where set.extend(a); out.extend_buf(set.symmetric_difference(set2).copied()) }, + SetOperation::IsDisjoint => { + // If broadcast `set2` should already be filled. + if !broadcast_rhs { + set2.clear(); + set2.extend(b); + } + bool_out.push(Some(!a.into_iter().any(|val| set2.contains(&val)))); + bool_out.len() + }, } } @@ -115,6 +126,7 @@ pub enum SetOperation { Union, Difference, SymmetricDifference, + IsDisjoint, } impl Display for SetOperation { @@ -124,11 +136,21 @@ impl Display for SetOperation { SetOperation::Union => "union", SetOperation::Difference => "difference", SetOperation::SymmetricDifference => "symmetric_difference", + SetOperation::IsDisjoint => "is_disjoint", }; write!(f, "{s}") } } +impl SetOperation { + fn is_boolean(&self) -> bool { + match self { + SetOperation::IsDisjoint => true, + _ => false, + } + } +} + fn primitive( a: &PrimitiveArray, b: &PrimitiveArray, @@ -136,7 +158,7 @@ fn primitive( offsets_b: &[i64], set_op: SetOperation, validity: Option, -) -> PolarsResult> +) -> PolarsResult, BooleanArray>> where T: NativeType + TotalHash + TotalEq + Copy + ToTotalOrd, as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy, @@ -147,10 +169,18 @@ where let mut set = Default::default(); let mut set2: PlIndexSet< as ToTotalOrd>::TotalOrdItem> = Default::default(); - let mut values_out = MutablePrimitiveArray::with_capacity(std::cmp::max( - *offsets_a.last().unwrap(), - *offsets_b.last().unwrap(), - ) as usize); + let needed_capacity = + std::cmp::max(*offsets_a.last().unwrap(), *offsets_b.last().unwrap()) as usize; + let mut values_out; + let mut bool_values_out; + if set_op.is_boolean() { + values_out = MutablePrimitiveArray::new(); + bool_values_out = MutableBooleanArray::with_capacity(needed_capacity); + } else { + values_out = MutablePrimitiveArray::with_capacity(needed_capacity); + bool_values_out = MutableBooleanArray::new(); + } + let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len())); offsets.push(0i64); @@ -199,6 +229,7 @@ where a_iter, b_iter, &mut values_out, + &mut bool_values_out, set_op, true, ) @@ -221,6 +252,7 @@ where a_iter, b_iter, &mut values_out, + &mut bool_values_out, set_op, false, ) @@ -243,6 +275,7 @@ where a_iter, b_iter, &mut values_out, + &mut bool_values_out, set_op, false, ) @@ -250,11 +283,21 @@ where offsets.push(offset as i64); } - let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; - let dtype = ListArray::::default_datatype(values_out.data_type().clone()); - let values: PrimitiveArray = values_out.into(); - Ok(ListArray::new(dtype, offsets, values.boxed(), validity)) + if set_op.is_boolean() { + Ok(Either::Right(bool_values_out.into())) + } else { + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; + let dtype = ListArray::::default_datatype(values_out.data_type().clone()); + + let values: PrimitiveArray = values_out.into(); + Ok(Either::Left(ListArray::new( + dtype, + offsets, + values.boxed(), + validity, + ))) + } } fn binary( @@ -265,16 +308,24 @@ fn binary( set_op: SetOperation, validity: Option, as_utf8: bool, -) -> PolarsResult> { +) -> PolarsResult, BooleanArray>> { let broadcast_lhs = offsets_a.len() == 2; let broadcast_rhs = offsets_b.len() == 2; let mut set = Default::default(); let mut set2: PlIndexSet> = Default::default(); - let mut values_out = MutablePlBinary::with_capacity(std::cmp::max( - *offsets_a.last().unwrap(), - *offsets_b.last().unwrap(), - ) as usize); + let needed_capacity = + std::cmp::max(*offsets_a.last().unwrap(), *offsets_b.last().unwrap()) as usize; + let mut values_out; + let mut bool_values_out; + if set_op.is_boolean() { + values_out = MutablePlBinary::new(); + bool_values_out = MutableBooleanArray::with_capacity(needed_capacity); + } else { + values_out = MutablePlBinary::with_capacity(needed_capacity); + bool_values_out = MutableBooleanArray::new(); + } + let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len())); offsets.push(0i64); @@ -310,6 +361,7 @@ fn binary( a_iter, b_iter, &mut values_out, + &mut bool_values_out, set_op, true, ) @@ -322,6 +374,7 @@ fn binary( a_iter, b_iter, &mut values_out, + &mut bool_values_out, set_op, false, ) @@ -335,22 +388,38 @@ fn binary( a_iter, b_iter, &mut values_out, + &mut bool_values_out, set_op, false, ) }; offsets.push(offset as i64); } - let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; - let values = values_out.freeze(); - if as_utf8 { - let values = unsafe { values.to_utf8view_unchecked() }; - let dtype = ListArray::::default_datatype(values.data_type().clone()); - Ok(ListArray::new(dtype, offsets, values.boxed(), validity)) + if set_op.is_boolean() { + Ok(Either::Right(bool_values_out.into())) } else { - let dtype = ListArray::::default_datatype(values.data_type().clone()); - Ok(ListArray::new(dtype, offsets, values.boxed(), validity)) + let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) }; + let values = values_out.freeze(); + + if as_utf8 { + let values = unsafe { values.to_utf8view_unchecked() }; + let dtype = ListArray::::default_datatype(values.data_type().clone()); + Ok(Either::Left(ListArray::new( + dtype, + offsets, + values.boxed(), + validity, + ))) + } else { + let dtype = ListArray::::default_datatype(values.data_type().clone()); + Ok(Either::Left(ListArray::new( + dtype, + offsets, + values.boxed(), + validity, + ))) + } } } @@ -358,7 +427,7 @@ fn array_set_operation( a: &ListArray, b: &ListArray, set_op: SetOperation, -) -> PolarsResult> { +) -> PolarsResult, BooleanArray>> { let offsets_a = a.offsets().as_slice(); let offsets_b = b.offsets().as_slice(); @@ -407,7 +476,7 @@ pub fn list_set_operation( a: &ListChunked, b: &ListChunked, set_op: SetOperation, -) -> PolarsResult { +) -> PolarsResult> { polars_ensure!(a.len() == b.len() || b.len() == 1 || a.len() == 1, ShapeMismatch: "column lengths don't match"); polars_ensure!(a.dtype() == b.dtype(), InvalidOperation: "cannot do 'set' operation on dtypes: {} and {}", a.dtype(), b.dtype()); let mut a = a.clone(); @@ -429,14 +498,22 @@ pub fn list_set_operation( (a, b) = make_list_categoricals_compatible(a, b)?; } - // we use the unsafe variant because we want to keep the nested logical types type. - unsafe { - arity::try_binary_unchecked_same_type( - &a, - &b, - |a, b| array_set_operation(a, b, set_op).map(|arr| arr.boxed()), - false, - false, - ) + if set_op.is_boolean() { + arity::try_binary(&a, &b, |a, b| { + array_set_operation(a, b, set_op).map(|arr| arr.unwrap_right()) + }) + .map(Either::Right) + } else { + // we use the unsafe variant because we want to keep the nested logical types type. + unsafe { + arity::try_binary_unchecked_same_type( + &a, + &b, + |a, b| array_set_operation(a, b, set_op).map(|arr| arr.unwrap_left().boxed()), + false, + false, + ) + .map(Either::Left) + } } } diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index e68b080d17f1..bbd9c6647e6b 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -590,6 +590,8 @@ pub(super) fn unique(s: &Series, is_stable: bool) -> PolarsResult { #[cfg(feature = "list_sets")] pub(super) fn set_operation(s: &[Series], set_type: SetOperation) -> PolarsResult { + use arrow::Either; + let s0 = &s[0]; let s1 = &s[1]; @@ -610,10 +612,14 @@ pub(super) fn set_operation(s: &[Series], set_type: SetOperation) -> PolarsResul Ok(s0.clone()) } }, + SetOperation::IsDisjoint => Ok(Series::new(s0.name(), [true])), }; } - list_set_operation(s0.list()?, s1.list()?, set_type).map(|ca| ca.into_series()) + list_set_operation(s0.list()?, s1.list()?, set_type).map(|ca| match ca { + Either::Left(values) => values.into_series(), + Either::Right(boolean) => boolean.into_series(), + }) } #[cfg(feature = "list_any_all")] diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 87691f263757..c97923fabc40 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -401,4 +401,11 @@ impl ListNameSpace { let other = other.into(); self.set_operation(other, SetOperation::SymmetricDifference) } + + /// Return true if the set has no elements in common with other. Sets are disjoint if and only if their intersection is the empty set. + #[cfg(feature = "list_sets")] + pub fn is_disjoint>(self, other: E) -> Expr { + let other = other.into(); + self.set_operation(other, SetOperation::IsDisjoint) + } } diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 390904997697..d6bfcc72b54e 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -1358,3 +1358,62 @@ def set_symmetric_difference(self, other: IntoExpr) -> Expr: """ # noqa: W505. other = parse_into_expression(other, str_as_lit=False) return wrap_expr(self._pyexpr.list_set_operation(other, "symmetric_difference")) + + def is_disjoint(self, other: IntoExpr) -> Expr: + """DOCS TODO.""" + other = parse_as_expression(other, str_as_lit=False) + return wrap_expr(self._pyexpr.list_set_operation(other, "is_disjoint")) + + @deprecate_renamed_function("count_matches", version="0.19.3") + def count_match(self, element: IntoExpr) -> Expr: + """ + Count how often the value produced by `element` occurs. + + .. deprecated:: 0.19.3 + This method has been renamed to :func:`count_matches`. + + Parameters + ---------- + element + An expression that produces a single value + """ + return self.count_matches(element) + + @deprecate_renamed_function("len", version="0.19.8") + def lengths(self) -> Expr: + """ + Return the number of elements in each list. + + .. deprecated:: 0.19.8 + This method has been renamed to :func:`len`. + """ + return self.len() + + @deprecate_renamed_function("gather", version="0.19.14") + @deprecate_renamed_parameter("index", "indices", version="0.19.14") + def take( + self, + indices: Expr | Series | list[int] | list[list[int]], + *, + null_on_oob: bool = False, + ) -> Expr: + """ + Take sublists by multiple indices. + + The indices may be defined in a single column, or by sublists in another + column of dtype `List`. + + .. deprecated:: 0.19.14 + This method has been renamed to :func:`gather`. + + Parameters + ---------- + indices + Indices to return per sublist + null_on_oob + Behavior if an index is out of bounds: + True -> set as null + False -> raise an error + Note that defaulting to raising an error is much cheaper + """ + return self.gather(indices) diff --git a/py-polars/src/conversion/mod.rs b/py-polars/src/conversion/mod.rs index 6056103b3251..2933dd3fd5fc 100644 --- a/py-polars/src/conversion/mod.rs +++ b/py-polars/src/conversion/mod.rs @@ -1098,9 +1098,10 @@ impl<'py> FromPyObject<'py> for Wrap { "difference" => SetOperation::Difference, "intersection" => SetOperation::Intersection, "symmetric_difference" => SetOperation::SymmetricDifference, + "is_disjoint" => SetOperation::IsDisjoint, v => { return Err(PyValueError::new_err(format!( - "set operation must be one of {{'union', 'difference', 'intersection', 'symmetric_difference'}}, got {v}", + "set operation must be one of {{'union', 'difference', 'intersection', 'symmetric_difference', 'is_disjoint'}}, got {v}", ))) } }; diff --git a/py-polars/src/expr/list.rs b/py-polars/src/expr/list.rs index 9ab917918b83..3e079313ff1a 100644 --- a/py-polars/src/expr/list.rs +++ b/py-polars/src/expr/list.rs @@ -250,6 +250,7 @@ impl PyExpr { SetOperation::Difference => e.set_difference(other.inner), SetOperation::Union => e.union(other.inner), SetOperation::SymmetricDifference => e.set_symmetric_difference(other.inner), + SetOperation::IsDisjoint => e.is_disjoint(other.inner), } .into() } From 549079d4206a775636f4eaee82de8bd4ac99ffd2 Mon Sep 17 00:00:00 2001 From: Object905 Date: Sat, 17 Feb 2024 15:35:29 +0500 Subject: [PATCH 2/2] add docs, add is_disjoint to Series.list, simple test for empty lists --- py-polars/polars/expr/list.py | 34 +++++++++++- py-polars/polars/series/list.py | 85 ++++++++++++++++++++++++++++++ py-polars/tests/unit/test_empty.py | 8 +++ 3 files changed, 126 insertions(+), 1 deletion(-) diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index d6bfcc72b54e..3cd1522fd74e 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -1360,7 +1360,39 @@ def set_symmetric_difference(self, other: IntoExpr) -> Expr: return wrap_expr(self._pyexpr.list_set_operation(other, "symmetric_difference")) def is_disjoint(self, other: IntoExpr) -> Expr: - """DOCS TODO.""" + """ + Return true if list has no elements in common with `other`. + + Lists are disjoint if and only if their intersection is empty. + + Parameters + ---------- + other + Right hand side of the set operation. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "a": [[1, 2, 3], [None, 3], [None, 3], [6, 7], [None], [None, 1]], + ... "b": [[3, 4, 5], [3], [3, 4, None], [8, 9], [1], [None, 3]], + ... } + ... ) + >>> df.with_columns(disjoint=pl.col("b").list.is_disjoint("a")) + shape: (6, 3) + ┌───────────┬──────────────┬──────────┐ + │ a ┆ b ┆ disjoint │ + │ --- ┆ --- ┆ --- │ + │ list[i64] ┆ list[i64] ┆ bool │ + ╞═══════════╪══════════════╪══════════╡ + │ [1, 2, 3] ┆ [3, 4, 5] ┆ false │ + │ [null, 3] ┆ [3] ┆ false │ + │ [null, 3] ┆ [3, 4, null] ┆ false │ + │ [5, 6, 7] ┆ [8, 9] ┆ true │ + │ [null] ┆ [1] ┆ true │ + │ [null, 1] ┆ [null, 3] ┆ false │ + └───────────┴──────────────┴──────────┘ + """ other = parse_as_expression(other, str_as_lit=False) return wrap_expr(self._pyexpr.list_set_operation(other, "is_disjoint")) diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 53bcb22ce6f3..a5c0780fd396 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -1052,3 +1052,88 @@ def set_symmetric_difference(self, other: Series) -> Series: [5, 7, 8] ] """ # noqa: W505 + + def is_disjoint(self, other: Series) -> Expr: + """ + Return true if list has no elements in common with `other`. + + Lists are disjoint if and only if their intersection is empty. + + Parameters + ---------- + other + Right hand side of the set operation. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "a": [[1, 2, 3], [None, 3], [None, 3], [6, 7], [None], [None, 1]], + ... "b": [[3, 4, 5], [3], [3, 4, None], [8, 9], [1], [None, 3]], + ... } + ... ) + >>> df.with_columns(disjoint=pl.col("b").list.is_disjoint("a")) + shape: (6, 3) + ┌───────────┬──────────────┬──────────┐ + │ a ┆ b ┆ disjoint │ + │ --- ┆ --- ┆ --- │ + │ list[i64] ┆ list[i64] ┆ bool │ + ╞═══════════╪══════════════╪══════════╡ + │ [1, 2, 3] ┆ [3, 4, 5] ┆ false │ + │ [null, 3] ┆ [3] ┆ false │ + │ [null, 3] ┆ [3, 4, null] ┆ false │ + │ [5, 6, 7] ┆ [8, 9] ┆ true │ + │ [null] ┆ [1] ┆ true │ + │ [null, 1] ┆ [null, 3] ┆ false │ + └───────────┴──────────────┴──────────┘ + """ + + @deprecate_renamed_function("count_matches", version="0.19.3") + def count_match( + self, element: float | str | bool | int | date | datetime | time | Expr + ) -> Expr: + """ + Count how often the value produced by `element` occurs. + + .. deprecated:: 0.19.3 + This method has been renamed to :func:`count_matches`. + + Parameters + ---------- + element + An expression that produces a single value + """ + + @deprecate_renamed_function("len", version="0.19.8") + def lengths(self) -> Series: + """ + Return the number of elements in each list. + + .. deprecated:: 0.19.8 + This method has been renamed to :func:`len`. + """ + + @deprecate_renamed_function("gather", version="0.19.14") + @deprecate_renamed_parameter("index", "indices", version="0.19.14") + def take( + self, + indices: Series | list[int] | list[list[int]], + *, + null_on_oob: bool = False, + ) -> Series: + """ + Take sublists by multiple indices. + + .. deprecated:: 0.19.14 + This method has been renamed to :func:`gather`. + + Parameters + ---------- + indices + Indices to return per sublist + null_on_oob + Behavior if an index is out of bounds: + True -> set as null + False -> raise an error + Note that defaulting to raising an error is much cheaper + """ diff --git a/py-polars/tests/unit/test_empty.py b/py-polars/tests/unit/test_empty.py index 8b5826c25380..70e3c77652b1 100644 --- a/py-polars/tests/unit/test_empty.py +++ b/py-polars/tests/unit/test_empty.py @@ -113,6 +113,14 @@ def test_empty_set_symteric_difference() -> None: assert_series_equal(full.rename("empty"), empty.list.set_symmetric_difference(full)) +def test_empty_is_disjoint() -> None: + s1 = pl.Series("s1", [[1, 2], [], [None]], pl.List(pl.UInt32)) + s2 = pl.Series("s2", [[], [1, 2], []], pl.List(pl.UInt32)) + expected = pl.Series("s1", [True, True, True]) + + assert_series_equal(expected, s1.list.is_disjoint(s2)) + + @pytest.mark.parametrize("name", ["sort", "unique", "head", "tail", "shift", "reverse"]) def test_empty_list_namespace_output_9585(name: str) -> None: dtype = pl.List(pl.String)