Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(python): tighten assert_frame_equal for LazyFrames (don't collect until after the schema has been checked) #11331

Merged
Merged
60 changes: 33 additions & 27 deletions py-polars/polars/testing/asserts.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import textwrap
from typing import Any
from typing import Any, NoReturn

from polars import functions as F
from polars.dataframe import DataFrame
Expand Down Expand Up @@ -75,24 +75,20 @@ def assert_frame_equal(
>>> assert_frame_equal(df1, df2) # doctest: +SKIP
AssertionError: Values for column 'a' are different.
"""
if isinstance(left, LazyFrame) and isinstance(right, LazyFrame):
left, right = left.collect(), right.collect()
obj = "LazyFrames"
collect_input_frames = isinstance(left, LazyFrame) and isinstance(right, LazyFrame)
if collect_input_frames:
objs = "LazyFrames"
elif isinstance(left, DataFrame) and isinstance(right, DataFrame):
obj = "DataFrames"
objs = "DataFrames"
else:
raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right))
raise_assert_detail("Inputs", "unexpected input types", type(left), type(right))

if left.shape[0] != right.shape[0]: # type: ignore[union-attr]
raise_assert_detail(obj, "Length mismatch", left.shape, right.shape) # type: ignore[union-attr]

left_not_right = [c for c in left.columns if c not in right.columns]
if left_not_right:
if left_not_right := [c for c in left.columns if c not in right.columns]:
alexander-beedie marked this conversation as resolved.
Show resolved Hide resolved
raise AssertionError(
f"columns {left_not_right!r} in left frame, but not in right"
)
right_not_left = [c for c in right.columns if c not in left.columns]
if right_not_left:

if right_not_left := [c for c in right.columns if c not in left.columns]:
raise AssertionError(
f"columns {right_not_left!r} in right frame, but not in left"
)
Expand All @@ -102,6 +98,18 @@ def assert_frame_equal(
f"columns are not in the same order:\n{left.columns!r}\n{right.columns!r}"
)

if collect_input_frames:
if check_dtype: # check this _before_ we collect
left_schema, right_schema = left.schema, right.schema
if left_schema != right_schema:
raise_assert_detail(
objs, "lazy schemas are not equal", left_schema, right_schema
)
left, right = left.collect(), right.collect() # type: ignore[union-attr]

if left.shape[0] != right.shape[0]: # type: ignore[union-attr]
raise_assert_detail(objs, "length mismatch", left.shape, right.shape) # type: ignore[union-attr]

if not check_row_order:
try:
left = left.sort(by=left.columns)
Expand Down Expand Up @@ -250,13 +258,13 @@ def assert_series_equal(
isinstance(left, Series) # type: ignore[redundant-expr]
and isinstance(right, Series)
):
raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right))
raise_assert_detail("Inputs", "unexpected input types", type(left), type(right))

if len(left) != len(right):
raise_assert_detail("Series", "Length mismatch", len(left), len(right))
raise_assert_detail("Series", "length mismatch", len(left), len(right))

if check_names and left.name != right.name:
raise_assert_detail("Series", "Name mismatch", left.name, right.name)
raise_assert_detail("Series", "name mismatch", left.name, right.name)

_assert_series_inner(
left,
Expand Down Expand Up @@ -347,7 +355,7 @@ def _assert_series_inner(
) -> None:
"""Compare Series dtype + values."""
if check_dtype and left.dtype != right.dtype:
raise_assert_detail("Series", "Dtype mismatch", left.dtype, right.dtype)
raise_assert_detail("Series", "dtype mismatch", left.dtype, right.dtype)

if left.null_count() != right.null_count():
raise_assert_detail(
Expand Down Expand Up @@ -398,7 +406,7 @@ def _assert_series_inner(
if check_exact:
raise_assert_detail(
"Series",
"Exact value mismatch",
"exact value mismatch",
left=list(left),
right=list(right),
)
Expand Down Expand Up @@ -434,7 +442,7 @@ def _assert_series_inner(
if mismatch:
raise_assert_detail(
"Series",
f"Value mismatch{nan_info}",
f"value mismatch{nan_info}",
left=list(left),
right=list(right),
)
Expand Down Expand Up @@ -469,10 +477,10 @@ def _assert_series_nested(
s2,
)
elif (s1 is None and s2 is not None) or (s2 is None and s1 is not None):
raise_assert_detail("Series", "Nested value mismatch", s1, s2)
raise_assert_detail("Series", "nested value mismatch", s1, s2)
elif len(s1) != len(s2):
raise_assert_detail(
"Series", "Nested list length mismatch", len(s1), len(s2)
"Series", "nested list length mismatch", len(s1), len(s2)
)

_assert_series_inner(
Expand All @@ -493,13 +501,13 @@ def _assert_series_nested(
if len(ls.columns) != len(rs.columns):
raise_assert_detail(
"Series",
"Nested struct fields mismatch",
"nested struct fields mismatch",
len(ls.columns),
len(rs.columns),
)
elif len(ls) != len(rs):
raise_assert_detail(
"Series", "Nested struct length mismatch", len(ls), len(rs)
"Series", "nested struct length mismatch", len(ls), len(rs)
)
for s1, s2 in zip(ls, rs):
_assert_series_inner(
Expand All @@ -525,15 +533,13 @@ def raise_assert_detail(
left: Any,
right: Any,
exc: AssertionError | None = None,
) -> None:
) -> NoReturn:
"""Raise a detailed assertion error."""
__tracebackhide__ = True

error_msg = textwrap.dedent(
f"""\
{obj} are different.

{detail}
{obj} are different ({detail})
[left]: {left}
[right]: {right}\
"""
Expand Down
5 changes: 3 additions & 2 deletions py-polars/tests/unit/test_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,9 @@ def test_lazyframe_membership_operator() -> None:

def test_apply() -> None:
ldf = pl.LazyFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]})
new = ldf.with_columns_seq(pl.col("a").map_batches(lambda s: s * 2).alias("foo"))

new = ldf.with_columns_seq(
pl.col("a").map_batches(lambda s: s * 2, return_dtype=pl.Int64).alias("foo")
)
expected = ldf.clone().with_columns((pl.col("a") * 2).alias("foo"))
assert_frame_equal(new, expected)
assert_frame_equal(new.collect(), expected.collect())
Expand Down
28 changes: 16 additions & 12 deletions py-polars/tests/unit/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def test_compare_series_value_mismatch() -> None:
srs2 = pl.Series([2, 3, 4])

assert_series_not_equal(srs1, srs2)
with pytest.raises(AssertionError, match="Series are different.\n\nValue mismatch"):
with pytest.raises(
AssertionError, match=r"Series are different \(value mismatch\)"
):
assert_series_equal(srs1, srs2)


Expand Down Expand Up @@ -62,9 +64,9 @@ def test_compare_series_nans_assert_equal() -> None:
(True, True),
):
if check_exact:
check_msg = "Exact value mismatch"
check_msg = "exact value mismatch"
else:
check_msg = f"Value mismatch.*nans_compare_equal={nans_equal}"
check_msg = f"value mismatch.*nans_compare_equal={nans_equal}"

with pytest.raises(AssertionError, match=check_msg):
assert_series_equal(
Expand Down Expand Up @@ -135,7 +137,7 @@ def test_compare_series_value_mismatch_string() -> None:

assert_series_not_equal(srs1, srs2)
with pytest.raises(
AssertionError, match="Series are different.\n\nExact value mismatch"
AssertionError, match=r"Series are different \(exact value mismatch\)"
):
assert_series_equal(srs1, srs2)

Expand All @@ -145,20 +147,22 @@ def test_compare_series_type_mismatch() -> None:
srs2 = pl.DataFrame({"col1": [2, 3, 4]})

with pytest.raises(
AssertionError, match="Inputs are different.\n\nUnexpected input types"
AssertionError, match=r"Inputs are different \(unexpected input types\)"
):
assert_series_equal(srs1, srs2) # type: ignore[arg-type]

srs3 = pl.Series([1.0, 2.0, 3.0])
assert_series_not_equal(srs1, srs3)
with pytest.raises(AssertionError, match="Series are different.\n\nDtype mismatch"):
with pytest.raises(
AssertionError, match=r"Series are different \(dtype mismatch\)"
):
assert_series_equal(srs1, srs3)


def test_compare_series_name_mismatch() -> None:
srs1 = pl.Series(values=[1, 2, 3], name="srs1")
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
with pytest.raises(AssertionError, match="Series are different.\n\nName mismatch"):
with pytest.raises(AssertionError, match=r"Series are different \(name mismatch\)"):
assert_series_equal(srs1, srs2)


Expand All @@ -168,7 +172,7 @@ def test_compare_series_shape_mismatch() -> None:

assert_series_not_equal(srs1, srs2)
with pytest.raises(
AssertionError, match="Series are different.\n\nLength mismatch"
AssertionError, match=r"Series are different \(length mismatch\)"
):
assert_series_equal(srs1, srs2)

Expand All @@ -177,7 +181,7 @@ def test_compare_series_value_exact_mismatch() -> None:
srs1 = pl.Series([1.0, 2.0, 3.0])
srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0])
with pytest.raises(
AssertionError, match="Series are different.\n\nExact value mismatch"
AssertionError, match=r"Series are different \(exact value mismatch\)"
):
assert_series_equal(srs1, srs2, check_exact=True)

Expand Down Expand Up @@ -277,7 +281,7 @@ def test_assert_frame_equal_types() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
srs1 = pl.Series(values=[1, 2], name="a")
with pytest.raises(
AssertionError, match="Inputs are different.\n\nUnexpected input types"
AssertionError, match=r"Inputs are different \(unexpected input types\)"
):
assert_frame_equal(df1, srs1) # type: ignore[arg-type]

Expand All @@ -286,7 +290,7 @@ def test_assert_frame_equal_length_mismatch() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2, 3]})
with pytest.raises(
AssertionError, match="DataFrames are different.\n\nLength mismatch"
AssertionError, match=r"DataFrames are different \(length mismatch\)"
):
assert_frame_equal(df1, df2)

Expand Down Expand Up @@ -1027,7 +1031,7 @@ def test_assert_series_equal_categorical_vs_str() -> None:
s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical)
s2 = pl.Series(["a", "b", "a"], dtype=pl.Utf8)

with pytest.raises(AssertionError, match="Dtype mismatch"):
with pytest.raises(AssertionError, match="dtype mismatch"):
assert_series_equal(s1, s2, categorical_as_str=True)


Expand Down