Skip to content

Commit

Permalink
test(python): Refactor serde tests, add hypothesis tests
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Jun 26, 2024
1 parent df989de commit 1a296a0
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 65 deletions.
54 changes: 39 additions & 15 deletions py-polars/tests/unit/dataframe/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,32 @@
from typing import TYPE_CHECKING, Any

import pytest
from hypothesis import given

import polars as pl
from polars.exceptions import ComputeError
from polars.testing import assert_frame_equal
from polars.testing.parametric import dataframes

if TYPE_CHECKING:
from pathlib import Path


@given(
df=dataframes(
excluded_dtypes=[
pl.Null, # Not implemented yet
pl.Float32, # Bug, see: https://github.com/pola-rs/polars/issues/17211
pl.Float64, # Bug, see: https://github.com/pola-rs/polars/issues/17211
],
)
)
def test_df_serde_roundtrip(df: pl.DataFrame) -> None:
serialized = df.serialize()
result = pl.DataFrame.deserialize(io.StringIO(serialized))
assert_frame_equal(result, df, categorical_as_str=True)


def test_df_serialize() -> None:
df = pl.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}).sort("a")
result = df.serialize()
Expand All @@ -23,15 +40,15 @@ def test_df_serialize() -> None:


@pytest.mark.parametrize("buf", [io.BytesIO(), io.StringIO()])
def test_to_from_buffer(df: pl.DataFrame, buf: io.IOBase) -> None:
def test_df_serde_to_from_buffer(df: pl.DataFrame, buf: io.IOBase) -> None:
df.serialize(buf)
buf.seek(0)
read_df = pl.DataFrame.deserialize(buf)
assert_frame_equal(df, read_df, categorical_as_str=True)


@pytest.mark.write_disk()
def test_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None:
def test_df_serde_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)

file_path = tmp_path / "small.json"
Expand All @@ -41,13 +58,6 @@ def test_to_from_file(df: pl.DataFrame, tmp_path: Path) -> None:
assert_frame_equal(df, out, categorical_as_str=True)


def test_write_json_to_string() -> None:
# Tests if it runs if no arg given
df = pl.DataFrame({"a": [1, 2, 3]})
expected_str = '{"columns":[{"name":"a","datatype":"Int64","bit_settings":"","values":[1,2,3]}]}'
assert df.serialize() == expected_str


