Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add list.index_of_in() to Expr and Series #21192

Open
wants to merge 31 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
1a09f7b
Add the boilerplate for a new expression method `list.index_of_in()`.
pythonspeed Jan 15, 2025
3083803
Sketch of implementation, initial tests pass.
pythonspeed Jan 15, 2025
9fc4d5f
More explanations
pythonspeed Jan 15, 2025
85b26ff
Alternative implementation in Python.
pythonspeed Jan 28, 2025
2365f46
Remove much slower Python version.
pythonspeed Feb 3, 2025
d82565a
Merge remote-tracking branch 'origin/main' into 20626-index_of_in
pythonspeed Feb 3, 2025
6e4b005
Always check both literal and non-literal.
pythonspeed Feb 3, 2025
372ba35
Start making more general tests, make eager API work.
pythonspeed Feb 3, 2025
255ccc9
Integer tests pass.
pythonspeed Feb 4, 2025
1a41257
Support and tests for floats.
pythonspeed Feb 4, 2025
61e3d1c
More test coverage.
pythonspeed Feb 5, 2025
29a4068
Start of nested lists working.
pythonspeed Feb 5, 2025
a11b77a
Expand testing of nested lists.
pythonspeed Feb 5, 2025
07f35a0
Working arrays, do casting in the IR.
pythonspeed Feb 6, 2025
93075a0
Minimal support for multiple chunks in the needle.
pythonspeed Feb 6, 2025
4a96f99
Drop debug prints.
pythonspeed Feb 6, 2025
46feae9
Mismatched length.
pythonspeed Feb 6, 2025
23230e0
Expand.
pythonspeed Feb 6, 2025
5d40c9c
Optimize the scalar case of list_index_of_in().
pythonspeed Feb 7, 2025
ddbd9a4
Optimize a bunch of code paths.
pythonspeed Feb 7, 2025
a2d3e76
Bit more testing.
pythonspeed Feb 7, 2025
119d4e9
Lints and cleanups.
pythonspeed Feb 7, 2025
c1dfb3d
Validate input types.
pythonspeed Feb 10, 2025
86ed067
Documentation.
pythonspeed Feb 10, 2025
1ab116c
Cover another edge case.
pythonspeed Feb 10, 2025
da11595
Lint.
pythonspeed Feb 10, 2025
1df81f3
Type checks.
pythonspeed Feb 11, 2025
5c8ad6d
Reformat.
pythonspeed Feb 11, 2025
00887ce
Unnecessary.
pythonspeed Feb 11, 2025
36556d8
Add dtype feature conditions.
pythonspeed Feb 11, 2025
d6fa1d4
Feature list_index_of_in requires index_of feature.
pythonspeed Feb 11, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/polars-lazy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,7 @@ sign = ["polars-plan/sign"]
timezones = ["polars-plan/timezones"]
list_gather = ["polars-ops/list_gather", "polars-plan/list_gather"]
list_count = ["polars-ops/list_count", "polars-plan/list_count"]
list_index_of_in = ["polars-ops/list_index_of_in", "polars-plan/list_index_of_in"]
array_count = ["polars-ops/array_count", "polars-plan/array_count", "dtype-array"]
true_div = ["polars-plan/true_div"]
extract_jsonpath = ["polars-plan/extract_jsonpath", "polars-ops/extract_jsonpath"]
Expand Down Expand Up @@ -376,6 +377,7 @@ features = [
"list_drop_nulls",
"list_eval",
"list_gather",
"list_index_of_in",
"list_sample",
"list_sets",
"list_to_struct",
Expand Down
1 change: 1 addition & 0 deletions crates/polars-ops/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -146,3 +146,4 @@ abs = []
cov = []
gather = []
replace = ["is_in"]
list_index_of_in = ["index_of"]
159 changes: 159 additions & 0 deletions crates/polars-ops/src/chunked_array/list/index_of_in.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
use polars_core::match_arrow_dtype_apply_macro_ca;

use super::*;
use crate::series::{index_of, index_of_null};

macro_rules! to_anyvalue_iterator {
($ca:expr) => {{
use polars_core::prelude::AnyValue;
Box::new($ca.iter().map(AnyValue::from))
}};
}

fn series_to_anyvalue_iter(series: &Series) -> Box<dyn ExactSizeIterator<Item = AnyValue> + '_> {
let dtype = series.dtype();
match dtype {
#[cfg(feature = "dtype-date")]
DataType::Date => {
return to_anyvalue_iterator!(series.date().unwrap());
},
#[cfg(feature = "dtype-datetime")]
DataType::Datetime(_, _) => {
return to_anyvalue_iterator!(series.datetime().unwrap());
},
#[cfg(feature = "dtype-time")]
DataType::Time => {
return to_anyvalue_iterator!(series.time().unwrap());
},
#[cfg(feature = "dtype-duration")]
DataType::Duration(_) => {
return to_anyvalue_iterator!(series.duration().unwrap());
},
DataType::Binary => {
return to_anyvalue_iterator!(series.binary().unwrap());
},
#[cfg(feature = "dtype-decimal")]
DataType::Decimal(_, _) => {
return to_anyvalue_iterator!(series.decimal().unwrap());
},
_ => (),
};
match_arrow_dtype_apply_macro_ca!(
series,
to_anyvalue_iterator,
to_anyvalue_iterator,
to_anyvalue_iterator
)
}

