Skip to content

Commit

Permalink
bring assert error strings more in-line with the newer exception form…
Browse files Browse the repository at this point in the history
…atting standard
  • Loading branch information
alexander-beedie committed Sep 26, 2023
1 parent 1cf9a0f commit 496ab05
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 30 deletions.
35 changes: 17 additions & 18 deletions py-polars/polars/testing/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def assert_frame_equal(
elif isinstance(left, DataFrame) and isinstance(right, DataFrame):
objs = "DataFrames"
else:
raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right))
raise_assert_detail("Inputs", "unexpected input types", type(left), type(right))

if left_not_right := [c for c in left.columns if c not in right.columns]:
raise AssertionError(
Expand All @@ -102,9 +102,10 @@ def assert_frame_equal(
if collect_input_frames:
if check_dtype: # check this _before_ we collect
left_schema, right_schema = left.schema, right.schema
assert (
left_schema == right_schema
), f"lazy schemas are not equal\n left: {left_schema}\nright: {right_schema}"
if left_schema != right_schema:
raise_assert_detail(
objs, "lazy schemas are not equal", left_schema, right_schema
)
left, right = left.collect(), right.collect() # type: ignore[union-attr]

if left.shape[0] != right.shape[0]: # type: ignore[union-attr]
Expand Down Expand Up @@ -258,13 +259,13 @@ def assert_series_equal(
isinstance(left, Series) # type: ignore[redundant-expr]
and isinstance(right, Series)
):
raise_assert_detail("Inputs", "Unexpected input types", type(left), type(right))
raise_assert_detail("Inputs", "unexpected input types", type(left), type(right))

if len(left) != len(right):
raise_assert_detail("Series", "Length mismatch", len(left), len(right))
raise_assert_detail("Series", "length mismatch", len(left), len(right))

if check_names and left.name != right.name:
raise_assert_detail("Series", "Name mismatch", left.name, right.name)
raise_assert_detail("Series", "name mismatch", left.name, right.name)

_assert_series_inner(
left,
Expand Down Expand Up @@ -355,7 +356,7 @@ def _assert_series_inner(
) -> None:
"""Compare Series dtype + values."""
if check_dtype and left.dtype != right.dtype:
raise_assert_detail("Series", "Dtype mismatch", left.dtype, right.dtype)
raise_assert_detail("Series", "dtype mismatch", left.dtype, right.dtype)

if left.null_count() != right.null_count():
raise_assert_detail(
Expand Down Expand Up @@ -406,7 +407,7 @@ def _assert_series_inner(
if check_exact:
raise_assert_detail(
"Series",
"Exact value mismatch",
"exact value mismatch",
left=list(left),
right=list(right),
)
Expand Down Expand Up @@ -442,7 +443,7 @@ def _assert_series_inner(
if mismatch:
raise_assert_detail(
"Series",
f"Value mismatch{nan_info}",
f"value mismatch{nan_info}",
left=list(left),
right=list(right),
)
Expand Down Expand Up @@ -477,10 +478,10 @@ def _assert_series_nested(
s2,
)
elif (s1 is None and s2 is not None) or (s2 is None and s1 is not None):
raise_assert_detail("Series", "Nested value mismatch", s1, s2)
raise_assert_detail("Series", "nested value mismatch", s1, s2)
elif len(s1) != len(s2):
raise_assert_detail(
"Series", "Nested list length mismatch", len(s1), len(s2)
"Series", "nested list length mismatch", len(s1), len(s2)
)

_assert_series_inner(
Expand All @@ -501,13 +502,13 @@ def _assert_series_nested(
if len(ls.columns) != len(rs.columns):
raise_assert_detail(
"Series",
"Nested struct fields mismatch",
"nested struct fields mismatch",
len(ls.columns),
len(rs.columns),
)
elif len(ls) != len(rs):
raise_assert_detail(
"Series", "Nested struct length mismatch", len(ls), len(rs)
"Series", "nested struct length mismatch", len(ls), len(rs)
)
for s1, s2 in zip(ls, rs):
_assert_series_inner(
Expand Down Expand Up @@ -539,10 +540,8 @@ def raise_assert_detail(

error_msg = textwrap.dedent(
f"""\
{obj} are different.
{detail}
[left]: {left}
{obj} are different ({detail})
[left]: {left}
[right]: {right}\
"""
)
Expand Down
28 changes: 16 additions & 12 deletions py-polars/tests/unit/test_testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@ def test_compare_series_value_mismatch() -> None:
srs2 = pl.Series([2, 3, 4])

assert_series_not_equal(srs1, srs2)
with pytest.raises(AssertionError, match="Series are different.\n\nValue mismatch"):
with pytest.raises(
AssertionError, match=r"Series are different \(value mismatch\)"
):
assert_series_equal(srs1, srs2)


Expand Down Expand Up @@ -62,9 +64,9 @@ def test_compare_series_nans_assert_equal() -> None:
(True, True),
):
if check_exact:
check_msg = "Exact value mismatch"
check_msg = "exact value mismatch"
else:
check_msg = f"Value mismatch.*nans_compare_equal={nans_equal}"
check_msg = f"value mismatch.*nans_compare_equal={nans_equal}"

with pytest.raises(AssertionError, match=check_msg):
assert_series_equal(
Expand Down Expand Up @@ -135,7 +137,7 @@ def test_compare_series_value_mismatch_string() -> None:

assert_series_not_equal(srs1, srs2)
with pytest.raises(
AssertionError, match="Series are different.\n\nExact value mismatch"
AssertionError, match=r"Series are different \(exact value mismatch\)"
):
assert_series_equal(srs1, srs2)

Expand All @@ -145,20 +147,22 @@ def test_compare_series_type_mismatch() -> None:
srs2 = pl.DataFrame({"col1": [2, 3, 4]})

with pytest.raises(
AssertionError, match="Inputs are different.\n\nUnexpected input types"
AssertionError, match=r"Inputs are different \(unexpected input types\)"
):
assert_series_equal(srs1, srs2) # type: ignore[arg-type]

srs3 = pl.Series([1.0, 2.0, 3.0])
assert_series_not_equal(srs1, srs3)
with pytest.raises(AssertionError, match="Series are different.\n\nDtype mismatch"):
with pytest.raises(
AssertionError, match=r"Series are different \(dtype mismatch\)"
):
assert_series_equal(srs1, srs3)


def test_compare_series_name_mismatch() -> None:
srs1 = pl.Series(values=[1, 2, 3], name="srs1")
srs2 = pl.Series(values=[1, 2, 3], name="srs2")
with pytest.raises(AssertionError, match="Series are different.\n\nName mismatch"):
with pytest.raises(AssertionError, match=r"Series are different \(name mismatch\)"):
assert_series_equal(srs1, srs2)


Expand All @@ -168,7 +172,7 @@ def test_compare_series_shape_mismatch() -> None:

assert_series_not_equal(srs1, srs2)
with pytest.raises(
AssertionError, match="Series are different.\n\nLength mismatch"
AssertionError, match=r"Series are different \(length mismatch\)"
):
assert_series_equal(srs1, srs2)

Expand All @@ -177,7 +181,7 @@ def test_compare_series_value_exact_mismatch() -> None:
srs1 = pl.Series([1.0, 2.0, 3.0])
srs2 = pl.Series([1.0, 2.0 + 1e-7, 3.0])
with pytest.raises(
AssertionError, match="Series are different.\n\nExact value mismatch"
AssertionError, match=r"Series are different \(exact value mismatch\)"
):
assert_series_equal(srs1, srs2, check_exact=True)

Expand Down Expand Up @@ -277,7 +281,7 @@ def test_assert_frame_equal_types() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
srs1 = pl.Series(values=[1, 2], name="a")
with pytest.raises(
AssertionError, match="Inputs are different.\n\nUnexpected input types"
AssertionError, match=r"Inputs are different \(unexpected input types\)"
):
assert_frame_equal(df1, srs1) # type: ignore[arg-type]

Expand All @@ -286,7 +290,7 @@ def test_assert_frame_equal_length_mismatch() -> None:
df1 = pl.DataFrame({"a": [1, 2]})
df2 = pl.DataFrame({"a": [1, 2, 3]})
with pytest.raises(
AssertionError, match="DataFrames are different.\n\nLength mismatch"
AssertionError, match=r"DataFrames are different \(length mismatch\)"
):
assert_frame_equal(df1, df2)

Expand Down Expand Up @@ -1027,7 +1031,7 @@ def test_assert_series_equal_categorical_vs_str() -> None:
s1 = pl.Series(["a", "b", "a"], dtype=pl.Categorical)
s2 = pl.Series(["a", "b", "a"], dtype=pl.Utf8)

with pytest.raises(AssertionError, match="Dtype mismatch"):
with pytest.raises(AssertionError, match="dtype mismatch"):
assert_series_equal(s1, s2, categorical_as_str=True)


Expand Down

0 comments on commit 496ab05

Please sign in to comment.