diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index 94e93c59839f..ceaafc56b4b4 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -174,6 +174,7 @@ sign = ["polars-plan/sign"] timezones = ["polars-plan/timezones"] list_gather = ["polars-ops/list_gather", "polars-plan/list_gather"] list_count = ["polars-ops/list_count", "polars-plan/list_count"] +list_index_of_in = ["polars-ops/list_index_of_in", "polars-plan/list_index_of_in"] array_count = ["polars-ops/array_count", "polars-plan/array_count", "dtype-array"] true_div = ["polars-plan/true_div"] extract_jsonpath = ["polars-plan/extract_jsonpath", "polars-ops/extract_jsonpath"] @@ -376,6 +377,7 @@ features = [ "list_drop_nulls", "list_eval", "list_gather", + "list_index_of_in", "list_sample", "list_sets", "list_to_struct", diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 1034bf0db3f4..f1513588dc41 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -146,3 +146,4 @@ abs = [] cov = [] gather = [] replace = ["is_in"] +list_index_of_in = ["index_of"] diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs new file mode 100644 index 000000000000..f0146bcd7778 --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -0,0 +1,159 @@ +use polars_core::match_arrow_dtype_apply_macro_ca; + +use super::*; +use crate::series::{index_of, index_of_null}; + +macro_rules! to_anyvalue_iterator { + ($ca:expr) => {{ + use polars_core::prelude::AnyValue; + Box::new($ca.iter().map(AnyValue::from)) + }}; +} + +fn series_to_anyvalue_iter(series: &Series) -> Box + '_> { + let dtype = series.dtype(); + match dtype { + #[cfg(feature = "dtype-date")] + DataType::Date => { + return to_anyvalue_iterator!(series.date().unwrap()); + }, + #[cfg(feature = "dtype-datetime")] + DataType::Datetime(_, _) => { + return to_anyvalue_iterator!(series.datetime().unwrap()); + }, + #[cfg(feature = "dtype-time")] + DataType::Time => { + return to_anyvalue_iterator!(series.time().unwrap()); + }, + #[cfg(feature = "dtype-duration")] + DataType::Duration(_) => { + return to_anyvalue_iterator!(series.duration().unwrap()); + }, + DataType::Binary => { + return to_anyvalue_iterator!(series.binary().unwrap()); + }, + #[cfg(feature = "dtype-decimal")] + DataType::Decimal(_, _) => { + return to_anyvalue_iterator!(series.decimal().unwrap()); + }, + _ => (), + }; + match_arrow_dtype_apply_macro_ca!( + series, + to_anyvalue_iterator, + to_anyvalue_iterator, + to_anyvalue_iterator + ) +} + +/// Given a needle, or needles, find the corresponding needle in each value of a +/// ListChunked. +pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult { + // Handle scalar case separately, since we can do some optimizations given + // the extra knowledge we have. + if needles.len() == 1 { + let needle = needles.get(0).unwrap(); + return list_index_of_in_for_scalar(ca, needle); + } + + polars_ensure!( + ca.len() == needles.len(), + ComputeError: "shapes don't match: expected {} elements in 'index_of_in' comparison, got {}", + ca.len(), + needles.len() + ); + let dtype = needles.dtype(); + let owned; + let needle_iter = if dtype.is_list() + || dtype.is_array() + || dtype.is_enum() + || dtype.is_categorical() + || dtype.is_struct() + { + owned = needles.rechunk(); + Box::new(owned.iter()) + } else { + // Optimized versions: + series_to_anyvalue_iter(needles) + }; + + let mut builder = PrimitiveChunkedBuilder::::new(ca.name().clone(), ca.len()); + ca.amortized_iter() + .zip(needle_iter) + .for_each(|(opt_series, needle)| match (opt_series, needle) { + (None, _) => builder.append_null(), + (Some(subseries), needle) => { + let needle = Scalar::new(needles.dtype().clone(), needle.into_static()); + builder.append_option( + index_of(subseries.as_ref(), needle) + .unwrap() + .map(|v| v.try_into().unwrap()), + ); + }, + }); + + Ok(builder.finish().into()) +} + +macro_rules! process_series_for_numeric_value { + ($extractor:ident, $needle:ident) => {{ + use arrow::array::PrimitiveArray; + + use crate::series::index_of_value; + + let needle = $needle.extract::<$extractor>().unwrap(); + Box::new(move |subseries| { + index_of_value::<_, PrimitiveArray<$extractor>>(subseries.$extractor().unwrap(), needle) + }) + }}; +} + +#[allow(clippy::type_complexity)] // For the Box +fn list_index_of_in_for_scalar(ca: &ListChunked, needle: AnyValue<'_>) -> PolarsResult { + polars_ensure!( + ca.dtype().inner_dtype().unwrap() == &needle.dtype() || needle.dtype().is_null(), + ComputeError: "dtypes didn't match: series values have dtype {} and needle has dtype {}", + ca.dtype().inner_dtype().unwrap(), + needle.dtype() + ); + + let mut builder = PrimitiveChunkedBuilder::::new(ca.name().clone(), ca.len()); + let needle = needle.into_static(); + let inner_dtype = ca.dtype().inner_dtype().unwrap(); + let needle_dtype = needle.dtype(); + + let process_series: Box Option> = match needle_dtype { + DataType::Null => Box::new(index_of_null), + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => process_series_for_numeric_value!(u8, needle), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => process_series_for_numeric_value!(u16, needle), + DataType::UInt32 => process_series_for_numeric_value!(u32, needle), + DataType::UInt64 => process_series_for_numeric_value!(u64, needle), + #[cfg(feature = "dtype-i8")] + DataType::Int8 => process_series_for_numeric_value!(i8, needle), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => process_series_for_numeric_value!(i16, needle), + DataType::Int32 => process_series_for_numeric_value!(i32, needle), + DataType::Int64 => process_series_for_numeric_value!(i64, needle), + #[cfg(feature = "dtype-i128")] + DataType::Int128 => process_series_for_numeric_value!(i128, needle), + DataType::Float32 => process_series_for_numeric_value!(f32, needle), + DataType::Float64 => process_series_for_numeric_value!(f64, needle), + // Just use the general purpose index_of() function: + _ => Box::new(|subseries| { + let needle = Scalar::new(inner_dtype.clone(), needle.clone()); + index_of(subseries, needle).unwrap() + }), + }; + + ca.amortized_iter().for_each(|opt_series| { + if let Some(subseries) = opt_series { + builder + .append_option(process_series(subseries.as_ref()).map(|v| v.try_into().unwrap())); + } else { + builder.append_null(); + } + }); + Ok(builder.finish().into()) +} diff --git a/crates/polars-ops/src/chunked_array/list/mod.rs b/crates/polars-ops/src/chunked_array/list/mod.rs index a93b1ed7e2b3..76f8b8414fc8 100644 --- a/crates/polars-ops/src/chunked_array/list/mod.rs +++ b/crates/polars-ops/src/chunked_array/list/mod.rs @@ -6,6 +6,8 @@ mod count; mod dispersion; #[cfg(feature = "hash")] pub(crate) mod hash; +#[cfg(feature = "list_index_of_in")] +mod index_of_in; mod min_max; mod namespace; #[cfg(feature = "list_sets")] @@ -18,6 +20,8 @@ mod to_struct; pub use count::*; #[cfg(not(feature = "list_count"))] use count::*; +#[cfg(feature = "list_index_of_in")] +pub use index_of_in::*; pub use namespace::*; #[cfg(feature = "list_sets")] pub use sets::*; diff --git a/crates/polars-ops/src/series/ops/index_of.rs b/crates/polars-ops/src/series/ops/index_of.rs index 6c2536e263ec..29d74f64f3b4 100644 --- a/crates/polars-ops/src/series/ops/index_of.rs +++ b/crates/polars-ops/src/series/ops/index_of.rs @@ -5,7 +5,10 @@ use polars_utils::total_ord::TotalEq; use row_encode::encode_rows_unordered; /// Find the index of the value, or ``None`` if it can't be found. -fn index_of_value<'a, DT, AR>(ca: &'a ChunkedArray
, value: AR::ValueT<'a>) -> Option +pub(crate) fn index_of_value<'a, DT, AR>( + ca: &'a ChunkedArray
, + value: AR::ValueT<'a>, +) -> Option where DT: PolarsDataType, AR: StaticArray, @@ -51,14 +54,31 @@ macro_rules! try_index_of_numeric_ca { ($ca:expr, $value:expr) => {{ let ca = $ca; let value = $value; - // extract() returns None if casting failed, so consider an extract() - // failure as not finding the value. Nulls should have been handled - // earlier. + // extract() returns None if casting failed, and by this point Nulls + // have been handled, and everything should have been cast to matching + // dtype otherwise. let value = value.value().extract().unwrap(); index_of_numeric_value(ca, value) }}; } +/// Find the index of nulls within a Series. +pub(crate) fn index_of_null(series: &Series) -> Option { + let mut index = 0; + for chunk in series.chunks() { + let length = chunk.len(); + if let Some(bitmap) = chunk.validity() { + let leading_ones = bitmap.leading_ones(); + if leading_ones < length { + return Some(index + leading_ones); + } + } else { + index += length; + } + } + None +} + /// Find the index of a given value (the first and only entry in `value_series`) /// within the series. pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult> { @@ -80,19 +100,7 @@ pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult> // Series is not null, and the value is null: if needle.is_null() { - let mut index = 0; - for chunk in series.chunks() { - let length = chunk.len(); - if let Some(bitmap) = chunk.validity() { - let leading_ones = bitmap.leading_ones(); - if leading_ones < length { - return Ok(Some(index + leading_ones)); - } - } else { - index += length; - } - } - return Ok(None); + return Ok(index_of_null(series)); } if series.dtype().is_primitive_numeric() { diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 22c0aaf61a2f..dead8de98374 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -98,6 +98,7 @@ dtype-struct = ["polars-core/dtype-struct"] object = ["polars-core/object"] list_gather = ["polars-ops/list_gather"] list_count = ["polars-ops/list_count"] +list_index_of_in = ["polars-ops/list_index_of_in"] array_count = ["polars-ops/array_count", "dtype-array"] trigonometry = [] sign = [] @@ -293,6 +294,7 @@ features = [ "streaming", "true_div", "sign", + "list_index_of_in", ] # defines the configuration attribute `docsrs` rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index e6b1468f4f82..7e374031b905 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -58,6 +58,8 @@ pub enum ListFunction { ToArray(usize), #[cfg(feature = "list_to_struct")] ToStruct(ListToStructArgs), + #[cfg(feature = "list_index_of_in")] + IndexOfIn, } impl ListFunction { @@ -107,6 +109,8 @@ impl ListFunction { NUnique => mapper.with_dtype(IDX_DTYPE), #[cfg(feature = "list_to_struct")] ToStruct(args) => mapper.try_map_dtype(|x| args.get_output_dtype(x)), + #[cfg(feature = "list_index_of_in")] + IndexOfIn => mapper.with_dtype(IDX_DTYPE), } } } @@ -180,6 +184,8 @@ impl Display for ListFunction { ToArray(_) => "to_array", #[cfg(feature = "list_to_struct")] ToStruct(_) => "to_struct", + #[cfg(feature = "list_index_of_in")] + IndexOfIn => "index_of_in", }; write!(f, "list.{name}") } @@ -243,6 +249,8 @@ impl From for SpecialEq> { NUnique => map!(n_unique), #[cfg(feature = "list_to_struct")] ToStruct(args) => map!(to_struct, &args), + #[cfg(feature = "list_index_of_in")] + IndexOfIn => map_as_slice!(index_of_in), } } } @@ -547,6 +555,14 @@ pub(super) fn count_matches(args: &[Column]) -> PolarsResult { list_count_matches(ca, element.get(0).unwrap()).map(Column::from) } +#[cfg(feature = "list_index_of_in")] +pub(super) fn index_of_in(args: &[Column]) -> PolarsResult { + let s = &args[0]; + let needles = &args[1]; + let ca = s.list()?; + list_index_of_in(ca, needles.as_materialized_series()).map(Column::from) +} + pub(super) fn sum(s: &Column) -> PolarsResult { s.list()?.lst_sum().map(Column::from) } diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index d5c2622b5afb..ea85f57d11dd 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -322,6 +322,21 @@ impl ListNameSpace { ) } + #[cfg(feature = "list_index_of_in")] + /// Find the index of a needle in the list. + pub fn index_of_in>(self, needle: N) -> Expr { + Expr::Function { + input: vec![self.0, needle.into()], + function: FunctionExpr::ListExpr(ListFunction::IndexOfIn), + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + flags: FunctionFlags::default(), + cast_options: Some(CastingRules::FirstArgInnerLossless), + ..Default::default() + }, + } + } + #[cfg(feature = "list_sets")] fn set_operation(self, other: Expr, set_operation: SetOperation) -> Expr { Expr::Function { diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs index f8b3d802c5f1..0b660d146d70 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs @@ -295,6 +295,17 @@ impl OptimizationRule for TypeCoercionRule { } } }, + CastingRules::FirstArgInnerLossless => { + if super_type.leaf_dtype().is_integer() { + for other in &input[1..] { + let other_dtype = + other.dtype(&input_schema, Context::Default, expr_arena)?; + if other_dtype.leaf_dtype().is_float() { + polars_bail!(InvalidOperation: "cannot cast lossless between {} and {}", super_type, other_dtype) + } + } + } + }, } if matches!(super_type, DataType::Unknown(UnknownKind::Any)) { @@ -310,7 +321,23 @@ impl OptimizationRule for TypeCoercionRule { _ => {}, } - for (e, dtype) in input.iter_mut().zip(dtypes) { + for (i, (e, dtype)) in input.iter_mut().zip(dtypes).enumerate() { + let new_super_type = if matches!( + casting_rules, + CastingRules::FirstArgInnerLossless + ) && (i > 0) + { + if let Some(inner_type) = super_type.inner_dtype() { + inner_type + } else { + polars_bail!( + InvalidOperation: + "FirstArgInnerLossless only makes sense for types like list or array" + ); + } + } else { + &super_type + }; match super_type { #[cfg(feature = "dtype-categorical")] DataType::Categorical(_, _) if dtype.is_string() => { @@ -319,7 +346,7 @@ impl OptimizationRule for TypeCoercionRule { _ => cast_expr_ir( e, &dtype, - &super_type, + new_super_type, expr_arena, CastOptions::NonStrict, )?, diff --git a/crates/polars-plan/src/plans/options.rs b/crates/polars-plan/src/plans/options.rs index c2b654c72834..f16590b66f15 100644 --- a/crates/polars-plan/src/plans/options.rs +++ b/crates/polars-plan/src/plans/options.rs @@ -210,6 +210,9 @@ pub enum CastingRules { /// whereas int to int is considered lossless. /// Overflowing is not considered in this flag, that's handled in `strict` casting FirstArgLossless, + /// Cast (in a lossless way) to the inner dtype of the first argument, + /// presumably a list or array. + FirstArgInnerLossless, Supertype(SuperTypeOptions), } diff --git a/crates/polars-python/Cargo.toml b/crates/polars-python/Cargo.toml index c5a63f830c6f..0f5cd8151835 100644 --- a/crates/polars-python/Cargo.toml +++ b/crates/polars-python/Cargo.toml @@ -176,6 +176,7 @@ new_streaming = ["polars-lazy/new_streaming"] bitwise = ["polars/bitwise"] approx_unique = ["polars/approx_unique"] string_normalize = ["polars/string_normalize"] +list_index_of_in = ["polars/list_index_of_in"] dtype-i8 = [] dtype-i16 = [] @@ -211,6 +212,7 @@ operations = [ "list_any_all", "list_drop_nulls", "list_sample", + "list_index_of_in", "cutqcut", "rle", "extract_groups", diff --git a/crates/polars-python/src/expr/list.rs b/crates/polars-python/src/expr/list.rs index b8f10fc60c3e..b39b8cf41387 100644 --- a/crates/polars-python/src/expr/list.rs +++ b/crates/polars-python/src/expr/list.rs @@ -55,6 +55,11 @@ impl PyExpr { .into() } + #[cfg(feature = "list_index_of_in")] + fn list_index_of_in(&self, value: PyExpr) -> Self { + self.inner.clone().list().index_of_in(value.inner).into() + } + fn list_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self { self.inner .clone() diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index 2ddfd3bc0289..0b519156ab48 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -236,6 +236,7 @@ trigonometry = ["polars-lazy?/trigonometry"] true_div = ["polars-lazy?/true_div"] unique_counts = ["polars-ops/unique_counts", "polars-lazy?/unique_counts"] zip_with = ["polars-core/zip_with"] +list_index_of_in = ["polars-ops/list_index_of_in", "polars-lazy?/list_index_of_in"] bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx", "polars-utils/bigidx"] polars_cloud = ["polars-lazy?/polars_cloud"] diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index b6cc0d32b77f..ead974d79d85 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -71,6 +71,7 @@ list_any_all = ["polars-python/list_any_all"] array_any_all = ["polars-python/array_any_all"] list_drop_nulls = ["polars-python/list_drop_nulls"] list_sample = ["polars-python/list_sample"] +list_index_of_in = ["polars-python/list_index_of_in"] cutqcut = ["polars-python/cutqcut"] rle = ["polars-python/rle"] extract_groups = ["polars-python/extract_groups"] diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index de98cf5869ce..84f9db6a8fb4 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -1059,6 +1059,39 @@ def count_matches(self, element: IntoExpr) -> Expr: element = parse_into_expression(element, str_as_lit=True) return wrap_expr(self._pyexpr.list_count_matches(element)) + def index_of_in(self, needles: IntoExpr) -> Expr: + """ + For each List, return the index of the first value equal to a needle. + + Parameters + ---------- + needles + The value(s) to search for. + + Examples + -------- + >>> df = pl.DataFrame( + ... { + ... "lists": [[1, 2, 3], [], [None, 3], [5, 6, 7]], + ... "needles": [3, 0, 3, 7], + ... } + ... ) + >>> df.select(pl.col("lists").list.index_of_in(pl.col("needles"))) + shape: (4, 1) + ┌───────┐ + │ lists │ + │ --- │ + │ u32 │ + ╞═══════╡ + │ 2 │ + │ null │ + │ 1 │ + │ 2 │ + └───────┘ + """ + element = parse_into_expression(needles, str_as_lit=True, list_as_series=False) + return wrap_expr(self._pyexpr.list_index_of_in(element)) + def to_array(self, width: int) -> Expr: """ Convert a List column into an Array column with the same inner data type. diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 72c435b157d7..dfa8f855f60f 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -1054,3 +1054,26 @@ def set_symmetric_difference(self, other: Series) -> Series: [5, 7, 8] ] """ # noqa: W505 + + def index_of_in(self, needles: IntoExpr) -> Series: + """ + For each list in the series, return the index of the first value equal to the corresponding needle. + + Parameters + ---------- + needles + The value(s) to search for. + + Examples + -------- + >>> a = pl.Series([[1, 2, 3], [], [None, 3], [5, 6, 7]]) + >>> a.list.index_of_in(pl.Series([3, 0, 3, 7])) + shape: (4,) + Series: '' [u32] + [ + 2 + null + 1 + 2 + ] + """ # noqa: W505 diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py new file mode 100644 index 000000000000..3242ad0a44e7 --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -0,0 +1,319 @@ +"""Tests for ``.list.index_of_in()``.""" + +from __future__ import annotations + +from datetime import date, datetime, time, timedelta +from decimal import Decimal +from typing import TYPE_CHECKING + +import numpy as np +import pytest + +import polars as pl +from polars.exceptions import ComputeError, InvalidOperationError +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import FLOAT_DTYPES, INTEGER_DTYPES +from tests.unit.operations.test_index_of import get_expected_index + +if TYPE_CHECKING: + from polars._typing import PythonLiteral + +IdxType = pl.get_index_type() + + +def assert_index_of_in_from_scalar( + list_series: pl.Series, value: PythonLiteral | None +) -> None: + expected_indexes = [ + None if sub_series is None else get_expected_index(sub_series, value) + for sub_series in list_series + ] + + original_value = value + del value + for updated_value in (original_value, pl.lit(original_value)): + # Eager API: + assert_series_equal( + list_series.list.index_of_in(updated_value), + pl.Series(list_series.name, expected_indexes, dtype=IdxType), + ) + # Lazy API: + assert_frame_equal( + pl.LazyFrame({"lists": list_series}) + .select(pl.col("lists").list.index_of_in(updated_value)) + .collect(), + pl.DataFrame({"lists": expected_indexes}, schema={"lists": IdxType}), + ) + + +def assert_index_of_in_from_series( + list_series: pl.Series, + values: pl.Series, +) -> None: + expected_indexes = [ + None if sub_series is None else get_expected_index(sub_series, value) + for (sub_series, value) in zip(list_series, values) + ] + + # Eager API: + assert_series_equal( + list_series.list.index_of_in(values), + pl.Series(list_series.name, expected_indexes, dtype=IdxType), + ) + # Lazy API: + assert_frame_equal( + pl.LazyFrame({"lists": list_series, "values": values}) + .select(pl.col("lists").list.index_of_in(pl.col("values"))) + .collect(), + pl.DataFrame({"lists": expected_indexes}, schema={"lists": IdxType}), + ) + + +def test_index_of_in_from_scalar() -> None: + list_series = pl.Series([[3, 1], [2, 4], [5, 3, 1]]) + assert_index_of_in_from_scalar(list_series, 1) + + +def test_index_of_in_from_series() -> None: + list_series = pl.Series([[3, 1], [2, 4], [5, 3, 1]]) + values = pl.Series([1, 2, 6]) + assert_index_of_in_from_series(list_series, values) + + +@pytest.mark.parametrize("lists_dtype", INTEGER_DTYPES) +@pytest.mark.parametrize("values_dtype", INTEGER_DTYPES) +def test_integer(lists_dtype: pl.NumericType, values_dtype: pl.NumericType) -> None: + def to_int(expr: pl.Expr) -> int: + return pl.select(expr).item() # type: ignore[no-any-return] + + lists = [ + [51, 3], + [None, 4], + None, + [to_int(lists_dtype.max()), 3], + [6, to_int(lists_dtype.min())], + ] + lists_series = pl.Series(lists, dtype=pl.List(lists_dtype)) + chunked_series = pl.concat( + [pl.Series([[100, 7]], dtype=pl.List(lists_dtype)), lists_series], rechunk=False + ) + values: list[None | PythonLiteral] = [ + to_int(v) for v in [lists_dtype.max() - 1, lists_dtype.min() + 1] + ] + for sublist in lists: + if sublist is None: + values.append(None) + else: + values.extend(sublist) # type: ignore[arg-type] + + # Scalars: + for s in [lists_series, chunked_series]: + for value in values: + assert_index_of_in_from_scalar(s, value) + + # Series + search_series = pl.Series([3, 4, 7, None, 6], dtype=values_dtype) + assert_index_of_in_from_series(lists_series, search_series) + search_series = pl.Series([17, 3, 4, 7, None, 6], dtype=values_dtype) + assert_index_of_in_from_series(chunked_series, search_series) + + +def test_no_lossy_numeric_casts() -> None: + list_series = pl.Series([[3]], dtype=pl.List(pl.Int8())) + for will_be_lossy in [np.float32(3.1), np.float64(3.1), 50.9]: + with pytest.raises(InvalidOperationError, match="cannot cast lossless"): + list_series.list.index_of_in(will_be_lossy) # type: ignore[arg-type] + + for will_be_lossy in [300, -300, pl.lit(300, dtype=pl.Int16)]: + with pytest.raises(InvalidOperationError, match="conversion from"): + list_series.list.index_of_in(will_be_lossy) # type: ignore[arg-type] + + +def test_multichunk_needles() -> None: + series = pl.Series([[1, 3], [3, 2], [4, 5, 3]]) + needles = pl.concat([pl.Series([3, 1]), pl.Series([3])]) + assert series.list.index_of_in(needles).to_list() == [1, None, 2] + + +def test_mismatched_length() -> None: + """ + Mismatched lengths result in an error. + + Unfortunately a length 1 Series will be treated as a _scalar_, which seems + weird, but that's how e.g. list.contains() works so maybe that's + intentional. + """ + series = pl.Series([[1, 3], [3, 2], [4, 5, 3]]) + needles = pl.Series([3, 2]) + with pytest.raises(ComputeError, match="shapes don't match"): + series.list.index_of_in(pl.Series(needles)) + + +def all_values(list_series: pl.Series) -> list[object]: + values = [] + for subseries in list_series.to_list(): + if subseries is not None: + values.extend(subseries) + return values + + +@pytest.mark.parametrize("float_dtype", FLOAT_DTYPES) +def test_float(float_dtype: pl.DataType) -> None: + lists = [ + [1.5, np.nan, np.inf], + [3.0, None, -np.inf], + [0.0, -0.0, -np.nan], + None, + [None, None], + ] + lists_series = pl.Series(lists, dtype=pl.List(float_dtype)) + + # Scalar + for value in all_values(lists_series) + [ + None, + 3.5, + np.float64(1.5), + np.float32(3.0), + ]: + assert_index_of_in_from_scalar(lists_series, value) # type: ignore[arg-type] + + # Series + assert_index_of_in_from_series( + lists_series, pl.Series([1.5, -np.inf, -np.nan, 3, None], dtype=float_dtype) + ) + + +@pytest.mark.parametrize( + ("list_series", "extra_values"), + [ + (pl.Series([["abc", "def"], ["ghi", "zzz", "X"], ["Y"]]), ["foo"]), + (pl.Series([[b"abc", b"def"], [b"ghi", b"zzz", b"X"], [b"Y"]]), [b"foo"]), + (pl.Series([[True, None, False], [True, False]]), []), + ( + pl.Series( + [ + [datetime(1997, 12, 31), datetime(1996, 1, 1)], + [datetime(1997, 12, 30), datetime(1996, 1, 2)], + ] + ), + [datetime(2003, 1, 1)], + ), + ( + pl.Series( + [ + [date(1997, 12, 31), date(1996, 1, 1)], + [date(1997, 12, 30), date(1996, 1, 2)], + ] + ), + [date(2003, 1, 1)], + ), + ( + pl.Series( + [ + [time(16, 12, 31), None, time(11, 10, 53)], + [time(16, 11, 31), time(11, 10, 54)], + ] + ), + [time(12, 6, 7)], + ), + ( + pl.Series( + [ + [timedelta(hours=12), None, timedelta(minutes=3)], + [timedelta(hours=3), None, timedelta(hours=1)], + ], + ), + [timedelta(minutes=7)], + ), + ( + pl.Series( + [[Decimal(12), None, Decimal(3)], [Decimal(500), None, Decimal(16)]] + ), + [Decimal(4)], + ), + ( + pl.Series([[[1, 2], None], [[4, 5], [6]], None, [[None, 3, 5]], [None]]), + [[5, 7], []], + ), + ( + pl.Series( + [ + [[[1, 2], None], [[4, 5], [6]]], + [[[None, 3, 5]]], + None, + [None], + [[None]], + [[[None]]], + ] + ), + [[[5, 7]], [[]], [None]], + ), + ( + pl.Series( + [[[1, 2]], [[4, 5]], [[None, 3]], [None], None], + dtype=pl.List(pl.Array(pl.Int64(), 2)), + ), + [[5, 7]], + ), + ( + pl.Series( + [ + [{"a": 1, "b": 2}, None], + [{"a": 3, "b": 4}, {"a": None, "b": 2}], + None, + ], + dtype=pl.List(pl.Struct({"a": pl.Int64(), "b": pl.Int64()})), + ), + [{"a": 7, "b": None}, {"a": 6, "b": 4}], + ), + ( + pl.Series( + [["a", "c"], [None, "b"], ["b", "a", "a", "c"], None, [None]], + dtype=pl.List(pl.Enum(["c", "b", "a"])), + ), + [], + ), + ], +) +def test_other_types(list_series: pl.Series, extra_values: list[PythonLiteral]) -> None: + needles_series = pl.Series( + [ + None if sublist is None else sublist[i % len(sublist)] + for (i, sublist) in enumerate(list_series) + ], + dtype=list_series.dtype.inner, # type: ignore[attr-defined] + ) + assert_index_of_in_from_series(list_series, needles_series) + + values = all_values(list_series) + extra_values + [None] + for value in values: + assert_index_of_in_from_scalar(list_series, value) # type: ignore [arg-type] + + +@pytest.mark.xfail(reason="Depends on Series.index_of supporting Categoricals") +def test_categorical() -> None: + # When this starts passing, convert to test_other_types entry above. + series = pl.Series( + [["a", "c"], [None, "b"], ["b", "a", "a", "c"], None, [None]], + dtype=pl.List(pl.Categorical), + ) + assert series.list.index_of_in("b").to_list() == [None, 1, 0, None, None] + + +def test_nulls() -> None: + series = pl.Series([[None, None], None], dtype=pl.List(pl.Null)) + assert series.list.index_of_in(None).to_list() == [0, None] + + series = pl.Series([None, [None, None]], dtype=pl.List(pl.Int64)) + assert series.list.index_of_in(None).to_list() == [None, 0] + assert series.list.index_of_in(1).to_list() == [None, None] + + +def test_wrong_type() -> None: + series = pl.Series([[1, 2, 3], [4, 5]]) + with pytest.raises( + ComputeError, + match=r"dtypes didn't match: series values have dtype i64 and needle has dtype list\[i64\]", + ): + # Searching for a list won't work: + series.list.index_of_in([1, 2]) diff --git a/py-polars/tests/unit/operations/test_index_of.py b/py-polars/tests/unit/operations/test_index_of.py index 7301d6dfd311..b725f0dff468 100644 --- a/py-polars/tests/unit/operations/test_index_of.py +++ b/py-polars/tests/unit/operations/test_index_of.py @@ -25,12 +25,14 @@ def isnan(value: object) -> bool: return np.isnan(value) # type: ignore[no-any-return] -def assert_index_of( - series: pl.Series, - value: IntoExpr, - convert_to_literal: bool = False, -) -> None: - """``Series.index_of()`` returns the index, or ``None`` if it can't be found.""" +def to_python(maybe_series: object) -> object: + if isinstance(maybe_series, pl.Series): + return [to_python(sub) for sub in maybe_series.to_list()] + else: + return maybe_series + + +def get_expected_index(series: pl.Series, value: IntoExpr) -> int | None: if isnan(value): expected_index = None for i, o in enumerate(series.to_list()): @@ -39,21 +41,27 @@ def assert_index_of( break else: try: - expected_index = series.to_list().index(value) + expected_index = to_python(series).index( # type: ignore[attr-defined] + to_python(value) + ) except ValueError: expected_index = None if expected_index == -1: expected_index = None + return expected_index - if convert_to_literal: - value = pl.lit(value, dtype=series.dtype) - # Eager API: - assert series.index_of(value) == expected_index - # Lazy API: - assert pl.LazyFrame({"series": series}).select( - pl.col("series").index_of(value) - ).collect().get_column("series").to_list() == [expected_index] +def assert_index_of(series: pl.Series, value: IntoExpr) -> None: + """``Series.index_of()`` returns the index, or ``None`` if it can't be found.""" + expected_index = get_expected_index(series, value) + orig_value = value + for value in (orig_value, pl.lit(orig_value, dtype=series.dtype)): + # Eager API: + assert series.index_of(value) == expected_index + # Lazy API: + assert pl.LazyFrame({"series": series}).select( + pl.col("series").index_of(value) + ).collect().get_column("series").to_list() == [expected_index] @pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64]) @@ -74,8 +82,7 @@ def test_float(dtype: pl.DataType) -> None: ] for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]: for value in values: - assert_index_of(s, value, convert_to_literal=True) - assert_index_of(s, value, convert_to_literal=False) + assert_index_of(s, value) for value in extra_values: # type: ignore[assignment] assert_index_of(s, value) @@ -133,11 +140,9 @@ def test_integer(dtype: pl.DataType) -> None: for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]: value: IntoExpr for value in values: - assert_index_of(s, value, convert_to_literal=True) - assert_index_of(s, value, convert_to_literal=False) + assert_index_of(s, value) for value in extra_values: - assert_index_of(s, value, convert_to_literal=True) - assert_index_of(s, value, convert_to_literal=False) + assert_index_of(s, value) # Can't cast floats: for f in [np.float32(3.1), np.float64(3.1), 50.9]: @@ -270,8 +275,7 @@ def test_other_types( ) for s in series_variants: for value in expected_values: - assert_index_of(s, value, convert_to_literal=True) - assert_index_of(s, value, convert_to_literal=False) + assert_index_of(s, value) # Extra values may not be expressible as literal of correct dtype, so # don't try: for value in extra_values: @@ -302,14 +306,7 @@ def test_error_on_multiple_values() -> None: pl.Series("a", [1, 2, 3]).index_of(pl.Series([2, 3])) -@pytest.mark.parametrize( - "convert_to_literal", - [ - True, - False, - ], -) -def test_enum(convert_to_literal: bool) -> None: +def test_enum() -> None: series = pl.Series(["a", "c", None, "b"], dtype=pl.Enum(["c", "b", "a"])) expected_values = series.to_list() for s in [ @@ -319,27 +316,16 @@ def test_enum(convert_to_literal: bool) -> None: series.sort(descending=True), ]: for value in expected_values: - assert_index_of(s, value, convert_to_literal=convert_to_literal) + assert_index_of(s, value) -@pytest.mark.parametrize( - "convert_to_literal", - [ - pytest.param( - True, - marks=pytest.mark.xfail( - reason="https://github.com/pola-rs/polars/issues/20318" - ), - ), - pytest.param( - False, - marks=pytest.mark.xfail( - reason="https://github.com/pola-rs/polars/issues/20171" - ), - ), - ], +@pytest.mark.xfail( + reason=( + "https://github.com/pola-rs/polars/issues/20318 and " + + "https://github.com/pola-rs/polars/issues/20171" + ) ) -def test_categorical(convert_to_literal: bool) -> None: +def test_categorical() -> None: series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical) expected_values = series.to_list() for s in [ @@ -349,4 +335,4 @@ def test_categorical(convert_to_literal: bool) -> None: series.sort(descending=True), ]: for value in expected_values: - assert_index_of(s, value, convert_to_literal=convert_to_literal) + assert_index_of(s, value)