diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index d40e74b66bb1..0cd6b57b5bc4 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -1,7 +1,7 @@ from __future__ import annotations import textwrap -from typing import Any +from typing import Any, NoReturn from polars import functions as F from polars.dataframe import DataFrame @@ -75,24 +75,20 @@ def assert_frame_equal( >>> assert_frame_equal(df1, df2) # doctest: +SKIP AssertionError: Values for column 'a' are different. """ - if isinstance(left, LazyFrame) and isinstance(right, LazyFrame): - left, right = left.collect(), right.collect() - obj = "LazyFrames" + collect_input_frames = isinstance(left, LazyFrame) and isinstance(right, LazyFrame) + if collect_input_frames: + objs = "LazyFrames" elif isinstance(left, DataFrame) and isinstance(right, DataFrame): - obj = "DataFrames" + objs = "DataFrames" else: - raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right)) + raise_assert_detail("Inputs", "unexpected input types", type(left), type(right)) - if left.shape[0] != right.shape[0]: # type: ignore[union-attr] - raise_assert_detail(obj, "Length mismatch", left.shape, right.shape) # type: ignore[union-attr] - - left_not_right = [c for c in left.columns if c not in right.columns] - if left_not_right: + if left_not_right := [c for c in left.columns if c not in right.columns]: raise AssertionError( f"columns {left_not_right!r} in left frame, but not in right" ) - right_not_left = [c for c in right.columns if c not in left.columns] - if right_not_left: + + if right_not_left := [c for c in right.columns if c not in left.columns]: raise AssertionError( f"columns {right_not_left!r} in right frame, but not in left" ) @@ -102,6 +98,18 @@ def assert_frame_equal( f"columns are not in the same order:\n{left.columns!r}\n{right.columns!r}" ) + if collect_input_frames: + if check_dtype: # check this _before_ we collect + left_schema, right_schema = left.schema, right.schema + if left_schema != right_schema: + raise_assert_detail( + objs, "lazy schemas are not equal", left_schema, right_schema + ) + left, right = left.collect(), right.collect() # type: ignore[union-attr] + + if left.shape[0] != right.shape[0]: # type: ignore[union-attr] + raise_assert_detail(objs, "length mismatch", left.shape, right.shape) # type: ignore[union-attr] + if not check_row_order: try: left = left.sort(by=left.columns) @@ -250,13 +258,13 @@ def assert_series_equal( isinstance(left, Series) # type: ignore[redundant-expr] and isinstance(right, Series) ): - raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right)) + raise_assert_detail("Inputs", "unexpected input types", type(left), type(right)) if len(left) != len(right): - raise_assert_detail("Series", "Length mismatch", len(left), len(right)) + raise_assert_detail("Series", "length mismatch", len(left), len(right)) if check_names and left.name != right.name: - raise_assert_detail("Series", "Name mismatch", left.name, right.name) + raise_assert_detail("Series", "name mismatch", left.name, right.name) _assert_series_inner( left, @@ -347,7 +355,7 @@ def _assert_series_inner( ) -> None: """Compare Series dtype + values.""" if check_dtype and left.dtype != right.dtype: - raise_assert_detail("Series", "Dtype mismatch", left.dtype, right.dtype) + raise_assert_detail("Series", "dtype mismatch", left.dtype, right.dtype) if left.null_count() != right.null_count(): raise_assert_detail( @@ -398,7 +406,7 @@ def _assert_series_inner( if check_exact: raise_assert_detail( "Series", - "Exact value mismatch", + "exact value mismatch", left=list(left), right=list(right), ) @@ -434,7 +442,7 @@ def _assert_series_inner( if mismatch: raise_assert_detail( "Series", - f"Value mismatch{nan_info}", + f"value mismatch{nan_info}", left=list(left), right=list(right), ) @@ -469,10 +477,10 @@ def _assert_series_nested( s2, ) elif (s1 is None and s2 is not None) or (s2 is None and s1 is not None): - raise_assert_detail("Series", "Nested value mismatch", s1, s2) + raise_assert_detail("Series", "nested value mismatch", s1, s2) elif len(s1) != len(s2): raise_assert_detail( - "Series", "Nested list length mismatch", len(s1), len(s2) + "Series", "nested list length mismatch", len(s1), len(s2) ) _assert_series_inner( @@ -493,13 +501,13 @@ def _assert_series_nested( if len(ls.columns) != len(rs.columns): raise_assert_detail( "Series", - "Nested struct fields mismatch", + "nested struct fields mismatch", len(ls.columns), len(rs.columns), ) elif len(ls) != len(rs): raise_assert_detail( - "Series", "Nested struct length mismatch", len(ls), len(rs) + "Series", "nested struct length mismatch", len(ls), len(rs) ) for s1, s2 in zip(ls, rs): _assert_series_inner( @@ -525,15 +533,13 @@ def raise_assert_detail( left: Any, right: Any, exc: AssertionError | None = None, -) -> None: +) -> NoReturn: """Raise a detailed assertion error.""" __tracebackhide__ = True error_msg = textwrap.dedent( f"""\ - {obj} are different. - - {detail} + {obj} are different ({detail}) [left]: {left} [right]: {right}\ """ diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 3aac664a8f88..9adbbd2620bd 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -69,8 +69,9 @@ def test_lazyframe_membership_operator() -> None: def test_apply() -> None: ldf = pl.LazyFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) - new = ldf.with_columns_seq(pl.col("a").map_batches(lambda s: s * 2).alias("foo")) - + new = ldf.with_columns_seq( + pl.col("a").map_batches(lambda s: s * 2, return_dtype=pl.Int64).alias("foo") + ) expected = ldf.clone().with_columns((pl.col("a") * 2).alias("foo")) assert_frame_equal(new, expected) assert_frame_equal(new.collect(), expected.collect()) diff --git a/py-polars/tests/unit/test_testing.py b/py-polars/tests/unit/test_testing.py index 9a22d9a5a7a5..aff4391882a6 100644 --- a/py-polars/tests/unit/test_testing.py +++ b/py-polars/tests/unit/test_testing.py @@ -22,7 +22,9 @@ def test_compare_series_value_mismatch() -> None: srs2 = pl.Series([2, 3, 4]) assert_series_not_equal(srs1, srs2) - with pytest.raises(AssertionError, match="Series are different.\n\nValue mismatch"): + with pytest.raises( + AssertionError, match=r"Series are different \(value mismatch\)" + ): assert_series_equal(srs1, srs2) @@ -62,9 +64,9 @@ def test_compare_series_nans_assert_equal() -> None: (True, True), ): if check_exact: - check_msg = "Exact value mismatch" + check_msg = "exact value mismatch" else: - check_msg = f"Value mismatch.*nans_compare_equal={nans_equal}" + check_msg = f"value mismatch.*nans_compare_equal={nans_equal}" with pytest.raises(AssertionError, match=check_msg): assert_series_equal( @@ -135,7 +137,7 @@ def test_compare_series_value_mismatch_string() -> None: assert_series_not_equal(srs1, srs2) with pytest.raises( - AssertionError, match="Series are different.\n\nExact value mismatch" + AssertionError, match=r"Series are different \(exact value mismatch\)" ): assert_series_equal(srs1, srs2) @@ -145,20 +147,22 @@ def test_compare_series_type_mismatch() -> None: srs2 = pl.DataFrame({"col1": [2, 3, 4]}) with pytest.raises( - AssertionError, match="Inputs are different.\n\nUnexpected input types" + AssertionError, match=r"Inputs are different \(unexpected input types\)" ): assert_series_equal(srs1, srs2) # type: ignore[arg-type] srs3 = pl.Series([1.0, 2.0, 3.0]) assert_series_not_equal(srs1, srs3) - with pytest.raises(AssertionError, match="Series are different.\n\nDtype mismatch"): + with pytest.raises( + AssertionError, match=r"Series are different \(dtype mismatch\)" + ): assert_series_equal(srs1, srs3) def test_compare_series_name_mismatch() -> None: srs1 = pl.Series(values=[1, 2, 3], name="srs1") srs2 = pl.Series(values=[1, 2, 3], name="srs2") - with pytest.raises(AssertionError, match="Series are different.\n\nName mismatch"): + with pytest.raises(AssertionError, match=r"Series are different \(name mismatch\)"): assert_series_equal(srs1, srs2) @@ -168,7 +172,7 @@ def test_compare_series_shape_mismatch() -> None: assert_series_not_equal(srs1, srs2) with pytest.raises( - AssertionError, match="Series are different.\n\nLength mismatch" + AssertionError, match=r"Series are different \(length mismatch\)" ): assert_series_equal(srs1, srs2) @@ -177,7 +181,7 @@ def test_compare_series_value_exact_mismatch() -> None: srs1 = pl.Series([1.0, 2.0, 3.0]) srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0]) with pytest.raises( - AssertionError, match="Series are different.\n\nExact value mismatch" + AssertionError, match=r"Series are different \(exact value mismatch\)" ): assert_series_equal(srs1, srs2, check_exact=True) @@ -277,7 +281,7 @@ def test_assert_frame_equal_types() -> None: df1 = pl.DataFrame({"a": [1, 2]}) srs1 = pl.Series(values=[1, 2], name="a") with pytest.raises( - AssertionError, match="Inputs are different.\n\nUnexpected input types" + AssertionError, match=r"Inputs are different \(unexpected input types\)" ): assert_frame_equal(df1, srs1) # type: ignore[arg-type] @@ -286,7 +290,7 @@ def test_assert_frame_equal_length_mismatch() -> None: df1 = pl.DataFrame({"a": [1, 2]}) df2 = pl.DataFrame({"a": [1, 2, 3]}) with pytest.raises( - AssertionError, match="DataFrames are different.\n\nLength mismatch" + AssertionError, match=r"DataFrames are different \(length mismatch\)" ): assert_frame_equal(df1, df2) @@ -1027,7 +1031,7 @@ def test_assert_series_equal_categorical_vs_str() -> None: s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) s2 = pl.Series(["a", "b", "a"], dtype=pl.Utf8) - with pytest.raises(AssertionError, match="Dtype mismatch"): + with pytest.raises(AssertionError, match="dtype mismatch"): assert_series_equal(s1, s2, categorical_as_str=True)