From 1a09f7b2be7ec1bbe08fea52cc08babb108e6f67 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Wed, 15 Jan 2025 10:34:34 -0500 Subject: [PATCH 01/30] Add the boilerplate for a new expression method `list.index_of_in()`. --- crates/polars-lazy/Cargo.toml | 2 ++ crates/polars-ops/Cargo.toml | 1 + .../src/chunked_array/list/index_of_in.rs | 5 +++++ .../polars-ops/src/chunked_array/list/mod.rs | 4 ++++ crates/polars-plan/Cargo.toml | 2 ++ .../polars-plan/src/dsl/function_expr/list.rs | 17 ++++++++++++++++ crates/polars-plan/src/dsl/list.rs | 6 ++++++ crates/polars-python/Cargo.toml | 2 ++ crates/polars-python/src/expr/list.rs | 5 +++++ crates/polars/Cargo.toml | 1 + py-polars/Cargo.toml | 1 + py-polars/polars/expr/list.py | 16 +++++++++++++++ .../namespaces/list/test_index_of_in.py | 20 +++++++++++++++++++ 13 files changed, 82 insertions(+) create mode 100644 crates/polars-ops/src/chunked_array/list/index_of_in.rs create mode 100644 py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py diff --git a/crates/polars-lazy/Cargo.toml b/crates/polars-lazy/Cargo.toml index ed172a178cb1..17c82bd1fd66 100644 --- a/crates/polars-lazy/Cargo.toml +++ b/crates/polars-lazy/Cargo.toml @@ -172,6 +172,7 @@ sign = ["polars-plan/sign"] timezones = ["polars-plan/timezones"] list_gather = ["polars-ops/list_gather", "polars-plan/list_gather"] list_count = ["polars-ops/list_count", "polars-plan/list_count"] +list_index_of_in = ["polars-ops/list_index_of_in", "polars-plan/list_index_of_in"] array_count = ["polars-ops/array_count", "polars-plan/array_count", "dtype-array"] true_div = ["polars-plan/true_div"] extract_jsonpath = ["polars-plan/extract_jsonpath", "polars-ops/extract_jsonpath"] @@ -377,6 +378,7 @@ features = [ "list_drop_nulls", "list_eval", "list_gather", + "list_index_of_in", "list_sample", "list_sets", "list_to_struct", diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index aca43eec83ad..16460f588b54 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -145,3 +145,4 @@ abs = [] cov = [] gather = [] replace = ["is_in"] +list_index_of_in = [] diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs new file mode 100644 index 000000000000..902fb48bfd1b --- /dev/null +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -0,0 +1,5 @@ +use super::*; + +pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult { + todo!("Implement me"); +} diff --git a/crates/polars-ops/src/chunked_array/list/mod.rs b/crates/polars-ops/src/chunked_array/list/mod.rs index a93b1ed7e2b3..f6b40ecf98e0 100644 --- a/crates/polars-ops/src/chunked_array/list/mod.rs +++ b/crates/polars-ops/src/chunked_array/list/mod.rs @@ -13,6 +13,8 @@ mod sets; mod sum_mean; #[cfg(feature = "list_to_struct")] mod to_struct; +#[cfg(feature = "list_index_of_in")] +mod index_of_in; #[cfg(feature = "list_count")] pub use count::*; @@ -23,6 +25,8 @@ pub use namespace::*; pub use sets::*; #[cfg(feature = "list_to_struct")] pub use to_struct::*; +#[cfg(feature = "list_index_of_in")] +pub use index_of_in::*; pub trait AsList { fn as_list(&self) -> &ListChunked; diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index 8f93faeebae8..2962122e5311 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -98,6 +98,7 @@ dtype-struct = ["polars-core/dtype-struct"] object = ["polars-core/object"] list_gather = ["polars-ops/list_gather"] list_count = ["polars-ops/list_count"] +list_index_of_in = ["polars-ops/list_index_of_in"] array_count = ["polars-ops/array_count", "dtype-array"] trigonometry = [] sign = [] @@ -295,6 +296,7 @@ features = [ "streaming", "true_div", "sign", + "list_index_of_in", ] # defines the configuration attribute `docsrs` rustdoc-args = ["--cfg", "docsrs"] diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index e6b1468f4f82..49e96d4cc6c9 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -58,6 +58,8 @@ pub enum ListFunction { ToArray(usize), #[cfg(feature = "list_to_struct")] ToStruct(ListToStructArgs), + #[cfg(feature = "list_index_of_in")] + IndexOfIn, } impl ListFunction { @@ -107,6 +109,8 @@ impl ListFunction { NUnique => mapper.with_dtype(IDX_DTYPE), #[cfg(feature = "list_to_struct")] ToStruct(args) => mapper.try_map_dtype(|x| args.get_output_dtype(x)), + #[cfg(feature = "list_index_of_in")] + IndexOfIn => mapper.with_dtype(IDX_DTYPE), } } } @@ -180,6 +184,8 @@ impl Display for ListFunction { ToArray(_) => "to_array", #[cfg(feature = "list_to_struct")] ToStruct(_) => "to_struct", + #[cfg(feature = "list_index_of_in")] + IndexOfIn => "index_of_in", }; write!(f, "list.{name}") } @@ -243,6 +249,8 @@ impl From for SpecialEq> { NUnique => map!(n_unique), #[cfg(feature = "list_to_struct")] ToStruct(args) => map!(to_struct, &args), + #[cfg(feature = "list_index_of_in")] + IndexOfIn => map_as_slice!(index_of_in), } } } @@ -547,6 +555,15 @@ pub(super) fn count_matches(args: &[Column]) -> PolarsResult { list_count_matches(ca, element.get(0).unwrap()).map(Column::from) } +#[cfg(feature = "list_index_of_in")] +pub(super) fn index_of_in(args: &[Column]) -> PolarsResult { + let s = &args[0]; + let needles = &args[1]; + let ca = s.list()?; + todo!("Implement me"); + //list_count_matches(ca, needles).map(Column::from) +} + pub(super) fn sum(s: &Column) -> PolarsResult { s.list()?.lst_sum().map(Column::from) } diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index d5c2622b5afb..087609c6c299 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -322,6 +322,12 @@ impl ListNameSpace { ) } + #[cfg(feature = "list_index_of_in")] + /// Find the index of needle in the list. + pub fn index_of_in>(self, needle: N) -> Expr { + todo!("Implement me"); + } + #[cfg(feature = "list_sets")] fn set_operation(self, other: Expr, set_operation: SetOperation) -> Expr { Expr::Function { diff --git a/crates/polars-python/Cargo.toml b/crates/polars-python/Cargo.toml index 14c918b44097..fc5358b176c6 100644 --- a/crates/polars-python/Cargo.toml +++ b/crates/polars-python/Cargo.toml @@ -172,6 +172,7 @@ new_streaming = ["polars-lazy/new_streaming"] bitwise = ["polars/bitwise"] approx_unique = ["polars/approx_unique"] string_normalize = ["polars/string_normalize"] +list_index_of_in = ["polars/list_index_of_in"] dtype-i8 = [] dtype-i16 = [] @@ -207,6 +208,7 @@ operations = [ "list_any_all", "list_drop_nulls", "list_sample", + "list_index_of_in", "cutqcut", "rle", "extract_groups", diff --git a/crates/polars-python/src/expr/list.rs b/crates/polars-python/src/expr/list.rs index b8f10fc60c3e..b39b8cf41387 100644 --- a/crates/polars-python/src/expr/list.rs +++ b/crates/polars-python/src/expr/list.rs @@ -55,6 +55,11 @@ impl PyExpr { .into() } + #[cfg(feature = "list_index_of_in")] + fn list_index_of_in(&self, value: PyExpr) -> Self { + self.inner.clone().list().index_of_in(value.inner).into() + } + fn list_join(&self, separator: PyExpr, ignore_nulls: bool) -> Self { self.inner .clone() diff --git a/crates/polars/Cargo.toml b/crates/polars/Cargo.toml index d256582503cf..38dd02a22e27 100644 --- a/crates/polars/Cargo.toml +++ b/crates/polars/Cargo.toml @@ -235,6 +235,7 @@ trigonometry = ["polars-lazy?/trigonometry"] true_div = ["polars-lazy?/true_div"] unique_counts = ["polars-ops/unique_counts", "polars-lazy?/unique_counts"] zip_with = ["polars-core/zip_with"] +list_index_of_in = ["polars-ops/list_index_of_in", "polars-lazy?/list_index_of_in"] bigidx = ["polars-core/bigidx", "polars-lazy?/bigidx", "polars-ops/big_idx", "polars-utils/bigidx"] polars_cloud = ["polars-lazy?/polars_cloud"] diff --git a/py-polars/Cargo.toml b/py-polars/Cargo.toml index f09fe3952f6f..72e3f67f6bbe 100644 --- a/py-polars/Cargo.toml +++ b/py-polars/Cargo.toml @@ -70,6 +70,7 @@ list_any_all = ["polars-python/list_any_all"] array_any_all = ["polars-python/array_any_all"] list_drop_nulls = ["polars-python/list_drop_nulls"] list_sample = ["polars-python/list_sample"] +list_index_of_in = ["polars-python/list_index_of_in"] cutqcut = ["polars-python/cutqcut"] rle = ["polars-python/rle"] extract_groups = ["polars-python/extract_groups"] diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index de98cf5869ce..26d1d75b02cb 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -1059,6 +1059,22 @@ def count_matches(self, element: IntoExpr) -> Expr: element = parse_into_expression(element, str_as_lit=True) return wrap_expr(self._pyexpr.list_count_matches(element)) + def index_of_in(self, element: IntoExpr) -> Expr: + """ + TODO + + Parameters + ---------- + needles + TODO + + Examples + -------- + TODO + """ + element = parse_into_expression(element, str_as_lit=True) + return wrap_expr(self._pyexpr.list_index_of_in(element)) + def to_array(self, width: int) -> Expr: """ Convert a List column into an Array column with the same inner data type. diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py new file mode 100644 index 000000000000..0cb406507ebf --- /dev/null +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -0,0 +1,20 @@ +"""Tests for ``.list.index_of_in()``.""" + +import polars as pl +from polars.testing import assert_frame_equal + + +def test_index_of_in_from_constant() -> None: + df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]]}) + assert_frame_equal( + df.select(pl.col("lists").list.index_of_in(1)), + pl.DataFrame({"lists": [1, None, 2]}), + ) + + +def test_index_of_in_from_column() -> None: + df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]], "values": [1, 2, 6]}) + assert_frame_equal( + df.select(pl.col("lists").list.index_of_in(pl.col("values"))), + pl.DataFrame({"lists": [1, None, 2]}), + ) From 30838035b1949d4540ec2f11d6a371893b4f14f0 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Wed, 15 Jan 2025 12:11:24 -0500 Subject: [PATCH 02/30] Sketch of implementation, initial tests pass. --- .../src/chunked_array/list/index_of_in.rs | 46 ++++++++++++++++++- .../polars-plan/src/dsl/function_expr/list.rs | 3 +- crates/polars-plan/src/dsl/list.rs | 9 +++- .../namespaces/list/test_index_of_in.py | 4 +- 4 files changed, 56 insertions(+), 6 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index 902fb48bfd1b..a51e7764ed94 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -1,5 +1,49 @@ use super::*; +use crate::series::index_of; pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult { - todo!("Implement me"); + let mut builder = PrimitiveChunkedBuilder::::new(ca.name().clone(), ca.len()); + if needles.len() == 1 { + // For some reason we need to do casting ourselves. + let needle = needles.get(0).unwrap(); + let cast_needle = needle.cast(ca.dtype().inner_dtype().unwrap()); + if cast_needle != needle { + todo!("nicer error handling"); + } + let needle = Scalar::new( + cast_needle.dtype().clone(), + cast_needle.into_static(), + ); + ca.amortized_iter().for_each(|opt_series| { + if let Some(subseries) = opt_series { + // TODO justify why unwrap()s are ok + builder.append_option( + // TODO clone() sucks, maybe need to change the API for index_of? + index_of(subseries.as_ref(), needle.clone()) + .unwrap() + .map(|v| v.try_into().unwrap()), + ); + } else { + builder.append_null(); + } + }); + } else { + ca.amortized_iter() + .zip(needles.iter()) + .for_each(|(opt_series, needle)| { + match (opt_series, needle) { + (None, _) => builder.append_null(), + (Some(subseries), needle) => { + let needle = Scalar::new(needles.dtype().clone(), needle.into_static()); + // TODO justify why unwrap()s are ok + builder.append_option( + index_of(subseries.as_ref(), needle) + .unwrap() + .map(|v| v.try_into().unwrap()), + ); + }, + } + }); + } + Ok(builder.finish().into()) } diff --git a/crates/polars-plan/src/dsl/function_expr/list.rs b/crates/polars-plan/src/dsl/function_expr/list.rs index 49e96d4cc6c9..7e374031b905 100644 --- a/crates/polars-plan/src/dsl/function_expr/list.rs +++ b/crates/polars-plan/src/dsl/function_expr/list.rs @@ -560,8 +560,7 @@ pub(super) fn index_of_in(args: &[Column]) -> PolarsResult { let s = &args[0]; let needles = &args[1]; let ca = s.list()?; - todo!("Implement me"); - //list_count_matches(ca, needles).map(Column::from) + list_index_of_in(ca, needles.as_materialized_series()).map(Column::from) } pub(super) fn sum(s: &Column) -> PolarsResult { diff --git a/crates/polars-plan/src/dsl/list.rs b/crates/polars-plan/src/dsl/list.rs index 087609c6c299..9f75089a6478 100644 --- a/crates/polars-plan/src/dsl/list.rs +++ b/crates/polars-plan/src/dsl/list.rs @@ -325,7 +325,14 @@ impl ListNameSpace { #[cfg(feature = "list_index_of_in")] /// Find the index of needle in the list. pub fn index_of_in>(self, needle: N) -> Expr { - todo!("Implement me"); + let other = needle.into(); + + self.0.map_many_private( + FunctionExpr::ListExpr(ListFunction::IndexOfIn), + &[other], + false, + None, + ) } #[cfg(feature = "list_sets")] diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index 0cb406507ebf..2ff5db4e25ed 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -8,7 +8,7 @@ def test_index_of_in_from_constant() -> None: df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]]}) assert_frame_equal( df.select(pl.col("lists").list.index_of_in(1)), - pl.DataFrame({"lists": [1, None, 2]}), + pl.DataFrame({"lists": [1, None, 2]}, schema={"lists": pl.get_index_type()}), ) @@ -16,5 +16,5 @@ def test_index_of_in_from_column() -> None: df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]], "values": [1, 2, 6]}) assert_frame_equal( df.select(pl.col("lists").list.index_of_in(pl.col("values"))), - pl.DataFrame({"lists": [1, None, 2]}), + pl.DataFrame({"lists": [1, 0, None]}, schema={"lists": pl.get_index_type()}), ) From 9fc4d5fdbb7cd26d9254d46a44b8d8b4a09b5c08 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Wed, 15 Jan 2025 12:25:01 -0500 Subject: [PATCH 03/30] More explanations --- .../polars-ops/src/chunked_array/list/index_of_in.rs | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index a51e7764ed94..55b4e447c04f 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -16,9 +16,10 @@ pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult instead of a Scalar + // which implies AnyValue<'static>? index_of(subseries.as_ref(), needle.clone()) .unwrap() .map(|v| v.try_into().unwrap()), @@ -29,13 +30,17 @@ pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult builder.append_null(), (Some(subseries), needle) => { let needle = Scalar::new(needles.dtype().clone(), needle.into_static()); - // TODO justify why unwrap()s are ok builder.append_option( index_of(subseries.as_ref(), needle) .unwrap() From 85b26ffa56f70d144fbc759156db364e63687a60 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 28 Jan 2025 13:12:37 -0500 Subject: [PATCH 04/30] Alternative implementation in Python. --- py-polars/polars/expr/list.py | 14 ++++++++++++++ .../operations/namespaces/list/test_index_of_in.py | 4 ++-- 2 files changed, 16 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 26d1d75b02cb..f6445ea99343 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -6,6 +6,8 @@ from typing import TYPE_CHECKING, Any, Callable import polars._reexport as pl +import polars as pl_full + from polars import functions as F from polars._utils.parse import parse_into_expression from polars._utils.various import find_stacklevel @@ -1075,6 +1077,18 @@ def index_of_in(self, element: IntoExpr) -> Expr: element = parse_into_expression(element, str_as_lit=True) return wrap_expr(self._pyexpr.list_index_of_in(element)) + def index_of_in_py(self, element: IntoExpr) -> Expr: + element = wrap_expr(parse_into_expression(element, str_as_lit=True)) + return ( + self.concat(element) + .list.eval( + pl_full.element() + .slice(0, pl_full.element().len() - 1) + .index_of(pl_full.element().last()) + ) + .list.first() + ) + def to_array(self, width: int) -> Expr: """ Convert a List column into an Array column with the same inner data type. diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index 2ff5db4e25ed..6a6569df9f42 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -7,7 +7,7 @@ def test_index_of_in_from_constant() -> None: df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]]}) assert_frame_equal( - df.select(pl.col("lists").list.index_of_in(1)), + df.select(pl.col("lists").list.index_of_in_py(1)), pl.DataFrame({"lists": [1, None, 2]}, schema={"lists": pl.get_index_type()}), ) @@ -15,6 +15,6 @@ def test_index_of_in_from_constant() -> None: def test_index_of_in_from_column() -> None: df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]], "values": [1, 2, 6]}) assert_frame_equal( - df.select(pl.col("lists").list.index_of_in(pl.col("values"))), + df.select(pl.col("lists").list.index_of_in_py(pl.col("values"))), pl.DataFrame({"lists": [1, 0, None]}, schema={"lists": pl.get_index_type()}), ) From 2365f46dabb3e4e08ec06cdcba2f8046156977c3 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 3 Feb 2025 12:53:26 -0500 Subject: [PATCH 05/30] Remove much slower Python version. --- py-polars/polars/expr/list.py | 12 ------------ .../operations/namespaces/list/test_index_of_in.py | 4 ++-- 2 files changed, 2 insertions(+), 14 deletions(-) diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index f6445ea99343..b0382a8b396d 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -1077,18 +1077,6 @@ def index_of_in(self, element: IntoExpr) -> Expr: element = parse_into_expression(element, str_as_lit=True) return wrap_expr(self._pyexpr.list_index_of_in(element)) - def index_of_in_py(self, element: IntoExpr) -> Expr: - element = wrap_expr(parse_into_expression(element, str_as_lit=True)) - return ( - self.concat(element) - .list.eval( - pl_full.element() - .slice(0, pl_full.element().len() - 1) - .index_of(pl_full.element().last()) - ) - .list.first() - ) - def to_array(self, width: int) -> Expr: """ Convert a List column into an Array column with the same inner data type. diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index 6a6569df9f42..2ff5db4e25ed 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -7,7 +7,7 @@ def test_index_of_in_from_constant() -> None: df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]]}) assert_frame_equal( - df.select(pl.col("lists").list.index_of_in_py(1)), + df.select(pl.col("lists").list.index_of_in(1)), pl.DataFrame({"lists": [1, None, 2]}, schema={"lists": pl.get_index_type()}), ) @@ -15,6 +15,6 @@ def test_index_of_in_from_constant() -> None: def test_index_of_in_from_column() -> None: df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]], "values": [1, 2, 6]}) assert_frame_equal( - df.select(pl.col("lists").list.index_of_in_py(pl.col("values"))), + df.select(pl.col("lists").list.index_of_in(pl.col("values"))), pl.DataFrame({"lists": [1, 0, None]}, schema={"lists": pl.get_index_type()}), ) From 6e4b0059076badae4ecc1927468b3510db2997bc Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 3 Feb 2025 13:23:38 -0500 Subject: [PATCH 06/30] Always check both literal and non-literal. --- .../tests/unit/operations/test_index_of.py | 71 ++++++------------- 1 file changed, 22 insertions(+), 49 deletions(-) diff --git a/py-polars/tests/unit/operations/test_index_of.py b/py-polars/tests/unit/operations/test_index_of.py index 7301d6dfd311..bd09959d88a4 100644 --- a/py-polars/tests/unit/operations/test_index_of.py +++ b/py-polars/tests/unit/operations/test_index_of.py @@ -25,11 +25,7 @@ def isnan(value: object) -> bool: return np.isnan(value) # type: ignore[no-any-return] -def assert_index_of( - series: pl.Series, - value: IntoExpr, - convert_to_literal: bool = False, -) -> None: +def assert_index_of(series: pl.Series, value: IntoExpr) -> None: """``Series.index_of()`` returns the index, or ``None`` if it can't be found.""" if isnan(value): expected_index = None @@ -45,15 +41,14 @@ def assert_index_of( if expected_index == -1: expected_index = None - if convert_to_literal: - value = pl.lit(value, dtype=series.dtype) - - # Eager API: - assert series.index_of(value) == expected_index - # Lazy API: - assert pl.LazyFrame({"series": series}).select( - pl.col("series").index_of(value) - ).collect().get_column("series").to_list() == [expected_index] + orig_value = value + for value in (orig_value, pl.lit(orig_value, dtype=series.dtype)): + # Eager API: + assert series.index_of(value) == expected_index + # Lazy API: + assert pl.LazyFrame({"series": series}).select( + pl.col("series").index_of(value) + ).collect().get_column("series").to_list() == [expected_index] @pytest.mark.parametrize("dtype", [pl.Float32, pl.Float64]) @@ -74,8 +69,7 @@ def test_float(dtype: pl.DataType) -> None: ] for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]: for value in values: - assert_index_of(s, value, convert_to_literal=True) - assert_index_of(s, value, convert_to_literal=False) + assert_index_of(s, value) for value in extra_values: # type: ignore[assignment] assert_index_of(s, value) @@ -133,11 +127,9 @@ def test_integer(dtype: pl.DataType) -> None: for s in [series, sorted_series_asc, sorted_series_desc, chunked_series]: value: IntoExpr for value in values: - assert_index_of(s, value, convert_to_literal=True) - assert_index_of(s, value, convert_to_literal=False) + assert_index_of(s, value) for value in extra_values: - assert_index_of(s, value, convert_to_literal=True) - assert_index_of(s, value, convert_to_literal=False) + assert_index_of(s, value) # Can't cast floats: for f in [np.float32(3.1), np.float64(3.1), 50.9]: @@ -270,8 +262,7 @@ def test_other_types( ) for s in series_variants: for value in expected_values: - assert_index_of(s, value, convert_to_literal=True) - assert_index_of(s, value, convert_to_literal=False) + assert_index_of(s, value) # Extra values may not be expressible as literal of correct dtype, so # don't try: for value in extra_values: @@ -302,14 +293,7 @@ def test_error_on_multiple_values() -> None: pl.Series("a", [1, 2, 3]).index_of(pl.Series([2, 3])) -@pytest.mark.parametrize( - "convert_to_literal", - [ - True, - False, - ], -) -def test_enum(convert_to_literal: bool) -> None: +def test_enum() -> None: series = pl.Series(["a", "c", None, "b"], dtype=pl.Enum(["c", "b", "a"])) expected_values = series.to_list() for s in [ @@ -319,27 +303,16 @@ def test_enum(convert_to_literal: bool) -> None: series.sort(descending=True), ]: for value in expected_values: - assert_index_of(s, value, convert_to_literal=convert_to_literal) + assert_index_of(s, value) -@pytest.mark.parametrize( - "convert_to_literal", - [ - pytest.param( - True, - marks=pytest.mark.xfail( - reason="https://github.com/pola-rs/polars/issues/20318" - ), - ), - pytest.param( - False, - marks=pytest.mark.xfail( - reason="https://github.com/pola-rs/polars/issues/20171" - ), - ), - ], +@pytest.mark.xfail( + reason=( + "https://github.com/pola-rs/polars/issues/20318 and " + + "https://github.com/pola-rs/polars/issues/20171" + ) ) -def test_categorical(convert_to_literal: bool) -> None: +def test_categorical() -> None: series = pl.Series(["a", "c", None, "b"], dtype=pl.Categorical) expected_values = series.to_list() for s in [ @@ -349,4 +322,4 @@ def test_categorical(convert_to_literal: bool) -> None: series.sort(descending=True), ]: for value in expected_values: - assert_index_of(s, value, convert_to_literal=convert_to_literal) + assert_index_of(s, value) From 372ba35ad92a443b93066ae7a3f81af9f6f84089 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 3 Feb 2025 14:24:23 -0500 Subject: [PATCH 07/30] Start making more general tests, make eager API work. --- py-polars/polars/expr/list.py | 2 - py-polars/polars/series/list.py | 5 ++ .../namespaces/list/test_index_of_in.py | 85 ++++++++++++++++--- .../tests/unit/operations/test_index_of.py | 9 +- 4 files changed, 86 insertions(+), 15 deletions(-) diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index b0382a8b396d..26d1d75b02cb 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -6,8 +6,6 @@ from typing import TYPE_CHECKING, Any, Callable import polars._reexport as pl -import polars as pl_full - from polars import functions as F from polars._utils.parse import parse_into_expression from polars._utils.various import find_stacklevel diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 72c435b157d7..54bcd6763c5e 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -1054,3 +1054,8 @@ def set_symmetric_difference(self, other: Series) -> Series: [5, 7, 8] ] """ # noqa: W505 + + def index_of_in(self, element: IntoExpr) -> Series: + """ + TODO + """ # noqa: W505 diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index 2ff5db4e25ed..745b87e9d9ea 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -1,20 +1,83 @@ """Tests for ``.list.index_of_in()``.""" +from __future__ import annotations + +from typing import TYPE_CHECKING + import polars as pl -from polars.testing import assert_frame_equal +from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.operations.test_index_of import get_expected_index -def test_index_of_in_from_constant() -> None: - df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]]}) - assert_frame_equal( - df.select(pl.col("lists").list.index_of_in(1)), - pl.DataFrame({"lists": [1, None, 2]}, schema={"lists": pl.get_index_type()}), - ) +if TYPE_CHECKING: + from polars._typing import IntoExpr, PythonLiteral + +IdxType = pl.get_index_type() + + +def assert_index_of_in_from_scalar( + list_series: pl.Series, value: PythonLiteral +) -> None: + expected_indexes = [ + get_expected_index(sub_series, value) for sub_series in list_series + ] + + original_value = value + del value + for updated_value in (original_value, pl.lit(original_value)): + # Eager API: + assert_series_equal( + list_series.list.index_of_in(updated_value), + pl.Series(list_series.name, expected_indexes, dtype=IdxType), + ) + # Lazy API: + assert_frame_equal( + pl.LazyFrame({"lists": list_series}) + .select(pl.col("lists").list.index_of_in(updated_value)) + .collect(), + pl.DataFrame({"lists": expected_indexes}, schema={"lists": IdxType}), + ) -def test_index_of_in_from_column() -> None: - df = pl.DataFrame({"lists": [[3, 1], [2, 4], [5, 3, 1]], "values": [1, 2, 6]}) +def assert_index_of_in_from_series( + list_series: pl.Series, + values: pl.Series, +) -> None: + expected_indexes = [ + get_expected_index(sub_series, value) + for (sub_series, value) in zip(list_series, values) + ] + + # Eager API: + assert_series_equal( + list_series.list.index_of_in(values), + pl.Series(list_series.name, expected_indexes, dtype=IdxType), + ) + # Lazy API: assert_frame_equal( - df.select(pl.col("lists").list.index_of_in(pl.col("values"))), - pl.DataFrame({"lists": [1, 0, None]}, schema={"lists": pl.get_index_type()}), + pl.LazyFrame({"lists": list_series, "values": values}) + .select(pl.col("lists").list.index_of_in(pl.col("values"))) + .collect(), + pl.DataFrame({"lists": expected_indexes}, schema={"lists": IdxType}), ) + + +# Testing plan: +# - All integers +# - Both floats, with nans +# - Strings +# - datetime, date, time, timedelta +# - nested lists +# - something with hypothesis +# - error case: non-matching lengths + + +def test_index_of_in_from_scalar() -> None: + list_series = pl.Series([[3, 1], [2, 4], [5, 3, 1]]) + assert_index_of_in_from_scalar(list_series, 1) + + +def test_index_of_in_from_series() -> None: + list_series = pl.Series([[3, 1], [2, 4], [5, 3, 1]]) + values = pl.Series([1, 2, 6]) + assert_index_of_in_from_series(list_series, values) diff --git a/py-polars/tests/unit/operations/test_index_of.py b/py-polars/tests/unit/operations/test_index_of.py index bd09959d88a4..7b2bc88eb210 100644 --- a/py-polars/tests/unit/operations/test_index_of.py +++ b/py-polars/tests/unit/operations/test_index_of.py @@ -2,6 +2,7 @@ from datetime import date, datetime, time, timedelta from decimal import Decimal +from os import get_exec_path from typing import TYPE_CHECKING, Any import numpy as np @@ -25,8 +26,7 @@ def isnan(value: object) -> bool: return np.isnan(value) # type: ignore[no-any-return] -def assert_index_of(series: pl.Series, value: IntoExpr) -> None: - """``Series.index_of()`` returns the index, or ``None`` if it can't be found.""" +def get_expected_index(series: pl.Series, value: IntoExpr) -> int | None: if isnan(value): expected_index = None for i, o in enumerate(series.to_list()): @@ -40,7 +40,12 @@ def assert_index_of(series: pl.Series, value: IntoExpr) -> None: expected_index = None if expected_index == -1: expected_index = None + return expected_index + +def assert_index_of(series: pl.Series, value: IntoExpr) -> None: + """``Series.index_of()`` returns the index, or ``None`` if it can't be found.""" + expected_index = get_expected_index(series, value) orig_value = value for value in (orig_value, pl.lit(orig_value, dtype=series.dtype)): # Eager API: From 255ccc94280a5cbeb9487765a6e188844bc04925 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 4 Feb 2025 15:46:25 -0500 Subject: [PATCH 08/30] Integer tests pass. --- .../src/chunked_array/list/index_of_in.rs | 58 ++++++++++------ .../namespaces/list/test_index_of_in.py | 66 ++++++++++++++++++- .../tests/unit/operations/test_index_of.py | 1 - 3 files changed, 101 insertions(+), 24 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index 55b4e447c04f..ee95a1cde1c4 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -1,19 +1,38 @@ +use polars_core::chunked_array::cast::CastOptions; +use polars_utils::total_ord::TotalEq; + use super::*; use crate::series::index_of; +fn check_if_cast_lossless(dtype1: &DataType, dtype2: &DataType, result: bool) -> PolarsResult<()> { + polars_ensure!( + result, + InvalidOperation: "cannot cast lossless between {} and {}", + dtype1, dtype2, + ); + Ok(()) +} + pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult { let mut builder = PrimitiveChunkedBuilder::::new(ca.name().clone(), ca.len()); + let inner_dtype = ca.dtype().inner_dtype().unwrap(); + let needle_dtype = needles.dtype(); + // We need to do casting ourselves, unless we grow a new CastingRules + // variant. + check_if_cast_lossless( + &needle_dtype, + inner_dtype, + inner_dtype.leaf_dtype().is_float() == needle_dtype.is_float(), + )?; if needles.len() == 1 { - // For some reason we need to do casting ourselves. let needle = needles.get(0).unwrap(); - let cast_needle = needle.cast(ca.dtype().inner_dtype().unwrap()); - if cast_needle != needle { - todo!("nicer error handling"); + let cast_needle = needle.cast(inner_dtype); + check_if_cast_lossless(&needle_dtype, inner_dtype, cast_needle.tot_eq(&needle))?; + let mut needle_dtype = cast_needle.dtype().clone(); + if needle_dtype.is_null() { + needle_dtype = inner_dtype.clone(); } - let needle = Scalar::new( - cast_needle.dtype().clone(), - cast_needle.into_static(), - ); + let needle = Scalar::new(needle_dtype, cast_needle.into_static()); ca.amortized_iter().for_each(|opt_series| { if let Some(subseries) = opt_series { builder.append_option( @@ -29,6 +48,7 @@ pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult PolarsResult builder.append_null(), - (Some(subseries), needle) => { - let needle = Scalar::new(needles.dtype().clone(), needle.into_static()); - builder.append_option( - index_of(subseries.as_ref(), needle) - .unwrap() - .map(|v| v.try_into().unwrap()), - ); - }, - } + .for_each(|(opt_series, needle)| match (opt_series, needle) { + (None, _) => builder.append_null(), + (Some(subseries), needle) => { + let needle = Scalar::new(needles.dtype().clone(), needle.into_static()); + builder.append_option( + index_of(subseries.as_ref(), needle) + .unwrap() + .map(|v| v.try_into().unwrap()), + ); + }, }); } Ok(builder.finish().into()) diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index 745b87e9d9ea..874d1fc2b852 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -4,9 +4,14 @@ from typing import TYPE_CHECKING +import numpy as np +import pytest + import polars as pl +from polars.exceptions import InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal +from tests.unit.conftest import INTEGER_DTYPES from tests.unit.operations.test_index_of import get_expected_index if TYPE_CHECKING: @@ -19,7 +24,8 @@ def assert_index_of_in_from_scalar( list_series: pl.Series, value: PythonLiteral ) -> None: expected_indexes = [ - get_expected_index(sub_series, value) for sub_series in list_series + None if sub_series is None else get_expected_index(sub_series, value) + for sub_series in list_series ] original_value = value @@ -44,7 +50,7 @@ def assert_index_of_in_from_series( values: pl.Series, ) -> None: expected_indexes = [ - get_expected_index(sub_series, value) + None if sub_series is None else get_expected_index(sub_series, value) for (sub_series, value) in zip(list_series, values) ] @@ -63,7 +69,7 @@ def assert_index_of_in_from_series( # Testing plan: -# - All integers +# D All integers # - Both floats, with nans # - Strings # - datetime, date, time, timedelta @@ -81,3 +87,57 @@ def test_index_of_in_from_series() -> None: list_series = pl.Series([[3, 1], [2, 4], [5, 3, 1]]) values = pl.Series([1, 2, 6]) assert_index_of_in_from_series(list_series, values) + + +def to_int(expr: pl.Expr) -> int: + return pl.select(expr).item() + + +@pytest.mark.parametrize("lists_dtype", INTEGER_DTYPES) +@pytest.mark.parametrize("values_dtype", INTEGER_DTYPES) +def test_integer(lists_dtype: pl.DataType, values_dtype: pl.DataType) -> None: + lists = [ + [51, 3], + [None, 4], + None, + [to_int(lists_dtype.max()), 3], # type: ignore[attr-defined] + [6, to_int(lists_dtype.min())], # type: ignore[attr-defined] + ] + lists_series = pl.Series(lists, dtype=pl.List(lists_dtype)) + chunked_series = pl.concat( + [pl.Series([[100, 7]], dtype=pl.List(lists_dtype)), lists_series], rechunk=False + ) + values = [ + to_int(v) for v in [lists_dtype.max() - 1, lists_dtype.min() + 1] + ] # type: ignore[attr-defined] + for sublist in lists: + if sublist is None: + values.append(None) + else: + values.extend(sublist) + + # Scalars: + for s in [lists_series, chunked_series]: + value: IntoExpr + for value in values: + assert_index_of_in_from_scalar(s, value) + + # Series + search_series = pl.Series([3, 4, 7, None, 6], dtype=values_dtype) + assert_index_of_in_from_series(lists_series, search_series) + search_series = pl.Series([17, 3, 4, 7, None, 6], dtype=values_dtype) + assert_index_of_in_from_series(chunked_series, search_series) + + +def test_no_lossy_numeric_casts() -> None: + list_series = pl.Series([[3]], dtype=pl.List(pl.Int8())) + for will_be_lossy in [ + np.float32(3.1), + np.float64(3.1), + 50.9, + 300, + -300, + pl.lit(300, dtype=pl.Int16), + ]: + with pytest.raises(InvalidOperationError, match="cannot cast lossless"): + list_series.list.index_of_in(will_be_lossy) # type: ignore[arg-type] diff --git a/py-polars/tests/unit/operations/test_index_of.py b/py-polars/tests/unit/operations/test_index_of.py index 7b2bc88eb210..c5e3a5803703 100644 --- a/py-polars/tests/unit/operations/test_index_of.py +++ b/py-polars/tests/unit/operations/test_index_of.py @@ -2,7 +2,6 @@ from datetime import date, datetime, time, timedelta from decimal import Decimal -from os import get_exec_path from typing import TYPE_CHECKING, Any import numpy as np From 1a412573dfb8c0ee7c7bd8ac98325610aaaacad9 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 4 Feb 2025 16:54:38 -0500 Subject: [PATCH 09/30] Support and tests for floats. --- .../src/chunked_array/list/index_of_in.rs | 2 +- .../namespaces/list/test_index_of_in.py | 22 ++++++++++++++++++- 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index ee95a1cde1c4..cabe6f5098fa 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -22,7 +22,7 @@ pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult None: @@ -141,3 +142,22 @@ def test_no_lossy_numeric_casts() -> None: ]: with pytest.raises(InvalidOperationError, match="cannot cast lossless"): list_series.list.index_of_in(will_be_lossy) # type: ignore[arg-type] + + +@pytest.mark.parametrize("float_dtype", FLOAT_DTYPES) +def test_float(float_dtype: pl.DataType) -> None: + lists = [ + [1.5, np.nan, np.inf], + [3.0, None, -np.inf], + [0.0, -0.0, -np.nan], + ] + lists_series = pl.Series(lists, dtype=pl.List(float_dtype)) + + # Scalar + for value in sum(lists, []) + [3.5, np.float64(1.5), np.float32(3.0)]: + assert_index_of_in_from_scalar(lists_series, value) + + # Series + assert_index_of_in_from_series( + lists_series, pl.Series([1.5, -np.inf, -np.nan], dtype=float_dtype) + ) From 61e3d1c4f5ac93faba89f94b90d2404424a50f1f Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Wed, 5 Feb 2025 08:26:37 -0500 Subject: [PATCH 10/30] More test coverage. --- .../namespaces/list/test_index_of_in.py | 65 ++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index 67e517a3aaba..97ced24016bd 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -2,6 +2,8 @@ from __future__ import annotations +from datetime import date, datetime, time, timedelta +from decimal import Decimal from typing import TYPE_CHECKING import numpy as np @@ -70,7 +72,7 @@ def assert_index_of_in_from_series( # Testing plan: # D All integers -# - Both floats, with nans +# D Both floats, with nans # - Strings # - datetime, date, time, timedelta # - nested lists @@ -161,3 +163,64 @@ def test_float(float_dtype: pl.DataType) -> None: assert_index_of_in_from_series( lists_series, pl.Series([1.5, -np.inf, -np.nan], dtype=float_dtype) ) + + +@pytest.mark.parametrize( + ("list_series", "extra_values"), + [ + (pl.Series([["abc", "def"], ["ghi", "zzz", "X"], ["Y"]]), ["foo"]), + (pl.Series([[b"abc", b"def"], [b"ghi", b"zzz", b"X"], [b"Y"]]), [b"foo"]), + (pl.Series([[True, None, False], [True, False]]), []), + ( + pl.Series( + [ + [datetime(1997, 12, 31), datetime(1996, 1, 1)], + [datetime(1997, 12, 30), datetime(1996, 1, 2)], + ] + ), + [datetime(2003, 1, 1)], + ), + ( + pl.Series( + [ + [date(1997, 12, 31), date(1996, 1, 1)], + [date(1997, 12, 30), date(1996, 1, 2)], + ] + ), + [date(2003, 1, 1)], + ), + ( + pl.Series( + [ + [time(16, 12, 31), None, time(11, 10, 53)], + [time(16, 11, 31), time(11, 10, 54)], + ] + ), + [time(12, 6, 7)], + ), + ( + pl.Series( + [ + [timedelta(hours=12), None, timedelta(minutes=3)], + [timedelta(hours=3), None, timedelta(hours=1)], + ], + ), + [timedelta(minutes=7)], + ), + ( + pl.Series( + [[Decimal(12), None, Decimal(3)], [Decimal(500), None, Decimal(16)]] + ), + [Decimal(4)], + ), + # TODO: nested lists, arrays, structs + ], +) +def test_other_types(list_series: pl.Series, extra_values: list[PythonLiteral]) -> None: + needles_series = pl.Series( + [sublist[i % len(sublist)] for (i, sublist) in enumerate(list_series)] + ) + assert_index_of_in_from_series(list_series, needles_series) + + for value in sum(list_series.to_list(), []) + extra_values + [None]: + assert_index_of_in_from_scalar(list_series, value) From 29a406818c298a5fd8ecea38dc72a14713b99a81 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Wed, 5 Feb 2025 09:15:12 -0500 Subject: [PATCH 11/30] Start of nested lists working. --- .../src/chunked_array/list/index_of_in.rs | 18 ++++++++++++------ .../namespaces/list/test_index_of_in.py | 8 +++++--- .../tests/unit/operations/test_index_of.py | 9 ++++++++- 3 files changed, 25 insertions(+), 10 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index cabe6f5098fa..63a9117fb08f 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -26,12 +26,18 @@ pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult None: ), [Decimal(4)], ), + (pl.Series([[[1, 2], None], [[4, 5], [6]], [[None, 3, 5]]]), [[5, 7], []]), # TODO: nested lists, arrays, structs ], ) diff --git a/py-polars/tests/unit/operations/test_index_of.py b/py-polars/tests/unit/operations/test_index_of.py index c5e3a5803703..1450ccbfc1c6 100644 --- a/py-polars/tests/unit/operations/test_index_of.py +++ b/py-polars/tests/unit/operations/test_index_of.py @@ -25,6 +25,13 @@ def isnan(value: object) -> bool: return np.isnan(value) # type: ignore[no-any-return] +def to_python(maybe_series: object) -> object: + if isinstance(maybe_series, pl.Series): + return [to_python(sub) for sub in maybe_series.to_list()] + else: + return maybe_series + + def get_expected_index(series: pl.Series, value: IntoExpr) -> int | None: if isnan(value): expected_index = None @@ -34,7 +41,7 @@ def get_expected_index(series: pl.Series, value: IntoExpr) -> int | None: break else: try: - expected_index = series.to_list().index(value) + expected_index = to_python(series).index(to_python(value)) except ValueError: expected_index = None if expected_index == -1: From a11b77a519c8adc9d93e0ab5d67c9f82a6542759 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Wed, 5 Feb 2025 10:46:58 -0500 Subject: [PATCH 12/30] Expand testing of nested lists. --- .../namespaces/list/test_index_of_in.py | 32 +++++++++++++++++-- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index 059d6dfa722b..c393ddff7746 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -214,15 +214,41 @@ def test_float(float_dtype: pl.DataType) -> None: ), [Decimal(4)], ), - (pl.Series([[[1, 2], None], [[4, 5], [6]], [[None, 3, 5]]]), [[5, 7], []]), + ( + pl.Series([[[1, 2], None], [[4, 5], [6]], None, [[None, 3, 5]], [None]]), + [[5, 7], []], + ), + ( + pl.Series( + [ + [[[1, 2], None], [[4, 5], [6]]], + [[[None, 3, 5]]], + None, + [None], + [[None]], + [[[None]]], + ] + ), + [[[5, 7]], [[]], [None]], + ), # TODO: nested lists, arrays, structs ], ) def test_other_types(list_series: pl.Series, extra_values: list[PythonLiteral]) -> None: needles_series = pl.Series( - [sublist[i % len(sublist)] for (i, sublist) in enumerate(list_series)] + [ + None if sublist is None else sublist[i % len(sublist)] + for (i, sublist) in enumerate(list_series) + ], + dtype=list_series.dtype.inner, ) assert_index_of_in_from_series(list_series, needles_series) - for value in sum(list_series.to_list(), []) + extra_values + [None]: + values = [None] + for subseries in list_series.to_list(): + if subseries is not None: + values.extend(subseries) + values.extend(extra_values) + for value in values: assert_index_of_in_from_scalar(list_series, value) + From 07f35a0cdfd6cd2ecb8247a7beac0010897791d5 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Thu, 6 Feb 2025 11:14:01 -0500 Subject: [PATCH 13/30] Working arrays, do casting in the IR. --- crates/polars-expr/src/expressions/cast.rs | 1 + crates/polars-expr/src/planner.rs | 1 + .../src/chunked_array/list/index_of_in.rs | 34 +------------------ crates/polars-plan/src/dsl/list.rs | 20 ++++++----- .../src/plans/conversion/type_coercion/mod.rs | 24 +++++++++++-- crates/polars-plan/src/plans/expr_ir.rs | 4 +++ crates/polars-plan/src/plans/options.rs | 3 ++ py-polars/polars/expr/list.py | 2 +- .../namespaces/list/test_index_of_in.py | 29 +++++++++++----- 9 files changed, 64 insertions(+), 54 deletions(-) diff --git a/crates/polars-expr/src/expressions/cast.rs b/crates/polars-expr/src/expressions/cast.rs index 828117603231..c7c468d49758 100644 --- a/crates/polars-expr/src/expressions/cast.rs +++ b/crates/polars-expr/src/expressions/cast.rs @@ -16,6 +16,7 @@ pub struct CastExpr { impl CastExpr { fn finish(&self, input: &Column) -> PolarsResult { + println!("CASTEXPR {input:?} to {}", self.dtype); input.cast_with_options(&self.dtype, self.options) } } diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index 8156000a5311..6be15b28cd5e 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -430,6 +430,7 @@ fn create_physical_expr_inner( dtype, options, } => { + println!("AExpr::Cast HAD DTYPE {dtype}"); let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?; Ok(Arc::new(CastExpr { input: phys_expr, diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index 63a9117fb08f..c915c35046a7 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -1,44 +1,13 @@ -use polars_core::chunked_array::cast::CastOptions; -use polars_utils::total_ord::TotalEq; - use super::*; use crate::series::index_of; -fn check_if_cast_lossless(dtype1: &DataType, dtype2: &DataType, result: bool) -> PolarsResult<()> { - polars_ensure!( - result, - InvalidOperation: "cannot cast lossless between {} and {}", - dtype1, dtype2, - ); - Ok(()) -} - pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult { let mut builder = PrimitiveChunkedBuilder::::new(ca.name().clone(), ca.len()); let inner_dtype = ca.dtype().inner_dtype().unwrap(); - let needle_dtype = needles.dtype(); - // We need to do casting ourselves, unless we grow a new CastingRules - // variant. - check_if_cast_lossless( - &needle_dtype, - inner_dtype, - (inner_dtype.leaf_dtype().is_float() == needle_dtype.is_float()) || needle_dtype.is_null(), - )?; if needles.len() == 1 { let needle = needles.get(0).unwrap(); - let cast_needle = if needle_dtype.leaf_dtype().is_null() { - needle - } else { - let cast_needle = needle.cast(inner_dtype); - check_if_cast_lossless( - &needle_dtype, - inner_dtype, - needle_dtype.leaf_dtype().is_null() || cast_needle.tot_eq(&needle), - )?; - cast_needle - }; let needle_dtype = inner_dtype.clone(); - let needle = Scalar::new(needle_dtype, cast_needle.into_static()); + let needle = Scalar::new(needle_dtype, needle.into_static()); ca.amortized_iter().for_each(|opt_series| { if let Some(subseries) = opt_series { builder.append_option( @@ -54,7 +23,6 @@ pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult>(self, needle: N) -> Expr { - let other = needle.into(); - - self.0.map_many_private( - FunctionExpr::ListExpr(ListFunction::IndexOfIn), - &[other], - false, - None, - ) + Expr::Function { + input: vec![self.0, needle.into()], + function: FunctionExpr::ListExpr(ListFunction::IndexOfIn), + options: FunctionOptions { + collect_groups: ApplyOptions::ElementWise, + flags: FunctionFlags::default(), + cast_options: Some(CastingRules::FirstArgInnerLossless), + ..Default::default() + }, + } } #[cfg(feature = "list_sets")] diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs index f8b3d802c5f1..b144e02d792d 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs @@ -295,6 +295,17 @@ impl OptimizationRule for TypeCoercionRule { } } }, + CastingRules::FirstArgInnerLossless => { + if super_type.leaf_dtype().is_integer() { + for other in &input[1..] { + let other_dtype = + other.dtype(&input_schema, Context::Default, expr_arena)?; + if other_dtype.leaf_dtype().is_float() { + polars_bail!(InvalidOperation: "cannot cast lossless between {} and {}", super_type, other_dtype) + } + } + } + }, } if matches!(super_type, DataType::Unknown(UnknownKind::Any)) { @@ -310,7 +321,16 @@ impl OptimizationRule for TypeCoercionRule { _ => {}, } - for (e, dtype) in input.iter_mut().zip(dtypes) { + for (i, (e, dtype)) in input.iter_mut().zip(dtypes).enumerate() { + let new_super_type = + if matches!(casting_rules, CastingRules::FirstArgInnerLossless) + && (i > 0) + { + // TODO get rid of unwrap(), will fail if first item is not a list/array + &super_type.inner_dtype().unwrap() + } else { + &super_type + }; match super_type { #[cfg(feature = "dtype-categorical")] DataType::Categorical(_, _) if dtype.is_string() => { @@ -319,7 +339,7 @@ impl OptimizationRule for TypeCoercionRule { _ => cast_expr_ir( e, &dtype, - &super_type, + new_super_type, expr_arena, CastOptions::NonStrict, )?, diff --git a/crates/polars-plan/src/plans/expr_ir.rs b/crates/polars-plan/src/plans/expr_ir.rs index a8f062d74e30..38359e5f535b 100644 --- a/crates/polars-plan/src/plans/expr_ir.rs +++ b/crates/polars-plan/src/plans/expr_ir.rs @@ -262,6 +262,10 @@ impl ExprIR { } } + pub fn output_dtype(&self) -> &OnceLock { + &self.output_dtype + } + pub fn field( &self, schema: &Schema, diff --git a/crates/polars-plan/src/plans/options.rs b/crates/polars-plan/src/plans/options.rs index c2b654c72834..f16590b66f15 100644 --- a/crates/polars-plan/src/plans/options.rs +++ b/crates/polars-plan/src/plans/options.rs @@ -210,6 +210,9 @@ pub enum CastingRules { /// whereas int to int is considered lossless. /// Overflowing is not considered in this flag, that's handled in `strict` casting FirstArgLossless, + /// Cast (in a lossless way) to the inner dtype of the first argument, + /// presumably a list or array. + FirstArgInnerLossless, Supertype(SuperTypeOptions), } diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 26d1d75b02cb..7b4d828f1406 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -1072,7 +1072,7 @@ def index_of_in(self, element: IntoExpr) -> Expr: -------- TODO """ - element = parse_into_expression(element, str_as_lit=True) + element = parse_into_expression(element, str_as_lit=True, list_as_series=False) return wrap_expr(self._pyexpr.list_index_of_in(element)) def to_array(self, width: int) -> Expr: diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index c393ddff7746..24eea7a2bf99 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -33,6 +33,7 @@ def assert_index_of_in_from_scalar( original_value = value del value for updated_value in (original_value, pl.lit(original_value)): + print("PYTHON WANTS TO LOOK UP", updated_value) # Eager API: assert_series_equal( list_series.list.index_of_in(updated_value), @@ -135,17 +136,14 @@ def test_integer(lists_dtype: pl.DataType, values_dtype: pl.DataType) -> None: def test_no_lossy_numeric_casts() -> None: list_series = pl.Series([[3]], dtype=pl.List(pl.Int8())) - for will_be_lossy in [ - np.float32(3.1), - np.float64(3.1), - 50.9, - 300, - -300, - pl.lit(300, dtype=pl.Int16), - ]: + for will_be_lossy in [np.float32(3.1), np.float64(3.1), 50.9]: with pytest.raises(InvalidOperationError, match="cannot cast lossless"): list_series.list.index_of_in(will_be_lossy) # type: ignore[arg-type] + for will_be_lossy in [300, -300, pl.lit(300, dtype=pl.Int16)]: + with pytest.raises(InvalidOperationError, match="conversion from"): + list_series.list.index_of_in(will_be_lossy) # type: ignore[arg-type] + @pytest.mark.parametrize("float_dtype", FLOAT_DTYPES) def test_float(float_dtype: pl.DataType) -> None: @@ -231,7 +229,7 @@ def test_float(float_dtype: pl.DataType) -> None: ), [[[5, 7]], [[]], [None]], ), - # TODO: nested lists, arrays, structs + # TODO: structs ], ) def test_other_types(list_series: pl.Series, extra_values: list[PythonLiteral]) -> None: @@ -252,3 +250,16 @@ def test_other_types(list_series: pl.Series, extra_values: list[PythonLiteral]) for value in values: assert_index_of_in_from_scalar(list_series, value) + +def test_array() -> None: + array_dtype = pl.Array(pl.Int64(), 2) + list_series = pl.Series( + [[[1, 2]], [[4, 5]], [[None, 3]], [None], None], + dtype=pl.List(array_dtype), + ) + values = [[1, 2], [4, 5], [None, 3], [5, 7], None] + for value in values: + assert_index_of_in_from_scalar(list_series, value) + + needles = pl.Series(values[:5], dtype=array_dtype) + assert_index_of_in_from_series(list_series, needles) From 93075a0250c74689e5c06b2c344e7ae2b0d3bbb8 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Thu, 6 Feb 2025 12:03:48 -0500 Subject: [PATCH 14/30] Minimal support for multiple chunks in the needle. --- .../src/chunked_array/list/index_of_in.rs | 1 + .../namespaces/list/test_index_of_in.py | 63 +++++++++++-------- 2 files changed, 39 insertions(+), 25 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index c915c35046a7..684faee5ad0a 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -23,6 +23,7 @@ pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult None: list_series = pl.Series([[3, 1], [2, 4], [5, 3, 1]]) assert_index_of_in_from_scalar(list_series, 1) @@ -145,6 +132,12 @@ def test_no_lossy_numeric_casts() -> None: list_series.list.index_of_in(will_be_lossy) # type: ignore[arg-type] +def test_multichunk_needles() -> None: + series = pl.Series([[1, 3], [3, 2], [4, 5, 3]]) + needles = pl.concat([pl.Series([3, 1]), pl.Series([3])]) + assert series.list.index_of_in(needles).to_list() == [1, None, 2] + + @pytest.mark.parametrize("float_dtype", FLOAT_DTYPES) def test_float(float_dtype: pl.DataType) -> None: lists = [ @@ -229,7 +222,31 @@ def test_float(float_dtype: pl.DataType) -> None: ), [[[5, 7]], [[]], [None]], ), - # TODO: structs + ( + pl.Series( + [[[1, 2]], [[4, 5]], [[None, 3]], [None], None], + dtype=pl.List(pl.Array(pl.Int64(), 2)), + ), + [[5, 7]], + ), + ( + pl.Series( + [ + [{"a": 1, "b": 2}, None], + [{"a": 3, "b": 4}, {"a": None, "b": 2}], + None, + ], + dtype=pl.List(pl.Struct({"a": pl.Int64(), "b": pl.Int64()})), + ), + [{"a": 7, "b": None}, {"a": 6, "b": 4}], + ), + ( + pl.Series( + [["a", "c"], [None, "b"], ["b", "a", "a", "c"], None, [None]], + dtype=pl.List(pl.Enum(["c", "b", "a"])), + ), + [], + ), ], ) def test_other_types(list_series: pl.Series, extra_values: list[PythonLiteral]) -> None: @@ -251,15 +268,11 @@ def test_other_types(list_series: pl.Series, extra_values: list[PythonLiteral]) assert_index_of_in_from_scalar(list_series, value) -def test_array() -> None: - array_dtype = pl.Array(pl.Int64(), 2) - list_series = pl.Series( - [[[1, 2]], [[4, 5]], [[None, 3]], [None], None], - dtype=pl.List(array_dtype), +@pytest.mark.xfail(reason="Depends on Series.index_of supporting Categoricals") +def test_categorical() -> None: + # When this starts passing, convert to test_other_types entry above. + series = pl.Series( + [["a", "c"], [None, "b"], ["b", "a", "a", "c"], None, [None]], + dtype=pl.List(pl.Categorical), ) - values = [[1, 2], [4, 5], [None, 3], [5, 7], None] - for value in values: - assert_index_of_in_from_scalar(list_series, value) - - needles = pl.Series(values[:5], dtype=array_dtype) - assert_index_of_in_from_series(list_series, needles) + assert series.list.index_of_in("b").to_list() == [None, 1, 0, None, None] From 4a96f99787f2056caecef9fdd6a43550502170c1 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Thu, 6 Feb 2025 12:04:27 -0500 Subject: [PATCH 15/30] Drop debug prints. --- crates/polars-expr/src/expressions/cast.rs | 1 - crates/polars-expr/src/planner.rs | 1 - 2 files changed, 2 deletions(-) diff --git a/crates/polars-expr/src/expressions/cast.rs b/crates/polars-expr/src/expressions/cast.rs index c7c468d49758..828117603231 100644 --- a/crates/polars-expr/src/expressions/cast.rs +++ b/crates/polars-expr/src/expressions/cast.rs @@ -16,7 +16,6 @@ pub struct CastExpr { impl CastExpr { fn finish(&self, input: &Column) -> PolarsResult { - println!("CASTEXPR {input:?} to {}", self.dtype); input.cast_with_options(&self.dtype, self.options) } } diff --git a/crates/polars-expr/src/planner.rs b/crates/polars-expr/src/planner.rs index 6be15b28cd5e..8156000a5311 100644 --- a/crates/polars-expr/src/planner.rs +++ b/crates/polars-expr/src/planner.rs @@ -430,7 +430,6 @@ fn create_physical_expr_inner( dtype, options, } => { - println!("AExpr::Cast HAD DTYPE {dtype}"); let phys_expr = create_physical_expr_inner(*expr, ctxt, expr_arena, schema, state)?; Ok(Arc::new(CastExpr { input: phys_expr, From 46feae9d7b4cb2f602982abf5f18ffab3cf5b0f4 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Thu, 6 Feb 2025 14:24:16 -0500 Subject: [PATCH 16/30] Mismatched length. --- .../src/chunked_array/list/index_of_in.rs | 6 ++++++ .../namespaces/list/test_index_of_in.py | 16 +++++++++++++++- 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index 684faee5ad0a..2b0cad5d4021 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -23,6 +23,12 @@ pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult None: assert series.list.index_of_in(needles).to_list() == [1, None, 2] +def test_mismatched_length() -> None: + """ + Mismatched lengths result in an error. + + Unfortunately a length 1 Series will be treated as a _scalar_, which seems + weird, but that's how e.g. list.contains() works so maybe that's + intentional. + """ + series = pl.Series([[1, 3], [3, 2], [4, 5, 3]]) + needles = pl.Series([3, 2]) + with pytest.raises(ComputeError, match="shapes don't match"): + series.list.index_of_in(pl.Series(needles)) + + @pytest.mark.parametrize("float_dtype", FLOAT_DTYPES) def test_float(float_dtype: pl.DataType) -> None: lists = [ From 23230e0247826dcfce83d4578a4674a8b253f21a Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Thu, 6 Feb 2025 15:25:42 -0500 Subject: [PATCH 17/30] Expand. --- .../polars-ops/src/chunked_array/list/index_of_in.rs | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index 2b0cad5d4021..6f9bca53c12e 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -11,9 +11,13 @@ pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult instead of a Scalar - // which implies AnyValue<'static>? + // The clone() could perhaps be removed by refactoring + // index_of() to take a (scalar) Column, and then we could + // just pass through the Column as is. Still a bunch of + // duplicate work even in that case though, so possibly the + // real solution would be duplicating the code in + // index_of(), or refactoring so its guts are more + // available. index_of(subseries.as_ref(), needle.clone()) .unwrap() .map(|v| v.try_into().unwrap()), From 5d40c9c7131fe46285e0f7acfddd0b73dd3b6cc5 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Fri, 7 Feb 2025 12:32:56 -0500 Subject: [PATCH 18/30] Optimize the scalar case of list_index_of_in(). --- .../src/chunked_array/list/index_of_in.rs | 128 ++++++++++++------ crates/polars-ops/src/series/ops/index_of.rs | 42 +++--- .../namespaces/list/test_index_of_in.py | 11 ++ 3 files changed, 119 insertions(+), 62 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index 6f9bca53c12e..049edb6f47f7 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -1,57 +1,95 @@ use super::*; -use crate::series::index_of; +use crate::series::{index_of, index_of_null}; pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult { - let mut builder = PrimitiveChunkedBuilder::::new(ca.name().clone(), ca.len()); - let inner_dtype = ca.dtype().inner_dtype().unwrap(); + // Handle scalar case separately, since we can do some optimizations given + // the extra knowledge we have. if needles.len() == 1 { let needle = needles.get(0).unwrap(); - let needle_dtype = inner_dtype.clone(); - let needle = Scalar::new(needle_dtype, needle.into_static()); - ca.amortized_iter().for_each(|opt_series| { - if let Some(subseries) = opt_series { + return list_index_of_in_for_scalar(ca, needle); + } + + polars_ensure!( + ca.len() == needles.len(), + ComputeError: "shapes don't match: expected {} elements in 'index_of_in' comparison, got {}", + ca.len(), + needles.len() + ); + let mut builder = PrimitiveChunkedBuilder::::new(ca.name().clone(), ca.len()); + let needles = needles.rechunk(); + ca.amortized_iter() + // TODO iter() assumes a single chunk. could continue to use this + // and just rechunk(), or have needles also be a ChunkedArray, in + // which case we'd need to have to use one of the + // dispatch-on-dtype-and-cast-to-relevant-chunkedarray-type macros + // to duplicate the implementation code per dtype. + .zip(needles.iter()) + .for_each(|(opt_series, needle)| match (opt_series, needle) { + (None, _) => builder.append_null(), + (Some(subseries), needle) => { + let needle = Scalar::new(needles.dtype().clone(), needle.into_static()); builder.append_option( - // The clone() could perhaps be removed by refactoring - // index_of() to take a (scalar) Column, and then we could - // just pass through the Column as is. Still a bunch of - // duplicate work even in that case though, so possibly the - // real solution would be duplicating the code in - // index_of(), or refactoring so its guts are more - // available. - index_of(subseries.as_ref(), needle.clone()) + index_of(subseries.as_ref(), needle) .unwrap() .map(|v| v.try_into().unwrap()), ); - } else { - builder.append_null(); - } + }, }); - } else { - polars_ensure!( - ca.len() == needles.len(), - ComputeError: "shapes don't match: expected {} elements in 'index_of_in' comparison, got {}", - ca.len(), - needles.len() - ); - let needles = needles.rechunk(); - ca.amortized_iter() - // TODO iter() assumes a single chunk. could continue to use this - // and just rechunk(), or have needles also be a ChunkedArray, in - // which case we'd need to have to use one of the - // dispatch-on-dtype-and-cast-to-relevant-chunkedarray-type macros - // to duplicate the implementation code per dtype. - .zip(needles.iter()) - .for_each(|(opt_series, needle)| match (opt_series, needle) { - (None, _) => builder.append_null(), - (Some(subseries), needle) => { - let needle = Scalar::new(needles.dtype().clone(), needle.into_static()); - builder.append_option( - index_of(subseries.as_ref(), needle) - .unwrap() - .map(|v| v.try_into().unwrap()), - ); - }, - }); - } + + Ok(builder.finish().into()) +} + +macro_rules! process_series_for_numeric_value { + ($extractor:ident, $needle:ident) => {{ + use arrow::array::PrimitiveArray; + + use crate::series::index_of_value; + + let needle = $needle.extract::<$extractor>().unwrap(); + Box::new(move |subseries| { + index_of_value::<_, PrimitiveArray<$extractor>>(subseries.$extractor().unwrap(), needle) + }) + }}; +} + +fn list_index_of_in_for_scalar(ca: &ListChunked, needle: AnyValue<'_>) -> PolarsResult { + let mut builder = PrimitiveChunkedBuilder::::new(ca.name().clone(), ca.len()); + let needle = needle.into_static(); + let inner_dtype = ca.dtype().inner_dtype().unwrap(); + let needle_dtype = needle.dtype(); + + let process_series: Box Option> = match needle_dtype { + DataType::Null => Box::new(|subseries| index_of_null(subseries)), + #[cfg(feature = "dtype-u8")] + DataType::UInt8 => process_series_for_numeric_value!(u8, needle), + #[cfg(feature = "dtype-u16")] + DataType::UInt16 => process_series_for_numeric_value!(u16, needle), + DataType::UInt32 => process_series_for_numeric_value!(u32, needle), + DataType::UInt64 => process_series_for_numeric_value!(u64, needle), + #[cfg(feature = "dtype-i8")] + DataType::Int8 => process_series_for_numeric_value!(i8, needle), + #[cfg(feature = "dtype-i16")] + DataType::Int16 => process_series_for_numeric_value!(i16, needle), + DataType::Int32 => process_series_for_numeric_value!(i32, needle), + DataType::Int64 => process_series_for_numeric_value!(i64, needle), + #[cfg(feature = "dtype-i128")] + DataType::Int128 => process_series_for_numeric_value!(i128, needle), + DataType::Float32 => process_series_for_numeric_value!(f32, needle), + DataType::Float64 => process_series_for_numeric_value!(f64, needle), + // Just use the general purpose index_of() function: + _ => Box::new(|subseries| { + let needle = Scalar::new(inner_dtype.clone(), needle.clone()); + index_of(subseries, needle).unwrap() + }), + }; + + ca.amortized_iter().for_each(|opt_series| { + if let Some(subseries) = opt_series { + builder + .append_option(process_series(subseries.as_ref()).map(|v| v.try_into().unwrap())); + } else { + builder.append_null(); + } + }); Ok(builder.finish().into()) } diff --git a/crates/polars-ops/src/series/ops/index_of.rs b/crates/polars-ops/src/series/ops/index_of.rs index 6c2536e263ec..05255dc5d0e0 100644 --- a/crates/polars-ops/src/series/ops/index_of.rs +++ b/crates/polars-ops/src/series/ops/index_of.rs @@ -5,7 +5,10 @@ use polars_utils::total_ord::TotalEq; use row_encode::encode_rows_unordered; /// Find the index of the value, or ``None`` if it can't be found. -fn index_of_value<'a, DT, AR>(ca: &'a ChunkedArray
, value: AR::ValueT<'a>) -> Option +pub(crate) fn index_of_value<'a, DT, AR>( + ca: &'a ChunkedArray
, + value: AR::ValueT<'a>, +) -> Option where DT: PolarsDataType, AR: StaticArray, @@ -51,14 +54,31 @@ macro_rules! try_index_of_numeric_ca { ($ca:expr, $value:expr) => {{ let ca = $ca; let value = $value; - // extract() returns None if casting failed, so consider an extract() - // failure as not finding the value. Nulls should have been handled - // earlier. + // extract() returns None if casting failed, and by this point Nulls + // have been handled, and everything should have been cast to matching + // dtype otherwise. let value = value.value().extract().unwrap(); index_of_numeric_value(ca, value) }}; } +/// Find the index of nulls within a Series. +pub(crate) fn index_of_null(series: &Series) -> Option { + let mut index = 0; + for chunk in series.chunks() { + let length = chunk.len(); + if let Some(bitmap) = chunk.validity() { + let leading_ones = bitmap.leading_ones(); + if leading_ones < length { + return Some(index + leading_ones); + } + } else { + index += length; + } + } + return None; +} + /// Find the index of a given value (the first and only entry in `value_series`) /// within the series. pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult> { @@ -80,19 +100,7 @@ pub fn index_of(series: &Series, needle: Scalar) -> PolarsResult> // Series is not null, and the value is null: if needle.is_null() { - let mut index = 0; - for chunk in series.chunks() { - let length = chunk.len(); - if let Some(bitmap) = chunk.validity() { - let leading_ones = bitmap.leading_ones(); - if leading_ones < length { - return Ok(Some(index + leading_ones)); - } - } else { - index += length; - } - } - return Ok(None); + return Ok(index_of_null(series)); } if series.dtype().is_primitive_numeric() { diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index 857fa8b90410..1b033e058057 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -158,6 +158,8 @@ def test_float(float_dtype: pl.DataType) -> None: [1.5, np.nan, np.inf], [3.0, None, -np.inf], [0.0, -0.0, -np.nan], + None, + [None, None], ] lists_series = pl.Series(lists, dtype=pl.List(float_dtype)) @@ -290,3 +292,12 @@ def test_categorical() -> None: dtype=pl.List(pl.Categorical), ) assert series.list.index_of_in("b").to_list() == [None, 1, 0, None, None] + + +def test_nulls() -> None: + series = pl.Series([[None, None], None], dtype=pl.List(pl.Null)) + assert series.list.index_of_in(None).to_list() == [0, None] + + series = pl.Series([None, [None, None]], dtype=pl.List(pl.Int64)) + assert series.list.index_of_in(None).to_list() == [None, 0] + assert series.list.index_of_in(1).to_list() == [None, None] From ddbd9a436117e27954884484fad26329c2021174 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Fri, 7 Feb 2025 17:37:24 -0500 Subject: [PATCH 19/30] Optimize a bunch of code paths. --- .../src/chunked_array/list/index_of_in.rs | 65 +++++++++++++++++-- 1 file changed, 58 insertions(+), 7 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index 049edb6f47f7..29d510df1c93 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -1,6 +1,48 @@ +use polars_core::match_arrow_dtype_apply_macro_ca; + use super::*; use crate::series::{index_of, index_of_null}; +macro_rules! to_anyvalue_iterator { + ($ca:expr) => {{ + use polars_core::prelude::AnyValue; + Box::new($ca.iter().map(AnyValue::from)) + }}; +} + +fn series_to_anyvalue_iter(series: &Series) -> Box + '_> { + let dtype = series.dtype(); + match dtype { + DataType::Date => { + return to_anyvalue_iterator!(series.date().unwrap()); + }, + DataType::Datetime(_, _) => { + return to_anyvalue_iterator!(series.datetime().unwrap()); + }, + DataType::Time => { + return to_anyvalue_iterator!(series.time().unwrap()); + }, + DataType::Duration(_) => { + return to_anyvalue_iterator!(series.duration().unwrap()); + }, + DataType::Binary => { + return to_anyvalue_iterator!(series.binary().unwrap()); + }, + DataType::Decimal(_, _) => { + return to_anyvalue_iterator!(series.decimal().unwrap()); + }, + _ => (), + }; + match_arrow_dtype_apply_macro_ca!( + series, + to_anyvalue_iterator, + to_anyvalue_iterator, + to_anyvalue_iterator + ) +} + +/// Given a needle, or needles, find the corresponding needle in each value of a +/// ListChunked. pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult { // Handle scalar case separately, since we can do some optimizations given // the extra knowledge we have. @@ -15,15 +57,24 @@ pub fn list_index_of_in(ca: &ListChunked, needles: &Series) -> PolarsResult::new(ca.name().clone(), ca.len()); - let needles = needles.rechunk(); ca.amortized_iter() - // TODO iter() assumes a single chunk. could continue to use this - // and just rechunk(), or have needles also be a ChunkedArray, in - // which case we'd need to have to use one of the - // dispatch-on-dtype-and-cast-to-relevant-chunkedarray-type macros - // to duplicate the implementation code per dtype. - .zip(needles.iter()) + .zip(needle_iter) .for_each(|(opt_series, needle)| match (opt_series, needle) { (None, _) => builder.append_null(), (Some(subseries), needle) => { From a2d3e76f2e91499900aad28f8a9a858516bd89c8 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Fri, 7 Feb 2025 17:37:32 -0500 Subject: [PATCH 20/30] Bit more testing. --- .../namespaces/list/test_index_of_in.py | 23 +++++++++++++------ 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index 1b033e058057..181b2559e845 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -152,6 +152,14 @@ def test_mismatched_length() -> None: series.list.index_of_in(pl.Series(needles)) +def all_values(list_series: pl.Series) -> list: + values = [] + for subseries in list_series.to_list(): + if subseries is not None: + values.extend(subseries) + return values + + @pytest.mark.parametrize("float_dtype", FLOAT_DTYPES) def test_float(float_dtype: pl.DataType) -> None: lists = [ @@ -164,12 +172,17 @@ def test_float(float_dtype: pl.DataType) -> None: lists_series = pl.Series(lists, dtype=pl.List(float_dtype)) # Scalar - for value in sum(lists, []) + [3.5, np.float64(1.5), np.float32(3.0)]: + for value in all_values(lists_series) + [ + None, + 3.5, + np.float64(1.5), + np.float32(3.0), + ]: assert_index_of_in_from_scalar(lists_series, value) # Series assert_index_of_in_from_series( - lists_series, pl.Series([1.5, -np.inf, -np.nan], dtype=float_dtype) + lists_series, pl.Series([1.5, -np.inf, -np.nan, 3, None], dtype=float_dtype) ) @@ -275,11 +288,7 @@ def test_other_types(list_series: pl.Series, extra_values: list[PythonLiteral]) ) assert_index_of_in_from_series(list_series, needles_series) - values = [None] - for subseries in list_series.to_list(): - if subseries is not None: - values.extend(subseries) - values.extend(extra_values) + values = all_values(list_series) + extra_values + [None] for value in values: assert_index_of_in_from_scalar(list_series, value) From 119d4e92a1cfbfe75aaab5d1229685bd3985de8a Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Fri, 7 Feb 2025 17:51:48 -0500 Subject: [PATCH 21/30] Lints and cleanups. --- .../src/chunked_array/list/index_of_in.rs | 3 ++- .../polars-ops/src/chunked_array/list/mod.rs | 8 +++---- crates/polars-ops/src/series/ops/index_of.rs | 2 +- .../src/plans/conversion/type_coercion/mod.rs | 23 ++++++++++++------- 4 files changed, 22 insertions(+), 14 deletions(-) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index 29d510df1c93..76b3c9aefca3 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -103,6 +103,7 @@ macro_rules! process_series_for_numeric_value { }}; } +#[allow(clippy::type_complexity)] // For the Box fn list_index_of_in_for_scalar(ca: &ListChunked, needle: AnyValue<'_>) -> PolarsResult { let mut builder = PrimitiveChunkedBuilder::::new(ca.name().clone(), ca.len()); let needle = needle.into_static(); @@ -110,7 +111,7 @@ fn list_index_of_in_for_scalar(ca: &ListChunked, needle: AnyValue<'_>) -> Polars let needle_dtype = needle.dtype(); let process_series: Box Option> = match needle_dtype { - DataType::Null => Box::new(|subseries| index_of_null(subseries)), + DataType::Null => Box::new(index_of_null), #[cfg(feature = "dtype-u8")] DataType::UInt8 => process_series_for_numeric_value!(u8, needle), #[cfg(feature = "dtype-u16")] diff --git a/crates/polars-ops/src/chunked_array/list/mod.rs b/crates/polars-ops/src/chunked_array/list/mod.rs index f6b40ecf98e0..76f8b8414fc8 100644 --- a/crates/polars-ops/src/chunked_array/list/mod.rs +++ b/crates/polars-ops/src/chunked_array/list/mod.rs @@ -6,6 +6,8 @@ mod count; mod dispersion; #[cfg(feature = "hash")] pub(crate) mod hash; +#[cfg(feature = "list_index_of_in")] +mod index_of_in; mod min_max; mod namespace; #[cfg(feature = "list_sets")] @@ -13,20 +15,18 @@ mod sets; mod sum_mean; #[cfg(feature = "list_to_struct")] mod to_struct; -#[cfg(feature = "list_index_of_in")] -mod index_of_in; #[cfg(feature = "list_count")] pub use count::*; #[cfg(not(feature = "list_count"))] use count::*; +#[cfg(feature = "list_index_of_in")] +pub use index_of_in::*; pub use namespace::*; #[cfg(feature = "list_sets")] pub use sets::*; #[cfg(feature = "list_to_struct")] pub use to_struct::*; -#[cfg(feature = "list_index_of_in")] -pub use index_of_in::*; pub trait AsList { fn as_list(&self) -> &ListChunked; diff --git a/crates/polars-ops/src/series/ops/index_of.rs b/crates/polars-ops/src/series/ops/index_of.rs index 05255dc5d0e0..29d74f64f3b4 100644 --- a/crates/polars-ops/src/series/ops/index_of.rs +++ b/crates/polars-ops/src/series/ops/index_of.rs @@ -76,7 +76,7 @@ pub(crate) fn index_of_null(series: &Series) -> Option { index += length; } } - return None; + None } /// Find the index of a given value (the first and only entry in `value_series`) diff --git a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs index b144e02d792d..0b660d146d70 100644 --- a/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs +++ b/crates/polars-plan/src/plans/conversion/type_coercion/mod.rs @@ -322,15 +322,22 @@ impl OptimizationRule for TypeCoercionRule { } for (i, (e, dtype)) in input.iter_mut().zip(dtypes).enumerate() { - let new_super_type = - if matches!(casting_rules, CastingRules::FirstArgInnerLossless) - && (i > 0) - { - // TODO get rid of unwrap(), will fail if first item is not a list/array - &super_type.inner_dtype().unwrap() + let new_super_type = if matches!( + casting_rules, + CastingRules::FirstArgInnerLossless + ) && (i > 0) + { + if let Some(inner_type) = super_type.inner_dtype() { + inner_type } else { - &super_type - }; + polars_bail!( + InvalidOperation: + "FirstArgInnerLossless only makes sense for types like list or array" + ); + } + } else { + &super_type + }; match super_type { #[cfg(feature = "dtype-categorical")] DataType::Categorical(_, _) if dtype.is_string() => { From c1dfb3d583316f6fe68eb4e0bfa552b1ecf72162 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 10 Feb 2025 17:29:24 -0500 Subject: [PATCH 22/30] Validate input types. --- crates/polars-ops/src/chunked_array/list/index_of_in.rs | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index 76b3c9aefca3..0d82af4a68bb 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -105,6 +105,13 @@ macro_rules! process_series_for_numeric_value { #[allow(clippy::type_complexity)] // For the Box fn list_index_of_in_for_scalar(ca: &ListChunked, needle: AnyValue<'_>) -> PolarsResult { + polars_ensure!( + ca.dtype().inner_dtype().unwrap() == &needle.dtype() || needle.dtype().is_null(), + ComputeError: "dtypes did't match: series values have dtype {} and needle has dtype {}", + ca.dtype().inner_dtype().unwrap(), + needle.dtype() + ); + let mut builder = PrimitiveChunkedBuilder::::new(ca.name().clone(), ca.len()); let needle = needle.into_static(); let inner_dtype = ca.dtype().inner_dtype().unwrap(); From 86ed0677d06084c0c1ff0b29284572be0666260a Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 10 Feb 2025 17:32:52 -0500 Subject: [PATCH 23/30] Documentation. --- py-polars/polars/expr/list.py | 27 +++++++++++++++++++++------ py-polars/polars/series/list.py | 22 ++++++++++++++++++++-- 2 files changed, 41 insertions(+), 8 deletions(-) diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index 7b4d828f1406..f49535304527 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -1059,20 +1059,35 @@ def count_matches(self, element: IntoExpr) -> Expr: element = parse_into_expression(element, str_as_lit=True) return wrap_expr(self._pyexpr.list_count_matches(element)) - def index_of_in(self, element: IntoExpr) -> Expr: + def index_of_in(self, needles: IntoExpr) -> Expr: """ - TODO + For each List, return the index of the first value equal to a needle. Parameters ---------- needles - TODO + The value(s) to search for. Examples -------- - TODO - """ - element = parse_into_expression(element, str_as_lit=True, list_as_series=False) + >>> df = pl.DataFrame({ + ... "lists": [[1, 2, 3], [], [None, 3], [5, 6, 7]], + ... "needles": [3, 0, 3, 7], + ... }) + >>> df.select(pl.col("lists").list.index_of_in(pl.col("needles"))) + shape: (4, 1) + ┌───────┐ + │ lists │ + │ --- │ + │ u32 │ + ╞═══════╡ + │ 2 │ + │ null │ + │ 1 │ + │ 2 │ + └───────┘ + """ + element = parse_into_expression(needles, str_as_lit=True, list_as_series=False) return wrap_expr(self._pyexpr.list_index_of_in(element)) def to_array(self, width: int) -> Expr: diff --git a/py-polars/polars/series/list.py b/py-polars/polars/series/list.py index 54bcd6763c5e..dfa8f855f60f 100644 --- a/py-polars/polars/series/list.py +++ b/py-polars/polars/series/list.py @@ -1055,7 +1055,25 @@ def set_symmetric_difference(self, other: Series) -> Series: ] """ # noqa: W505 - def index_of_in(self, element: IntoExpr) -> Series: + def index_of_in(self, needles: IntoExpr) -> Series: """ - TODO + For each list in the series, return the index of the first value equal to the corresponding needle. + + Parameters + ---------- + needles + The value(s) to search for. + + Examples + -------- + >>> a = pl.Series([[1, 2, 3], [], [None, 3], [5, 6, 7]]) + >>> a.list.index_of_in(pl.Series([3, 0, 3, 7])) + shape: (4,) + Series: '' [u32] + [ + 2 + null + 1 + 2 + ] """ # noqa: W505 From 1ab116c8058bf106ab39bc09fc9b147d4f0aac80 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 10 Feb 2025 18:25:24 -0500 Subject: [PATCH 24/30] Cover another edge case. --- .../polars-ops/src/chunked_array/list/index_of_in.rs | 2 +- .../operations/namespaces/list/test_index_of_in.py | 10 ++++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index 0d82af4a68bb..bc611f61cb55 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -107,7 +107,7 @@ macro_rules! process_series_for_numeric_value { fn list_index_of_in_for_scalar(ca: &ListChunked, needle: AnyValue<'_>) -> PolarsResult { polars_ensure!( ca.dtype().inner_dtype().unwrap() == &needle.dtype() || needle.dtype().is_null(), - ComputeError: "dtypes did't match: series values have dtype {} and needle has dtype {}", + ComputeError: "dtypes didn't match: series values have dtype {} and needle has dtype {}", ca.dtype().inner_dtype().unwrap(), needle.dtype() ); diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index 181b2559e845..213c43e5c785 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -310,3 +310,13 @@ def test_nulls() -> None: series = pl.Series([None, [None, None]], dtype=pl.List(pl.Int64)) assert series.list.index_of_in(None).to_list() == [None, 0] assert series.list.index_of_in(1).to_list() == [None, None] + + +def test_wrong_type() -> None: + series = pl.Series([[1, 2, 3], [4, 5]]) + with pytest.raises( + ComputeError, + match=r"dtypes didn't match: series values have dtype i64 and needle has dtype list\[i64\]", + ): + # Searching for a list won't work: + series.list.index_of_in([1, 2]) From da11595e3e46fbbc08541353f1fbe641f2777d0c Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Mon, 10 Feb 2025 18:28:17 -0500 Subject: [PATCH 25/30] Lint. --- .../tests/unit/operations/namespaces/list/test_index_of_in.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index 213c43e5c785..a2a9d5c927fd 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -10,9 +10,8 @@ import pytest import polars as pl -from polars.exceptions import InvalidOperationError, ComputeError +from polars.exceptions import ComputeError, InvalidOperationError from polars.testing import assert_frame_equal, assert_series_equal - from tests.unit.conftest import FLOAT_DTYPES, INTEGER_DTYPES from tests.unit.operations.test_index_of import get_expected_index From 1df81f3aa519823ea9a27a9d78e293f8cb2a225a Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 11 Feb 2025 10:08:05 -0500 Subject: [PATCH 26/30] Type checks. --- .../namespaces/list/test_index_of_in.py | 32 +++++++++---------- 1 file changed, 15 insertions(+), 17 deletions(-) diff --git a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py index a2a9d5c927fd..3242ad0a44e7 100644 --- a/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py +++ b/py-polars/tests/unit/operations/namespaces/list/test_index_of_in.py @@ -16,13 +16,13 @@ from tests.unit.operations.test_index_of import get_expected_index if TYPE_CHECKING: - from polars._typing import IntoExpr, PythonLiteral + from polars._typing import PythonLiteral IdxType = pl.get_index_type() def assert_index_of_in_from_scalar( - list_series: pl.Series, value: PythonLiteral + list_series: pl.Series, value: PythonLiteral | None ) -> None: expected_indexes = [ None if sub_series is None else get_expected_index(sub_series, value) @@ -80,36 +80,34 @@ def test_index_of_in_from_series() -> None: assert_index_of_in_from_series(list_series, values) -def to_int(expr: pl.Expr) -> int: - return pl.select(expr).item() - - @pytest.mark.parametrize("lists_dtype", INTEGER_DTYPES) @pytest.mark.parametrize("values_dtype", INTEGER_DTYPES) -def test_integer(lists_dtype: pl.DataType, values_dtype: pl.DataType) -> None: +def test_integer(lists_dtype: pl.NumericType, values_dtype: pl.NumericType) -> None: + def to_int(expr: pl.Expr) -> int: + return pl.select(expr).item() # type: ignore[no-any-return] + lists = [ [51, 3], [None, 4], None, - [to_int(lists_dtype.max()), 3], # type: ignore[attr-defined] - [6, to_int(lists_dtype.min())], # type: ignore[attr-defined] + [to_int(lists_dtype.max()), 3], + [6, to_int(lists_dtype.min())], ] lists_series = pl.Series(lists, dtype=pl.List(lists_dtype)) chunked_series = pl.concat( [pl.Series([[100, 7]], dtype=pl.List(lists_dtype)), lists_series], rechunk=False ) - values = [ + values: list[None | PythonLiteral] = [ to_int(v) for v in [lists_dtype.max() - 1, lists_dtype.min() + 1] - ] # type: ignore[attr-defined] + ] for sublist in lists: if sublist is None: values.append(None) else: - values.extend(sublist) + values.extend(sublist) # type: ignore[arg-type] # Scalars: for s in [lists_series, chunked_series]: - value: IntoExpr for value in values: assert_index_of_in_from_scalar(s, value) @@ -151,7 +149,7 @@ def test_mismatched_length() -> None: series.list.index_of_in(pl.Series(needles)) -def all_values(list_series: pl.Series) -> list: +def all_values(list_series: pl.Series) -> list[object]: values = [] for subseries in list_series.to_list(): if subseries is not None: @@ -177,7 +175,7 @@ def test_float(float_dtype: pl.DataType) -> None: np.float64(1.5), np.float32(3.0), ]: - assert_index_of_in_from_scalar(lists_series, value) + assert_index_of_in_from_scalar(lists_series, value) # type: ignore[arg-type] # Series assert_index_of_in_from_series( @@ -283,13 +281,13 @@ def test_other_types(list_series: pl.Series, extra_values: list[PythonLiteral]) None if sublist is None else sublist[i % len(sublist)] for (i, sublist) in enumerate(list_series) ], - dtype=list_series.dtype.inner, + dtype=list_series.dtype.inner, # type: ignore[attr-defined] ) assert_index_of_in_from_series(list_series, needles_series) values = all_values(list_series) + extra_values + [None] for value in values: - assert_index_of_in_from_scalar(list_series, value) + assert_index_of_in_from_scalar(list_series, value) # type: ignore [arg-type] @pytest.mark.xfail(reason="Depends on Series.index_of supporting Categoricals") From 5c8ad6dbc56bd4253e188b198160f588b79ef45f Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 11 Feb 2025 10:08:10 -0500 Subject: [PATCH 27/30] Reformat. --- py-polars/polars/expr/list.py | 10 ++++++---- py-polars/tests/unit/operations/test_index_of.py | 4 +++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/py-polars/polars/expr/list.py b/py-polars/polars/expr/list.py index f49535304527..84f9db6a8fb4 100644 --- a/py-polars/polars/expr/list.py +++ b/py-polars/polars/expr/list.py @@ -1070,10 +1070,12 @@ def index_of_in(self, needles: IntoExpr) -> Expr: Examples -------- - >>> df = pl.DataFrame({ - ... "lists": [[1, 2, 3], [], [None, 3], [5, 6, 7]], - ... "needles": [3, 0, 3, 7], - ... }) + >>> df = pl.DataFrame( + ... { + ... "lists": [[1, 2, 3], [], [None, 3], [5, 6, 7]], + ... "needles": [3, 0, 3, 7], + ... } + ... ) >>> df.select(pl.col("lists").list.index_of_in(pl.col("needles"))) shape: (4, 1) ┌───────┐ diff --git a/py-polars/tests/unit/operations/test_index_of.py b/py-polars/tests/unit/operations/test_index_of.py index 1450ccbfc1c6..b725f0dff468 100644 --- a/py-polars/tests/unit/operations/test_index_of.py +++ b/py-polars/tests/unit/operations/test_index_of.py @@ -41,7 +41,9 @@ def get_expected_index(series: pl.Series, value: IntoExpr) -> int | None: break else: try: - expected_index = to_python(series).index(to_python(value)) + expected_index = to_python(series).index( # type: ignore[attr-defined] + to_python(value) + ) except ValueError: expected_index = None if expected_index == -1: From 00887ce66ef03c1d85236c1f1cec83205c9eac10 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 11 Feb 2025 10:20:14 -0500 Subject: [PATCH 28/30] Unnecessary. --- crates/polars-plan/src/plans/expr_ir.rs | 4 ---- 1 file changed, 4 deletions(-) diff --git a/crates/polars-plan/src/plans/expr_ir.rs b/crates/polars-plan/src/plans/expr_ir.rs index 38359e5f535b..a8f062d74e30 100644 --- a/crates/polars-plan/src/plans/expr_ir.rs +++ b/crates/polars-plan/src/plans/expr_ir.rs @@ -262,10 +262,6 @@ impl ExprIR { } } - pub fn output_dtype(&self) -> &OnceLock { - &self.output_dtype - } - pub fn field( &self, schema: &Schema, From 36556d807e4a1d21d04f8261df2ad8b8088b3483 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 11 Feb 2025 11:01:40 -0500 Subject: [PATCH 29/30] Add dtype feature conditions. --- crates/polars-ops/src/chunked_array/list/index_of_in.rs | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/crates/polars-ops/src/chunked_array/list/index_of_in.rs b/crates/polars-ops/src/chunked_array/list/index_of_in.rs index bc611f61cb55..f0146bcd7778 100644 --- a/crates/polars-ops/src/chunked_array/list/index_of_in.rs +++ b/crates/polars-ops/src/chunked_array/list/index_of_in.rs @@ -13,21 +13,26 @@ macro_rules! to_anyvalue_iterator { fn series_to_anyvalue_iter(series: &Series) -> Box + '_> { let dtype = series.dtype(); match dtype { + #[cfg(feature = "dtype-date")] DataType::Date => { return to_anyvalue_iterator!(series.date().unwrap()); }, + #[cfg(feature = "dtype-datetime")] DataType::Datetime(_, _) => { return to_anyvalue_iterator!(series.datetime().unwrap()); }, + #[cfg(feature = "dtype-time")] DataType::Time => { return to_anyvalue_iterator!(series.time().unwrap()); }, + #[cfg(feature = "dtype-duration")] DataType::Duration(_) => { return to_anyvalue_iterator!(series.duration().unwrap()); }, DataType::Binary => { return to_anyvalue_iterator!(series.binary().unwrap()); }, + #[cfg(feature = "dtype-decimal")] DataType::Decimal(_, _) => { return to_anyvalue_iterator!(series.decimal().unwrap()); }, From d6fa1d4a588f50cc586e1abe114883bb715aa6f9 Mon Sep 17 00:00:00 2001 From: Itamar Turner-Trauring Date: Tue, 11 Feb 2025 11:29:59 -0500 Subject: [PATCH 30/30] Feature list_index_of_in requires index_of feature. --- crates/polars-ops/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/polars-ops/Cargo.toml b/crates/polars-ops/Cargo.toml index 804a9dac4dd9..f1513588dc41 100644 --- a/crates/polars-ops/Cargo.toml +++ b/crates/polars-ops/Cargo.toml @@ -146,4 +146,4 @@ abs = [] cov = [] gather = [] replace = ["is_in"] -list_index_of_in = [] +list_index_of_in = ["index_of"]