diff --git a/py-polars/polars/testing/asserts.py b/py-polars/polars/testing/asserts.py index d2e1b1119a87..ca47c18141a7 100644 --- a/py-polars/polars/testing/asserts.py +++ b/py-polars/polars/testing/asserts.py @@ -82,7 +82,7 @@ def assert_frame_equal( elif isinstance(left, DataFrame) and isinstance(right, DataFrame): objs = "DataFrames" else: - raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right)) + raise_assert_detail("Inputs", "unexpected input types", type(left), type(right)) if left_not_right := [c for c in left.columns if c not in right.columns]: raise AssertionError( @@ -102,9 +102,10 @@ def assert_frame_equal( if collect_input_frames: if check_dtype: # check this _before_ we collect left_schema, right_schema = left.schema, right.schema - assert ( - left_schema == right_schema - ), f"lazy schemas are not equal\n left: {left_schema}\nright: {right_schema}" + if left_schema != right_schema: + raise_assert_detail( + objs, "lazy schemas are not equal", left_schema, right_schema + ) left, right = left.collect(), right.collect() # type: ignore[union-attr] if left.shape[0] != right.shape[0]: # type: ignore[union-attr] @@ -258,13 +259,13 @@ def assert_series_equal( isinstance(left, Series) # type: ignore[redundant-expr] and isinstance(right, Series) ): - raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right)) + raise_assert_detail("Inputs", "unexpected input types", type(left), type(right)) if len(left) != len(right): - raise_assert_detail("Series", "Length mismatch", len(left), len(right)) + raise_assert_detail("Series", "length mismatch", len(left), len(right)) if check_names and left.name != right.name: - raise_assert_detail("Series", "Name mismatch", left.name, right.name) + raise_assert_detail("Series", "name mismatch", left.name, right.name) _assert_series_inner( left, @@ -355,7 +356,7 @@ def _assert_series_inner( ) -> None: """Compare Series dtype + values.""" if check_dtype and left.dtype != right.dtype: - raise_assert_detail("Series", "Dtype mismatch", left.dtype, right.dtype) + raise_assert_detail("Series", "dtype mismatch", left.dtype, right.dtype) if left.null_count() != right.null_count(): raise_assert_detail( @@ -406,7 +407,7 @@ def _assert_series_inner( if check_exact: raise_assert_detail( "Series", - "Exact value mismatch", + "exact value mismatch", left=list(left), right=list(right), ) @@ -442,7 +443,7 @@ def _assert_series_inner( if mismatch: raise_assert_detail( "Series", - f"Value mismatch{nan_info}", + f"value mismatch{nan_info}", left=list(left), right=list(right), ) @@ -477,10 +478,10 @@ def _assert_series_nested( s2, ) elif (s1 is None and s2 is not None) or (s2 is None and s1 is not None): - raise_assert_detail("Series", "Nested value mismatch", s1, s2) + raise_assert_detail("Series", "nested value mismatch", s1, s2) elif len(s1) != len(s2): raise_assert_detail( - "Series", "Nested list length mismatch", len(s1), len(s2) + "Series", "nested list length mismatch", len(s1), len(s2) ) _assert_series_inner( @@ -501,13 +502,13 @@ def _assert_series_nested( if len(ls.columns) != len(rs.columns): raise_assert_detail( "Series", - "Nested struct fields mismatch", + "nested struct fields mismatch", len(ls.columns), len(rs.columns), ) elif len(ls) != len(rs): raise_assert_detail( - "Series", "Nested struct length mismatch", len(ls), len(rs) + "Series", "nested struct length mismatch", len(ls), len(rs) ) for s1, s2 in zip(ls, rs): _assert_series_inner( @@ -539,10 +540,8 @@ def raise_assert_detail( error_msg = textwrap.dedent( f"""\ - {obj} are different. - - {detail} - [left]: {left} + {obj} are different ({detail}) + [left]: {left} [right]: {right}\ """ ) diff --git a/py-polars/tests/unit/test_testing.py b/py-polars/tests/unit/test_testing.py index 9a22d9a5a7a5..aff4391882a6 100644 --- a/py-polars/tests/unit/test_testing.py +++ b/py-polars/tests/unit/test_testing.py @@ -22,7 +22,9 @@ def test_compare_series_value_mismatch() -> None: srs2 = pl.Series([2, 3, 4]) assert_series_not_equal(srs1, srs2) - with pytest.raises(AssertionError, match="Series are different.\n\nValue mismatch"): + with pytest.raises( + AssertionError, match=r"Series are different \(value mismatch\)" + ): assert_series_equal(srs1, srs2) @@ -62,9 +64,9 @@ def test_compare_series_nans_assert_equal() -> None: (True, True), ): if check_exact: - check_msg = "Exact value mismatch" + check_msg = "exact value mismatch" else: - check_msg = f"Value mismatch.*nans_compare_equal={nans_equal}" + check_msg = f"value mismatch.*nans_compare_equal={nans_equal}" with pytest.raises(AssertionError, match=check_msg): assert_series_equal( @@ -135,7 +137,7 @@ def test_compare_series_value_mismatch_string() -> None: assert_series_not_equal(srs1, srs2) with pytest.raises( - AssertionError, match="Series are different.\n\nExact value mismatch" + AssertionError, match=r"Series are different \(exact value mismatch\)" ): assert_series_equal(srs1, srs2) @@ -145,20 +147,22 @@ def test_compare_series_type_mismatch() -> None: srs2 = pl.DataFrame({"col1": [2, 3, 4]}) with pytest.raises( - AssertionError, match="Inputs are different.\n\nUnexpected input types" + AssertionError, match=r"Inputs are different \(unexpected input types\)" ): assert_series_equal(srs1, srs2) # type: ignore[arg-type] srs3 = pl.Series([1.0, 2.0, 3.0]) assert_series_not_equal(srs1, srs3) - with pytest.raises(AssertionError, match="Series are different.\n\nDtype mismatch"): + with pytest.raises( + AssertionError, match=r"Series are different \(dtype mismatch\)" + ): assert_series_equal(srs1, srs3) def test_compare_series_name_mismatch() -> None: srs1 = pl.Series(values=[1, 2, 3], name="srs1") srs2 = pl.Series(values=[1, 2, 3], name="srs2") - with pytest.raises(AssertionError, match="Series are different.\n\nName mismatch"): + with pytest.raises(AssertionError, match=r"Series are different \(name mismatch\)"): assert_series_equal(srs1, srs2) @@ -168,7 +172,7 @@ def test_compare_series_shape_mismatch() -> None: assert_series_not_equal(srs1, srs2) with pytest.raises( - AssertionError, match="Series are different.\n\nLength mismatch" + AssertionError, match=r"Series are different \(length mismatch\)" ): assert_series_equal(srs1, srs2) @@ -177,7 +181,7 @@ def test_compare_series_value_exact_mismatch() -> None: srs1 = pl.Series([1.0, 2.0, 3.0]) srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0]) with pytest.raises( - AssertionError, match="Series are different.\n\nExact value mismatch" + AssertionError, match=r"Series are different \(exact value mismatch\)" ): assert_series_equal(srs1, srs2, check_exact=True) @@ -277,7 +281,7 @@ def test_assert_frame_equal_types() -> None: df1 = pl.DataFrame({"a": [1, 2]}) srs1 = pl.Series(values=[1, 2], name="a") with pytest.raises( - AssertionError, match="Inputs are different.\n\nUnexpected input types" + AssertionError, match=r"Inputs are different \(unexpected input types\)" ): assert_frame_equal(df1, srs1) # type: ignore[arg-type] @@ -286,7 +290,7 @@ def test_assert_frame_equal_length_mismatch() -> None: df1 = pl.DataFrame({"a": [1, 2]}) df2 = pl.DataFrame({"a": [1, 2, 3]}) with pytest.raises( - AssertionError, match="DataFrames are different.\n\nLength mismatch" + AssertionError, match=r"DataFrames are different \(length mismatch\)" ): assert_frame_equal(df1, df2) @@ -1027,7 +1031,7 @@ def test_assert_series_equal_categorical_vs_str() -> None: s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical) s2 = pl.Series(["a", "b", "a"], dtype=pl.Utf8) - with pytest.raises(AssertionError, match="Dtype mismatch"): + with pytest.raises(AssertionError, match="dtype mismatch"): assert_series_equal(s1, s2, categorical_as_str=True)