Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Boolean list set operations #14558

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 112 additions & 35 deletions crates/polars-ops/src/chunked_array/list/sets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@ use std::fmt::{Display, Formatter};
use std::hash::Hash;

use arrow::array::{
Array, BinaryViewArray, ListArray, MutableArray, MutablePlBinary, MutablePrimitiveArray,
PrimitiveArray, Utf8ViewArray,
Array, BinaryViewArray, BooleanArray, ListArray, MutableArray, MutableBooleanArray,
MutablePlBinary, MutablePrimitiveArray, PrimitiveArray, Utf8ViewArray,
};
use arrow::bitmap::Bitmap;
use arrow::compute::utils::combine_validities_and;
use arrow::offset::OffsetsBuffer;
use arrow::types::NativeType;
use either::Either;
use polars_core::prelude::*;
use polars_core::with_match_physical_numeric_type;
use polars_utils::total_ord::{ToTotalOrd, TotalEq, TotalHash, TotalOrdWrap};
Expand Down Expand Up @@ -53,6 +54,7 @@ fn set_operation<K, I, J, R>(
a: I,
b: J,
out: &mut R,
bool_out: &mut MutableBooleanArray,
set_op: SetOperation,
broadcast_rhs: bool,
) -> usize
Expand Down Expand Up @@ -99,6 +101,15 @@ where
set.extend(a);
out.extend_buf(set.symmetric_difference(set2).copied())
},
SetOperation::IsDisjoint => {
// If broadcast `set2` should already be filled.
if !broadcast_rhs {
set2.clear();
set2.extend(b);
Object905 marked this conversation as resolved.
Show resolved Hide resolved
}
bool_out.push(Some(!a.into_iter().any(|val| set2.contains(&val))));
bool_out.len()
},
}
}

Expand All @@ -115,6 +126,7 @@ pub enum SetOperation {
Union,
Difference,
SymmetricDifference,
IsDisjoint,
}

impl Display for SetOperation {
Expand All @@ -124,19 +136,29 @@ impl Display for SetOperation {
SetOperation::Union => "union",
SetOperation::Difference => "difference",
SetOperation::SymmetricDifference => "symmetric_difference",
SetOperation::IsDisjoint => "is_disjoint",
};
write!(f, "{s}")
}
}

impl SetOperation {
fn is_boolean(&self) -> bool {
match self {
SetOperation::IsDisjoint => true,
_ => false,
}
}
}

fn primitive<T>(
a: &PrimitiveArray<T>,
b: &PrimitiveArray<T>,
offsets_a: &[i64],
offsets_b: &[i64],
set_op: SetOperation,
validity: Option<Bitmap>,
) -> PolarsResult<ListArray<i64>>
) -> PolarsResult<Either<ListArray<i64>, BooleanArray>>
where
T: NativeType + TotalHash + TotalEq + Copy + ToTotalOrd,
<Option<T> as ToTotalOrd>::TotalOrdItem: Hash + Eq + Copy,
Expand All @@ -147,10 +169,18 @@ where
let mut set = Default::default();
let mut set2: PlIndexSet<<Option<T> as ToTotalOrd>::TotalOrdItem> = Default::default();

let mut values_out = MutablePrimitiveArray::with_capacity(std::cmp::max(
*offsets_a.last().unwrap(),
*offsets_b.last().unwrap(),
) as usize);
let needed_capacity =
std::cmp::max(*offsets_a.last().unwrap(), *offsets_b.last().unwrap()) as usize;
let mut values_out;
let mut bool_values_out;
if set_op.is_boolean() {
values_out = MutablePrimitiveArray::new();
bool_values_out = MutableBooleanArray::with_capacity(needed_capacity);
} else {
values_out = MutablePrimitiveArray::with_capacity(needed_capacity);
bool_values_out = MutableBooleanArray::new();
}

let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len()));
offsets.push(0i64);

Expand Down Expand Up @@ -199,6 +229,7 @@ where
a_iter,
b_iter,
&mut values_out,
&mut bool_values_out,
set_op,
true,
)
Expand All @@ -221,6 +252,7 @@ where
a_iter,
b_iter,
&mut values_out,
&mut bool_values_out,
set_op,
false,
)
Expand All @@ -243,18 +275,29 @@ where
a_iter,
b_iter,
&mut values_out,
&mut bool_values_out,
set_op,
false,
)
};

offsets.push(offset as i64);
}
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };
let dtype = ListArray::<i64>::default_datatype(values_out.data_type().clone());

let values: PrimitiveArray<T> = values_out.into();
Ok(ListArray::new(dtype, offsets, values.boxed(), validity))
if set_op.is_boolean() {
Ok(Either::Right(bool_values_out.into()))
} else {
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };
let dtype = ListArray::<i64>::default_datatype(values_out.data_type().clone());

let values: PrimitiveArray<T> = values_out.into();
Ok(Either::Left(ListArray::new(
dtype,
offsets,
values.boxed(),
validity,
)))
}
}

