Skip to content

Commit

Permalink
fix(python): Handle non-Sequence iterables in filter (pola-rs#16254)
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned authored and Wouittone committed Jun 22, 2024
1 parent 1e8cc75 commit eddc944
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
9 changes: 7 additions & 2 deletions py-polars/polars/_utils/various.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@
from polars.dependencies import numpy as np

if TYPE_CHECKING:
from collections.abc import Reversible
from collections.abc import Iterator, Reversible

from polars import DataFrame
from polars.type_aliases import PolarsDataType, SizeUnit

if sys.version_info >= (3, 13):
from typing import TypeIs
else:
from typing_extensions import TypeIs

if sys.version_info >= (3, 10):
from typing import ParamSpec, TypeGuard
else:
Expand All @@ -56,7 +61,7 @@ def _process_null_values(
return null_values


def _is_generator(val: object) -> bool:
def _is_generator(val: object | Iterator[T]) -> TypeIs[Iterator[T]]:
return (
(isinstance(val, (Generator, Iterable)) and not isinstance(val, Sized))
or isinstance(val, MappingView)
Expand Down
5 changes: 4 additions & 1 deletion py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from polars._utils.unstable import issue_unstable_warning, unstable
from polars._utils.various import (
_in_notebook,
_is_generator,
is_bool_sequence,
is_sequence,
normalize_filepath,
Expand Down Expand Up @@ -2798,7 +2799,9 @@ def filter(
return self.clear() # type: ignore[return-value]
elif p is True:
continue # no-op; matches all rows
elif is_bool_sequence(p, include_series=True):
if _is_generator(p):
p = tuple(p)
if is_bool_sequence(p, include_series=True):
boolean_masks.append(pl.Series(p, dtype=Boolean))
elif (
(is_seq := is_sequence(p))
Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,28 @@ def test_filter_multiple_predicates() -> None:
assert ldf.filter(predicate="==").select("description").collect().item() == "eq"


@pytest.mark.parametrize(
"predicate",
[
[pl.lit(True)],
iter([pl.lit(True)]),
[True, True, True],
iter([True, True, True]),
(p for p in (pl.col("c") < 9,)),
(p for p in (pl.col("a") > 0, pl.col("b") > 0)),
],
)
def test_filter_seq_iterable_all_true(predicate: Any) -> None:
ldf = pl.LazyFrame(
{
"a": [1, 1, 1],
"b": [1, 1, 2],
"c": [3, 1, 2],
}
)
assert_frame_equal(ldf, ldf.filter(predicate))


def test_apply_custom_function() -> None:
ldf = pl.LazyFrame(
{
Expand Down

0 comments on commit eddc944

Please sign in to comment.