def test_write_json(df: pl.DataFrame) -> None:
# Text-based conversion loses time info
df = df.select(pl.all().exclude(["cat", "time"]))
Expand Down Expand Up @@ -100,7 +110,7 @@ def test_df_serde_enum() -> None:
),
],
)
def test_write_read_json_array(data: Any, dtype: pl.DataType) -> None:
def test_df_serde_array(data: Any, dtype: pl.DataType) -> None:
df = pl.DataFrame({"foo": data}, schema={"foo": dtype})
buf = io.StringIO()
df.serialize(buf)
Expand Down Expand Up @@ -135,9 +145,7 @@ def test_write_read_json_array(data: Any, dtype: pl.DataType) -> None:
),
],
)
def test_write_read_json_array_logical_inner_type(
data: Any, dtype: pl.DataType
) -> None:
def test_df_serde_array_logical_inner_type(data: Any, dtype: pl.DataType) -> None:
df = pl.DataFrame({"foo": data}, schema={"foo": dtype})
buf = io.StringIO()
df.serialize(buf)
Expand All @@ -147,14 +155,30 @@ def test_write_read_json_array_logical_inner_type(
assert deserialized_df.to_dict(as_series=False) == df.to_dict(as_series=False)


def test_json_deserialize_empty_list_10458() -> None:
def test_df_serde_empty_list_10458() -> None:
schema = {"LIST_OF_STRINGS": pl.List(pl.String)}
serialized_schema = pl.DataFrame(schema=schema).serialize()
df = pl.DataFrame.deserialize(io.StringIO(serialized_schema))
assert df.schema == schema


def test_serde_validation() -> None:
@pytest.mark.xfail(reason="Bug: https://github.com/pola-rs/polars/issues/17211")
def test_df_serde_float_inf_nan() -> None:
df = pl.DataFrame({"a": [1.0, float("inf"), float("-inf"), float("nan")]})
ser = df.serialize()
result = pl.DataFrame.deserialize(io.StringIO(ser))
assert_frame_equal(result, df)


@pytest.mark.xfail(reason="Not implemented yet")
def test_df_serde_null() -> None:
df = pl.DataFrame({"a": [None, None]})
ser = df.serialize()
result = pl.DataFrame.deserialize(io.StringIO(ser))
assert_frame_equal(result, df)


def test_df_deserialize_validation() -> None:
f = io.StringIO(
"""
{
Expand Down
45 changes: 45 additions & 0 deletions py-polars/tests/unit/expr/test_serde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import io

import pytest

import polars as pl
from polars.exceptions import ComputeError


def test_expr_serialization_roundtrip() -> None:
expr = pl.col("foo").sum().over("bar")
json = expr.meta.serialize()
round_tripped = pl.Expr.deserialize(io.StringIO(json))
assert round_tripped.meta == expr


def test_expr_deserialize_file_not_found() -> None:
with pytest.raises(FileNotFoundError):
pl.Expr.deserialize("abcdef")


def test_expr_deserialize_invalid_json() -> None:
with pytest.raises(
ComputeError, match="could not deserialize input into an expression"
):
pl.Expr.deserialize(io.StringIO("abcdef"))


def test_expr_write_json_from_json_deprecated() -> None:
expr = pl.col("foo").sum().over("bar")

with pytest.deprecated_call():
json = expr.meta.write_json()

with pytest.deprecated_call():
round_tripped = pl.Expr.from_json(json)

assert round_tripped.meta == expr


def test_expression_json_13991() -> None:
expr = pl.col("foo").cast(pl.Decimal)
json = expr.meta.serialize()

round_tripped = pl.Expr.deserialize(io.StringIO(json))
assert round_tripped.meta == expr
63 changes: 63 additions & 0 deletions py-polars/tests/unit/lazyframe/test_serde.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

import io
from typing import TYPE_CHECKING

import pytest
from hypothesis import given

import polars as pl
from polars.testing import assert_frame_equal
from polars.testing.parametric import dataframes

if TYPE_CHECKING:
from pathlib import Path


@given(
lf=dataframes(
lazy=True,
excluded_dtypes=[
pl.Null, # Not implemented yet
pl.Float32, # Bug, see: https://github.com/pola-rs/polars/issues/17211
pl.Float64, # Bug, see: https://github.com/pola-rs/polars/issues/17211
],
)
)
def test_lf_serde_roundtrip(lf: pl.LazyFrame) -> None:
serialized = lf.serialize()
result = pl.LazyFrame.deserialize(io.StringIO(serialized))
assert_frame_equal(result, lf, categorical_as_str=True)


@pytest.fixture()
def lf() -> pl.LazyFrame:
"""Sample LazyFrame for testing serialization/deserialization."""
return pl.LazyFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}).select("a").sum()


def test_lf_serde(lf: pl.LazyFrame) -> None:
serialized = lf.serialize()
assert isinstance(serialized, str)
result = pl.LazyFrame.deserialize(io.StringIO(serialized))

assert_frame_equal(result, lf)


@pytest.mark.parametrize("buf", [io.BytesIO(), io.StringIO()])
def test_lf_serde_to_from_buffer(lf: pl.LazyFrame, buf: io.IOBase) -> None:
lf.serialize(buf)
buf.seek(0)
result = pl.LazyFrame.deserialize(buf)
assert_frame_equal(lf, result)


@pytest.mark.write_disk()
def test_lf_serde_to_from_file(lf: pl.LazyFrame, tmp_path: Path) -> None:
tmp_path.mkdir(exist_ok=True)

file_path = tmp_path / "small.json"
lf.serialize(file_path)
result = pl.LazyFrame.deserialize(file_path)

assert_frame_equal(lf, result)
51 changes: 1 addition & 50 deletions py-polars/tests/unit/test_serde.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
from __future__ import annotations

import io
import pickle
from datetime import datetime, timedelta

import pytest

import polars as pl
from polars import StringCache
from polars.exceptions import ComputeError, SchemaError
from polars.exceptions import SchemaError
from polars.testing import assert_frame_equal, assert_series_equal


Expand All @@ -24,15 +23,6 @@ def test_pickling_as_struct_11100() -> None:
assert str(pickle.loads(buf)) == str(e)


def test_lazyframe_serde() -> None:
lf = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]}).lazy().select(pl.col("a"))

json = lf.serialize()
result = pl.LazyFrame.deserialize(io.StringIO(json))

assert_series_equal(result.collect().to_series(), pl.Series("a", [1, 2, 3]))


def test_serde_time_unit() -> None:
values = [datetime(2022, 1, 1) + timedelta(days=1) for _ in range(3)]
s = pl.Series(values).cast(pl.Datetime("ns"))
Expand Down Expand Up @@ -195,45 +185,6 @@ def test_serde_array_dtype() -> None:
assert_series_equal(pickle.loads(pickle.dumps(nested_s)), nested_s)


def test_expr_serialization_roundtrip() -> None:
expr = pl.col("foo").sum().over("bar")
json = expr.meta.serialize()
round_tripped = pl.Expr.deserialize(io.StringIO(json))
assert round_tripped.meta == expr


def test_expr_deserialize_file_not_found() -> None:
with pytest.raises(FileNotFoundError):
pl.Expr.deserialize("abcdef")


def test_expr_deserialize_invalid_json() -> None:
with pytest.raises(
ComputeError, match="could not deserialize input into an expression"
):
pl.Expr.deserialize(io.StringIO("abcdef"))


def test_expr_write_json_from_json_deprecated() -> None:
expr = pl.col("foo").sum().over("bar")

with pytest.deprecated_call():
json = expr.meta.write_json()

with pytest.deprecated_call():
round_tripped = pl.Expr.from_json(json)

assert round_tripped.meta == expr


def test_expression_json_13991() -> None:
expr = pl.col("foo").cast(pl.Decimal)
json = expr.meta.serialize()

round_tripped = pl.Expr.deserialize(io.StringIO(json))
assert round_tripped.meta == expr


def test_serde_data_type_class() -> None:
dtype = pl.Datetime
serialized = pickle.dumps(dtype)
Expand Down

0 comments on commit 1a296a0

Please sign in to comment.