fn binary(
Expand All @@ -265,16 +308,24 @@ fn binary(
set_op: SetOperation,
validity: Option<Bitmap>,
as_utf8: bool,
) -> PolarsResult<ListArray<i64>> {
) -> PolarsResult<Either<ListArray<i64>, BooleanArray>> {
let broadcast_lhs = offsets_a.len() == 2;
let broadcast_rhs = offsets_b.len() == 2;
let mut set = Default::default();
let mut set2: PlIndexSet<Option<&[u8]>> = Default::default();

let mut values_out = MutablePlBinary::with_capacity(std::cmp::max(
*offsets_a.last().unwrap(),
*offsets_b.last().unwrap(),
) as usize);
let needed_capacity =
std::cmp::max(*offsets_a.last().unwrap(), *offsets_b.last().unwrap()) as usize;
let mut values_out;
let mut bool_values_out;
if set_op.is_boolean() {
values_out = MutablePlBinary::new();
bool_values_out = MutableBooleanArray::with_capacity(needed_capacity);
} else {
values_out = MutablePlBinary::with_capacity(needed_capacity);
bool_values_out = MutableBooleanArray::new();
}

let mut offsets = Vec::with_capacity(std::cmp::max(offsets_a.len(), offsets_b.len()));
Object905 marked this conversation as resolved.
Show resolved Hide resolved
offsets.push(0i64);

Expand Down Expand Up @@ -310,6 +361,7 @@ fn binary(
a_iter,
b_iter,
&mut values_out,
&mut bool_values_out,
set_op,
true,
)
Expand All @@ -322,6 +374,7 @@ fn binary(
a_iter,
b_iter,
&mut values_out,
&mut bool_values_out,
set_op,
false,
)
Expand All @@ -335,30 +388,46 @@ fn binary(
a_iter,
b_iter,
&mut values_out,
&mut bool_values_out,
set_op,
false,
)
};
offsets.push(offset as i64);
}
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };
let values = values_out.freeze();

if as_utf8 {
let values = unsafe { values.to_utf8view_unchecked() };
let dtype = ListArray::<i64>::default_datatype(values.data_type().clone());
Ok(ListArray::new(dtype, offsets, values.boxed(), validity))
if set_op.is_boolean() {
Ok(Either::Right(bool_values_out.into()))
} else {
let dtype = ListArray::<i64>::default_datatype(values.data_type().clone());
Ok(ListArray::new(dtype, offsets, values.boxed(), validity))
let offsets = unsafe { OffsetsBuffer::new_unchecked(offsets.into()) };
let values = values_out.freeze();

if as_utf8 {
let values = unsafe { values.to_utf8view_unchecked() };
let dtype = ListArray::<i64>::default_datatype(values.data_type().clone());
Ok(Either::Left(ListArray::new(
dtype,
offsets,
values.boxed(),
validity,
)))
} else {
let dtype = ListArray::<i64>::default_datatype(values.data_type().clone());
Ok(Either::Left(ListArray::new(
dtype,
offsets,
values.boxed(),
validity,
)))
}
}
}

fn array_set_operation(
a: &ListArray<i64>,
b: &ListArray<i64>,
set_op: SetOperation,
) -> PolarsResult<ListArray<i64>> {
) -> PolarsResult<Either<ListArray<i64>, BooleanArray>> {
let offsets_a = a.offsets().as_slice();
let offsets_b = b.offsets().as_slice();

Expand Down Expand Up @@ -407,7 +476,7 @@ pub fn list_set_operation(
a: &ListChunked,
b: &ListChunked,
set_op: SetOperation,
) -> PolarsResult<ListChunked> {
) -> PolarsResult<Either<ListChunked, BooleanChunked>> {
polars_ensure!(a.len() == b.len() || b.len() == 1 || a.len() == 1, ShapeMismatch: "column lengths don't match");
polars_ensure!(a.dtype() == b.dtype(), InvalidOperation: "cannot do 'set' operation on dtypes: {} and {}", a.dtype(), b.dtype());
let mut a = a.clone();
Expand All @@ -429,14 +498,22 @@ pub fn list_set_operation(
(a, b) = make_list_categoricals_compatible(a, b)?;
}

// we use the unsafe variant because we want to keep the nested logical types type.
unsafe {
arity::try_binary_unchecked_same_type(
&a,
&b,
|a, b| array_set_operation(a, b, set_op).map(|arr| arr.boxed()),
false,
false,
)
if set_op.is_boolean() {
arity::try_binary(&a, &b, |a, b| {
array_set_operation(a, b, set_op).map(|arr| arr.unwrap_right())
})
.map(Either::Right)
} else {
// we use the unsafe variant because we want to keep the nested logical types type.
unsafe {
arity::try_binary_unchecked_same_type(
&a,
&b,
|a, b| array_set_operation(a, b, set_op).map(|arr| arr.unwrap_left().boxed()),
false,
false,
)
.map(Either::Left)
}
}
}
8 changes: 7 additions & 1 deletion crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -590,6 +590,8 @@ pub(super) fn unique(s: &Series, is_stable: bool) -> PolarsResult<Series> {

#[cfg(feature = "list_sets")]
pub(super) fn set_operation(s: &[Series], set_type: SetOperation) -> PolarsResult<Series> {
use arrow::Either;

let s0 = &s[0];
let s1 = &s[1];

Expand All @@ -610,10 +612,14 @@ pub(super) fn set_operation(s: &[Series], set_type: SetOperation) -> PolarsResul
Ok(s0.clone())
}
},
SetOperation::IsDisjoint => Ok(Series::new(s0.name(), [true])),
};
}

list_set_operation(s0.list()?, s1.list()?, set_type).map(|ca| ca.into_series())
list_set_operation(s0.list()?, s1.list()?, set_type).map(|ca| match ca {
Either::Left(values) => values.into_series(),
Either::Right(boolean) => boolean.into_series(),
})
}

#[cfg(feature = "list_any_all")]
Expand Down
7 changes: 7 additions & 0 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -401,4 +401,11 @@ impl ListNameSpace {
let other = other.into();
self.set_operation(other, SetOperation::SymmetricDifference)
}

/// Return true if the set has no elements in common with other. Sets are disjoint if and only if their intersection is the empty set.
#[cfg(feature = "list_sets")]
pub fn is_disjoint<E: Into<Expr>>(self, other: E) -> Expr {
let other = other.into();
self.set_operation(other, SetOperation::IsDisjoint)
}
}
Loading
Loading