/// Given a needle, or needles, find the corresponding needle in each value of a
/// ListChunked.
pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult<Series> {
// Handle scalar case separately, since we can do some optimizations given
// the extra knowledge we have.
if needles.len() == 1 {
let needle = needles.get(0).unwrap();
return list_index_of_in_for_scalar(ca, needle);
}

polars_ensure!(
ca.len() == needles.len(),
ComputeError: "shapes don't match: expected {} elements in 'index_of_in' comparison, got {}",
ca.len(),
needles.len()
);
let dtype = needles.dtype();
let owned;
let needle_iter = if dtype.is_list()
|| dtype.is_array()
|| dtype.is_enum()
|| dtype.is_categorical()
|| dtype.is_struct()
{
owned = needles.rechunk();
Box::new(owned.iter())
} else {
// Optimized versions:
series_to_anyvalue_iter(needles)
};

let mut builder = PrimitiveChunkedBuilder::<IdxType>::new(ca.name().clone(), ca.len());
ca.amortized_iter()
.zip(needle_iter)
.for_each(|(opt_series, needle)| match (opt_series, needle) {
(None, _) => builder.append_null(),
(Some(subseries), needle) => {
let needle = Scalar::new(needles.dtype().clone(), needle.into_static());
builder.append_option(
index_of(subseries.as_ref(), needle)
.unwrap()
.map(|v| v.try_into().unwrap()),
);
},
});

Ok(builder.finish().into())
}

macro_rules! process_series_for_numeric_value {
($extractor:ident, $needle:ident) => {{
use arrow::array::PrimitiveArray;

use crate::series::index_of_value;

let needle = $needle.extract::<$extractor>().unwrap();
Box::new(move |subseries| {
index_of_value::<_, PrimitiveArray<$extractor>>(subseries.$extractor().unwrap(), needle)
})
}};
}

