From 65de008345731458fc56036362b3aa5da621b0b0 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 26 Sep 2023 13:32:34 +0000 Subject: [PATCH 1/6] tests(python): tighten `assert_frame_equal` for LazyFrames (don't collect until after the schema has been checked) --- py-polars/polars/testing/asserts.py | 31 +++++++++++++++++------------ 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index d40e74b66bb1..37786b5aa564 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -1,7 +1,7 @@ from __future__ import annotations import textwrap -from typing import Any +from typing import Any, NoReturn from polars import functions as F from polars.dataframe import DataFrame @@ -75,24 +75,21 @@ def assert_frame_equal( >>> assert_frame_equal(df1, df2) # doctest: +SKIP AssertionError: Values for column 'a' are different. """ - if isinstance(left, LazyFrame) and isinstance(right, LazyFrame): - left, right = left.collect(), right.collect() - obj = "LazyFrames" + if collect_input_frames := ( + isinstance(left, LazyFrame) and isinstance(right, LazyFrame) + ): + objs = "LazyFrames" elif isinstance(left, DataFrame) and isinstance(right, DataFrame): - obj = "DataFrames" + objs = "DataFrames" else: raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right)) - if left.shape[0] != right.shape[0]: # type: ignore[union-attr] - raise_assert_detail(obj, "Length mismatch", left.shape, right.shape) # type: ignore[union-attr] - - left_not_right = [c for c in left.columns if c not in right.columns] - if left_not_right: + if left_not_right := [c for c in left.columns if c not in right.columns]: raise AssertionError( f"columns {left_not_right!r} in left frame, but not in right" ) - right_not_left = [c for c in right.columns if c not in left.columns] - if right_not_left: + + if right_not_left := [c for c in right.columns if c not in left.columns]: raise AssertionError( f"columns {right_not_left!r} in right frame, but not in left" ) @@ -102,6 +99,14 @@ def assert_frame_equal( f"columns are not in the same order:\n{left.columns!r}\n{right.columns!r}" ) + if collect_input_frames: + if check_dtype: # check this _before_ we collect + assert left.schema == right.schema, "schema dtypes are not equal" + left, right = left.collect(), right.collect() # type: ignore[union-attr] + + if left.shape[0] != right.shape[0]: # type: ignore[union-attr] + raise_assert_detail(objs, "Length mismatch", left.shape, right.shape) # type: ignore[union-attr] + if not check_row_order: try: left = left.sort(by=left.columns) @@ -525,7 +530,7 @@ def raise_assert_detail( left: Any, right: Any, exc: AssertionError | None = None, -) -> None: +) -> NoReturn: """Raise a detailed assertion error.""" __tracebackhide__ = True From 80a6b521925891e34a0b7e02a893f5681761926f Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 26 Sep 2023 13:50:36 +0000 Subject: [PATCH 2/6] more detailed err msg --- py-polars/polars/testing/asserts.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index 37786b5aa564..ea1b6c3b967d 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -101,11 +101,14 @@ def assert_frame_equal( if collect_input_frames: if check_dtype: # check this _before_ we collect - assert left.schema == right.schema, "schema dtypes are not equal" + left_schema, right_schema = left.schema, right.schema + assert ( + left_schema == right_schema + ), f"lazy schemas are not equal\nleft: {left_schema}\nright: {right_schema}" left, right = left.collect(), right.collect() # type: ignore[union-attr] if left.shape[0] != right.shape[0]: # type: ignore[union-attr] - raise_assert_detail(objs, "Length mismatch", left.shape, right.shape) # type: ignore[union-attr] + raise_assert_detail(objs, "length mismatch", left.shape, right.shape) # type: ignore[union-attr] if not check_row_order: try: From 1cf9a0f8591a512ce3636d6854dfc435a92756f3 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 26 Sep 2023 13:54:23 +0000 Subject: [PATCH 3/6] update test (new check found something) --- py-polars/polars/testing/asserts.py | 2 +- py-polars/tests/unit/test_lazy.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index ea1b6c3b967d..d2e1b1119a87 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -104,7 +104,7 @@ def assert_frame_equal( left_schema, right_schema = left.schema, right.schema assert ( left_schema == right_schema - ), f"lazy schemas are not equal\nleft: {left_schema}\nright: {right_schema}" + ), f"lazy schemas are not equal\n left: {left_schema}\nright: {right_schema}" left, right = left.collect(), right.collect() # type: ignore[union-attr] if left.shape[0] != right.shape[0]: # type: ignore[union-attr] diff --git a/py-polars/tests/unit/test_lazy.py b/py-polars/tests/unit/test_lazy.py index 3aac664a8f88..9adbbd2620bd 100644 --- a/py-polars/tests/unit/test_lazy.py +++ b/py-polars/tests/unit/test_lazy.py @@ -69,8 +69,9 @@ def test_lazyframe_membership_operator() -> None: def test_apply() -> None: ldf = pl.LazyFrame({"a": [1, 2, 3], "b": [1.0, 2.0, 3.0]}) - new = ldf.with_columns_seq(pl.col("a").map_batches(lambda s: s * 2).alias("foo")) - + new = ldf.with_columns_seq( + pl.col("a").map_batches(lambda s: s * 2, return_dtype=pl.Int64).alias("foo") + ) expected = ldf.clone().with_columns((pl.col("a") * 2).alias("foo")) assert_frame_equal(new, expected) assert_frame_equal(new.collect(), expected.collect()) From 496ab057b248ea42fa3544c19dc9d2891c4c1938 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 26 Sep 2023 14:17:58 +0000 Subject: [PATCH 4/6] bring assert error strings more in-line with the newer exception formatting standard --- py-polars/polars/testing/asserts.py | 35 ++++++++++++++-------------- py-polars/tests/unit/test_testing.py | 28 ++++++++++++---------- 2 files changed, 33 insertions(+), 30 deletions(-) diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index d2e1b1119a87..ca47c18141a7 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -82,7 +82,7 @@ def assert_frame_equal( elif isinstance(left, DataFrame) and isinstance(right, DataFrame): objs = "DataFrames" else: - raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right)) + raise_assert_detail("Inputs", "unexpected input types", type(left), type(right)) if left_not_right := [c for c in left.columns if c not in right.columns]: raise AssertionError( @@ -102,9 +102,10 @@ def assert_frame_equal( if collect_input_frames: if check_dtype: # check this _before_ we collect left_schema, right_schema = left.schema, right.schema - assert ( - left_schema == right_schema - ), f"lazy schemas are not equal\n left: {left_schema}\nright: {right_schema}" + if left_schema != right_schema: + raise_assert_detail( + objs, "lazy schemas are not equal", left_schema, right_schema + ) left, right = left.collect(), right.collect() # type: ignore[union-attr] if left.shape[0] != right.shape[0]: # type: ignore[union-attr] @@ -258,13 +259,13 @@ def assert_series_equal( isinstance(left, Series) # type: ignore[redundant-expr] and isinstance(right, Series) ): - raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right)) + raise_assert_detail("Inputs", "unexpected input types", type(left), type(right)) if len(left) != len(right): - raise_assert_detail("Series", "Length mismatch", len(left), len(right)) + raise_assert_detail("Series", "length mismatch", len(left), len(right)) if check_names and left.name != right.name: - raise_assert_detail("Series", "Name mismatch", left.name, right.name) + raise_assert_detail("Series", "name mismatch", left.name, right.name) _assert_series_inner( left, @@ -355,7 +356,7 @@ def _assert_series_inner( ) -> None: """Compare Series dtype + values.""" if check_dtype and left.dtype != right.dtype: - raise_assert_detail("Series", "Dtype mismatch", left.dtype, right.dtype) + raise_assert_detail("Series", "dtype mismatch", left.dtype, right.dtype) if left.null_count() != right.null_count(): raise_assert_detail( @@ -406,7 +407,7 @@ def _assert_series_inner( if check_exact: raise_assert_detail( "Series", - "Exact value mismatch", + "exact value mismatch", left=list(left), right=list(right), ) @@ -442,7 +443,7 @@ def _assert_series_inner( if mismatch: raise_assert_detail( "Series", - f"Value mismatch{nan_info}", + f"value mismatch{nan_info}", left=list(left), right=list(right), ) @@ -477,10 +478,10 @@ def _assert_series_nested( s2, ) elif (s1 is None and s2 is not None) or (s2 is None and s1 is not None): - raise_assert_detail("Series", "Nested value mismatch", s1, s2) + raise_assert_detail("Series", "nested value mismatch", s1, s2) elif len(s1) != len(s2): raise_assert_detail( - "Series", "Nested list length mismatch", len(s1), len(s2) + "Series", "nested list length mismatch", len(s1), len(s2) ) _assert_series_inner( @@ -501,13 +502,13 @@ def _assert_series_nested( if len(ls.columns) != len(rs.columns): raise_assert_detail( "Series", - "Nested struct fields mismatch", + "nested struct fields mismatch", len(ls.columns), len(rs.columns), ) elif len(ls) != len(rs): raise_assert_detail( - "Series", "Nested struct length mismatch", len(ls), len(rs) + "Series", "nested struct length mismatch", len(ls), len(rs) ) for s1, s2 in zip(ls, rs): _assert_series_inner( @@ -539,10 +540,8 @@ def raise_assert_detail( error_msg = textwrap.dedent( f"""\ - {obj} are different. - - {detail} - [left]: {left} + {obj} are different ({detail}) + [left]: {left} [right]: {right}\ """ ) diff --git a/py-polars/tests/unit/test_testing.py b/py-polars/tests/unit/test_testing.py index 9a22d9a5a7a5..aff4391882a6 100644 --- a/py-polars/tests/unit/test_testing.py +++ b/py-polars/tests/unit/test_testing.py @@ -22,7 +22,9 @@ def test_compare_series_value_mismatch() -> None: srs2 = pl.Series([2, 3, 4]) assert_series_not_equal(srs1, srs2) - with pytest.raises(AssertionError, match="Series are different.\n\nValue mismatch"): + with pytest.raises( + AssertionError, match=r"Series are different \(value mismatch\)" + ): assert_series_equal(srs1, srs2) @@ -62,9 +64,9 @@ def test_compare_series_nans_assert_equal() -> None: (True, True), ): if check_exact: - check_msg = "Exact value mismatch" + check_msg = "exact value mismatch" else: - check_msg = f"Value mismatch.*nans_compare_equal={nans_equal}" + check_msg = f"value mismatch.*nans_compare_equal={nans_equal}" with pytest.raises(AssertionError, match=check_msg): assert_series_equal( @@ -135,7 +137,7 @@ def test_compare_series_value_mismatch_string() -> None: assert_series_not_equal(srs1, srs2) with pytest.raises( - AssertionError, match="Series are different.\n\nExact value mismatch" + AssertionError, match=r"Series are different \(exact value mismatch\)" ): assert_series_equal(srs1, srs2) @@ -145,20 +147,22 @@ def test_compare_series_type_mismatch() -> None: srs2 = pl.DataFrame({"col1": [2, 3, 4]}) with pytest.raises( - AssertionError, match="Inputs are different.\n\nUnexpected input types" + AssertionError, match=r"Inputs are different \(unexpected input types\)" ): assert_series_equal(srs1, srs2) # type: ignore[arg-type] srs3 = pl.Series([1.0, 2.0, 3.0]) assert_series_not_equal(srs1, srs3) - with pytest.raises(AssertionError, match="Series are different.\n\nDtype mismatch"): + with pytest.raises( + AssertionError, match=r"Series are different \(dtype mismatch\)" + ): assert_series_equal(srs1, srs3) def test_compare_series_name_mismatch() -> None: srs1 = pl.Series(values=[1, 2, 3], name="srs1") srs2 = pl.Series(values=[1, 2, 3], name="srs2") - with pytest.raises(AssertionError, match="Series are different.\n\nName mismatch"): + with pytest.raises(AssertionError, match=r"Series are different \(name mismatch\)"): assert_series_equal(srs1, srs2) @@ -168,7 +172,7 @@ def test_compare_series_shape_mismatch() -> None: assert_series_not_equal(srs1, srs2) with pytest.raises( - AssertionError, match="Series are different.\n\nLength mismatch" + AssertionError, match=r"Series are different \(length mismatch\)" ): assert_series_equal(srs1, srs2) @@ -177,7 +181,7 @@ def test_compare_series_value_exact_mismatch() -> None: srs1 = pl.Series([1.0, 2.0, 3.0]) srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0]) with pytest.raises( - AssertionError, match="Series are different.\n\nExact value mismatch" + AssertionError, match=r"Series are different \(exact value mismatch\)" ): assert_series_equal(srs1, srs2, check_exact=True) @@ -277,7 +281,7 @@ def test_assert_frame_equal_types() -> None: df1 = pl.DataFrame({"a": [1, 2]}) srs1 = pl.Series(values=[1, 2], name="a") with pytest.raises( - AssertionError, match="Inputs are different.\n\nUnexpected input types" + AssertionError, match=r"Inputs are different \(unexpected input types\)" ): assert_frame_equal(df1, srs1) # type: ignore[arg-type] @@ -286,7 +290,7 @@ def test_assert_frame_equal_length_mismatch() -> None: df1 = pl.DataFrame({"a": [1, 2]}) df2 = pl.DataFrame({"a": [1, 2, 3]}) with pytest.raises( - AssertionError, match="DataFrames are different.\n\nLength mismatch" + AssertionError, match=r"DataFrames are different \(length mismatch\)" ): assert_frame_equal(df1, df2) @@ -1027,7 +1031,7 @@ def test_assert_series_equal_categorical_vs_str() -> None: s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) s2 = pl.Series(["a", "b", "a"], dtype=pl.Utf8) - with pytest.raises(AssertionError, match="Dtype mismatch"): + with pytest.raises(AssertionError, match="dtype mismatch"): assert_series_equal(s1, s2, categorical_as_str=True) From 5306e8ea0eff9c178242988c6e255f65158d0578 Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 26 Sep 2023 23:14:21 +0400 Subject: [PATCH 5/6] Update py-polars/polars/testing/asserts.py Co-authored-by: Stijn de Gooijer --- py-polars/polars/testing/asserts.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index ca47c18141a7..6442aa64af78 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -75,9 +75,8 @@ def assert_frame_equal( >>> assert_frame_equal(df1, df2) # doctest: +SKIP AssertionError: Values for column 'a' are different. """ - if collect_input_frames := ( - isinstance(left, LazyFrame) and isinstance(right, LazyFrame) - ): + collect_input_frames = isinstance(left, LazyFrame) and isinstance(right, LazyFrame) + if collect_input_frames: objs = "LazyFrames" elif isinstance(left, DataFrame) and isinstance(right, DataFrame): objs = "DataFrames" From 2157c4a23a72c61f34bada40b640c72a91f7699a Mon Sep 17 00:00:00 2001 From: Alexander Beedie Date: Tue, 26 Sep 2023 23:37:15 +0400 Subject: [PATCH 6/6] Update py-polars/polars/testing/asserts.py Co-authored-by: Stijn de Gooijer --- py-polars/polars/testing/asserts.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index 6442aa64af78..0cd6b57b5bc4 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -540,7 +540,7 @@ def raise_assert_detail( error_msg = textwrap.dedent( f"""\ {obj} are different ({detail}) - [left]: {left} + [left]: {left} [right]: {right}\ """ )