diff --git a/py-polars/polars/lazyframe/frame.py b/py-polars/polars/lazyframe/frame.py index d78f4e45a6f2..5f615400267b 100644 --- a/py-polars/polars/lazyframe/frame.py +++ b/py-polars/polars/lazyframe/frame.py @@ -717,18 +717,18 @@ def serialize( msg = f"invalid serialization format: {format!r}" raise ValueError(msg) - def serialize_to_bytes() -> str: + def serialize_to_bytes() -> bytes: with BytesIO() as buf: serializer(buf) serialized = buf.getvalue() return serialized if file is None: - seralized = serialize_to_bytes() + serialized = serialize_to_bytes() if format == "json": - return seralized.decode() + return serialized.decode() else: - return seralized + return serialized elif isinstance(file, StringIO): serialized_str = serialize_to_bytes().decode() file.write(serialized_str) diff --git a/py-polars/tests/unit/lazyframe/test_serde.py b/py-polars/tests/unit/lazyframe/test_serde.py index bc1cd27e2020..8ff74d672491 100644 --- a/py-polars/tests/unit/lazyframe/test_serde.py +++ b/py-polars/tests/unit/lazyframe/test_serde.py @@ -14,6 +14,13 @@ from pathlib import Path +@given(lf=dataframes(lazy=True)) +def test_lf_serde_roundtrip_binary(lf: pl.LazyFrame) -> None: + serialized = lf.serialize(format="binary") + result = pl.LazyFrame.deserialize(io.BytesIO(serialized, format="binary")) + assert_frame_equal(result, lf, categorical_as_str=True) + + @given( lf=dataframes( lazy=True, @@ -23,9 +30,9 @@ ], ) ) -def test_lf_serde_roundtrip(lf: pl.LazyFrame) -> None: - serialized = lf.serialize() - result = pl.LazyFrame.deserialize(io.StringIO(serialized)) +def test_lf_serde_roundtrip_json(lf: pl.LazyFrame) -> None: + serialized = lf.serialize(format="json") + result = pl.LazyFrame.deserialize(io.StringIO(serialized), format="json") assert_frame_equal(result, lf, categorical_as_str=True) @@ -37,17 +44,30 @@ def lf() -> pl.LazyFrame: def test_lf_serde(lf: pl.LazyFrame) -> None: serialized = lf.serialize() - assert isinstance(serialized, str) - result = pl.LazyFrame.deserialize(io.StringIO(serialized)) + assert isinstance(serialized, bytes) + result = pl.LazyFrame.deserialize(io.BytesIO(serialized)) + assert_frame_equal(result, lf) + +def test_lf_serde_json_stringio(lf: pl.LazyFrame) -> None: + serialized = lf.serialize(format="json") + assert isinstance(serialized, str) + result = pl.LazyFrame.deserialize(io.StringIO(serialized), format="json") assert_frame_equal(result, lf) -@pytest.mark.parametrize("buf", [io.BytesIO(), io.StringIO()]) -def test_lf_serde_to_from_buffer(lf: pl.LazyFrame, buf: io.IOBase) -> None: - lf.serialize(buf) +@pytest.mark.parametrize( + ("format", "buf"), + [ + ("binary", io.BytesIO()), + ("json", io.StringIO()), + ("json", io.BytesIO()), + ], +) +def test_lf_serde_to_from_buffer(lf: pl.LazyFrame, format: str, buf: io.IOBase) -> None: + lf.serialize(buf, format=format) buf.seek(0) - result = pl.LazyFrame.deserialize(buf) + result = pl.LazyFrame.deserialize(buf, format=format) assert_frame_equal(lf, result) @@ -62,8 +82,8 @@ def test_lf_serde_to_from_file(lf: pl.LazyFrame, tmp_path: Path) -> None: assert_frame_equal(lf, result) -def test_lazyframe_serde_json(lf: pl.LazyFrame) -> None: - serialized = lf.serialize(format="json") - assert isinstance(serialized, bytes) - result = pl.LazyFrame.deserialize(io.BytesIO(serialized), format="json") +def test_lf_serde_enum_data() -> None: + lf = pl.LazyFrame({"a": ["a", "b", "a"]}, schema={"a": pl.Enum(["b", "a"])}) + serialized = lf.serialize() + result = pl.LazyFrame.deserialize(io.BytesIO(serialized)) assert_frame_equal(result, lf)