#[allow(clippy::type_complexity)] // For the Box<dyn Fn()>
fn list_index_of_in_for_scalar(ca: &ListChunked, needle: AnyValue<'_>) -> PolarsResult<Series> {
polars_ensure!(
ca.dtype().inner_dtype().unwrap() == &needle.dtype() || needle.dtype().is_null(),
ComputeError: "dtypes didn't match: series values have dtype {} and needle has dtype {}",
ca.dtype().inner_dtype().unwrap(),
needle.dtype()
);

let mut builder = PrimitiveChunkedBuilder::<IdxType>::new(ca.name().clone(), ca.len());
let needle = needle.into_static();
let inner_dtype = ca.dtype().inner_dtype().unwrap();
let needle_dtype = needle.dtype();

let process_series: Box<dyn Fn(&Series) -> Option<usize>> = match needle_dtype {
DataType::Null => Box::new(index_of_null),
#[cfg(feature = "dtype-u8")]
DataType::UInt8 => process_series_for_numeric_value!(u8, needle),
#[cfg(feature = "dtype-u16")]
DataType::UInt16 => process_series_for_numeric_value!(u16, needle),
DataType::UInt32 => process_series_for_numeric_value!(u32, needle),
DataType::UInt64 => process_series_for_numeric_value!(u64, needle),
#[cfg(feature = "dtype-i8")]
DataType::Int8 => process_series_for_numeric_value!(i8, needle),
#[cfg(feature = "dtype-i16")]
DataType::Int16 => process_series_for_numeric_value!(i16, needle),
DataType::Int32 => process_series_for_numeric_value!(i32, needle),
DataType::Int64 => process_series_for_numeric_value!(i64, needle),
#[cfg(feature = "dtype-i128")]
DataType::Int128 => process_series_for_numeric_value!(i128, needle),
DataType::Float32 => process_series_for_numeric_value!(f32, needle),
DataType::Float64 => process_series_for_numeric_value!(f64, needle),
// Just use the general purpose index_of() function:
_ => Box::new(|subseries| {
let needle = Scalar::new(inner_dtype.clone(), needle.clone());
index_of(subseries, needle).unwrap()
}),
};

ca.amortized_iter().for_each(|opt_series| {
if let Some(subseries) = opt_series {
builder
.append_option(process_series(subseries.as_ref()).map(|v| v.try_into().unwrap()));
} else {
builder.append_null();
}
});
Ok(builder.finish().into())
}
4 changes: 4 additions & 0 deletions crates/polars-ops/src/chunked_array/list/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ mod count;
mod dispersion;
#[cfg(feature = "hash")]
pub(crate) mod hash;
#[cfg(feature = "list_index_of_in")]
mod index_of_in;
mod min_max;
mod namespace;
#[cfg(feature = "list_sets")]
Expand All @@ -18,6 +20,8 @@ mod to_struct;
pub use count::*;
#[cfg(not(feature = "list_count"))]
use count::*;
#[cfg(feature = "list_index_of_in")]
pub use index_of_in::*;
pub use namespace::*;
#[cfg(feature = "list_sets")]
pub use sets::*;
Expand Down
42 changes: 25 additions & 17 deletions crates/polars-ops/src/series/ops/index_of.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@ use polars_utils::total_ord::TotalEq;
use row_encode::encode_rows_unordered;

/// Find the index of the value, or ``None`` if it can't be found.
fn index_of_value<'a, DT, AR>(ca: &'a ChunkedArray<DT>, value: AR::ValueT<'a>) -> Option<usize>
pub(crate) fn index_of_value<'a, DT, AR>(
ca: &'a ChunkedArray<DT>,
value: AR::ValueT<'a>,
) -> Option<usize>
where
DT: PolarsDataType,
AR: StaticArray,
Expand Down Expand Up @@ -51,14 +54,31 @@ macro_rules! try_index_of_numeric_ca {
($ca:expr, $value:expr) => {{
let ca = $ca;
let value = $value;
// extract() returns None if casting failed, so consider an extract()
// failure as not finding the value. Nulls should have been handled
// earlier.
// extract() returns None if casting failed, and by this point Nulls
// have been handled, and everything should have been cast to matching
// dtype otherwise.
let value = value.value().extract().unwrap();
index_of_numeric_value(ca, value)
}};
}

/// Find the index of nulls within a Series.
pub(crate) fn index_of_null(series: &Series) -> Option<usize> {
let mut index = 0;
for chunk in series.chunks() {
let length = chunk.len();
if let Some(bitmap) = chunk.validity() {
let leading_ones = bitmap.leading_ones();
if leading_ones < length {
return Some(index + leading_ones);
}
} else {
index += length;
}
}
None
}

/// Find the index of a given value (the first and only entry in `value_series`)
/// within the series.
pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult<Option<usize>> {
Expand All @@ -80,19 +100,7 @@ pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult<Option<usize>>

// Series is not null, and the value is null:
if needle.is_null() {
let mut index = 0;
for chunk in series.chunks() {
let length = chunk.len();
if let Some(bitmap) = chunk.validity() {
let leading_ones = bitmap.leading_ones();
if leading_ones < length {
return Ok(Some(index + leading_ones));
}
} else {
index += length;
}
}
return Ok(None);
return Ok(index_of_null(series));
}

if series.dtype().is_primitive_numeric() {
Expand Down
2 changes: 2 additions & 0 deletions crates/polars-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ dtype-struct = ["polars-core/dtype-struct"]
object = ["polars-core/object"]
list_gather = ["polars-ops/list_gather"]
list_count = ["polars-ops/list_count"]
list_index_of_in = ["polars-ops/list_index_of_in"]
array_count = ["polars-ops/array_count", "dtype-array"]
trigonometry = []
sign = []
Expand Down Expand Up @@ -293,6 +294,7 @@ features = [
"streaming",
"true_div",
"sign",
"list_index_of_in",
]
# defines the configuration attribute `docsrs`
rustdoc-args = ["--cfg", "docsrs"]
16 changes: 16 additions & 0 deletions crates/polars-plan/src/dsl/function_expr/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ pub enum ListFunction {
ToArray(usize),
#[cfg(feature = "list_to_struct")]
ToStruct(ListToStructArgs),
#[cfg(feature = "list_index_of_in")]
IndexOfIn,
}

impl ListFunction {
Expand Down Expand Up @@ -107,6 +109,8 @@ impl ListFunction {
NUnique => mapper.with_dtype(IDX_DTYPE),
#[cfg(feature = "list_to_struct")]
ToStruct(args) => mapper.try_map_dtype(|x| args.get_output_dtype(x)),
#[cfg(feature = "list_index_of_in")]
IndexOfIn => mapper.with_dtype(IDX_DTYPE),
}
}
}
Expand Down Expand Up @@ -180,6 +184,8 @@ impl Display for ListFunction {
ToArray(_) => "to_array",
#[cfg(feature = "list_to_struct")]
ToStruct(_) => "to_struct",
#[cfg(feature = "list_index_of_in")]
IndexOfIn => "index_of_in",
};
write!(f, "list.{name}")
}
Expand Down Expand Up @@ -243,6 +249,8 @@ impl From<ListFunction> for SpecialEq<Arc<dyn ColumnsUdf>> {
NUnique => map!(n_unique),
#[cfg(feature = "list_to_struct")]
ToStruct(args) => map!(to_struct, &args),
#[cfg(feature = "list_index_of_in")]
IndexOfIn => map_as_slice!(index_of_in),
}
}
}
Expand Down Expand Up @@ -547,6 +555,14 @@ pub(super) fn count_matches(args: &[Column]) -> PolarsResult<Column> {
list_count_matches(ca, element.get(0).unwrap()).map(Column::from)
}

#[cfg(feature = "list_index_of_in")]
pub(super) fn index_of_in(args: &[Column]) -> PolarsResult<Column> {
let s = &args[0];
let needles = &args[1];
let ca = s.list()?;
list_index_of_in(ca, needles.as_materialized_series()).map(Column::from)
}

pub(super) fn sum(s: &Column) -> PolarsResult<Column> {
s.list()?.lst_sum().map(Column::from)
}
Expand Down
15 changes: 15 additions & 0 deletions crates/polars-plan/src/dsl/list.rs
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,21 @@ impl ListNameSpace {
)
}

#[cfg(feature = "list_index_of_in")]
/// Find the index of a needle in the list.
pub fn index_of_in<N: Into<Expr>>(self, needle: N) -> Expr {
Expr::Function {
input: vec![self.0, needle.into()],
function: FunctionExpr::ListExpr(ListFunction::IndexOfIn),
options: FunctionOptions {
collect_groups: ApplyOptions::ElementWise,
flags: FunctionFlags::default(),
cast_options: Some(CastingRules::FirstArgInnerLossless),
..Default::default()
},
}
}

#[cfg(feature = "list_sets")]
fn set_operation(self, other: Expr, set_operation: SetOperation) -> Expr {
Expr::Function {
Expand Down
Loading
Loading