Skip to content

Commit

Permalink
Update LazyFrame
Browse files Browse the repository at this point in the history
  • Loading branch information
stinodego committed Jun 26, 2024
1 parent 7cd939d commit 6f6675d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 17 deletions.
8 changes: 4 additions & 4 deletions py-polars/polars/lazyframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -717,18 +717,18 @@ def serialize(
msg = f"invalid serialization format: {format!r}"
raise ValueError(msg)

def serialize_to_bytes() -> str:
def serialize_to_bytes() -> bytes:
with BytesIO() as buf:
serializer(buf)
serialized = buf.getvalue()
return serialized

if file is None:
seralized = serialize_to_bytes()
serialized = serialize_to_bytes()
if format == "json":
return seralized.decode()
return serialized.decode()
else:
return seralized
return serialized
elif isinstance(file, StringIO):
serialized_str = serialize_to_bytes().decode()
file.write(serialized_str)
Expand Down
46 changes: 33 additions & 13 deletions py-polars/tests/unit/lazyframe/test_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@
from pathlib import Path


@given(lf=dataframes(lazy=True))
def test_lf_serde_roundtrip_binary(lf: pl.LazyFrame) -> None:
serialized = lf.serialize(format="binary")
result = pl.LazyFrame.deserialize(io.BytesIO(serialized, format="binary"))
assert_frame_equal(result, lf, categorical_as_str=True)


@given(
lf=dataframes(
lazy=True,
Expand All @@ -23,9 +30,9 @@
],
)
)
def test_lf_serde_roundtrip(lf: pl.LazyFrame) -> None:
serialized = lf.serialize()
result = pl.LazyFrame.deserialize(io.StringIO(serialized))
def test_lf_serde_roundtrip_json(lf: pl.LazyFrame) -> None:
serialized = lf.serialize(format="json")
result = pl.LazyFrame.deserialize(io.StringIO(serialized), format="json")
assert_frame_equal(result, lf, categorical_as_str=True)


Expand All @@ -37,17 +44,30 @@ def lf() -> pl.LazyFrame:

def test_lf_serde(lf: pl.LazyFrame) -> None:
serialized = lf.serialize()
assert isinstance(serialized, str)
result = pl.LazyFrame.deserialize(io.StringIO(serialized))
assert isinstance(serialized, bytes)
result = pl.LazyFrame.deserialize(io.BytesIO(serialized))
assert_frame_equal(result, lf)


def test_lf_serde_json_stringio(lf: pl.LazyFrame) -> None:
serialized = lf.serialize(format="json")
assert isinstance(serialized, str)
result = pl.LazyFrame.deserialize(io.StringIO(serialized), format="json")
assert_frame_equal(result, lf)


@pytest.mark.parametrize("buf", [io.BytesIO(), io.StringIO()])
def test_lf_serde_to_from_buffer(lf: pl.LazyFrame, buf: io.IOBase) -> None:
lf.serialize(buf)
@pytest.mark.parametrize(
("format", "buf"),
[
("binary", io.BytesIO()),
("json", io.StringIO()),
("json", io.BytesIO()),
],
)
def test_lf_serde_to_from_buffer(lf: pl.LazyFrame, format: str, buf: io.IOBase) -> None:
lf.serialize(buf, format=format)
buf.seek(0)
result = pl.LazyFrame.deserialize(buf)
result = pl.LazyFrame.deserialize(buf, format=format)
assert_frame_equal(lf, result)


Expand All @@ -62,8 +82,8 @@ def test_lf_serde_to_from_file(lf: pl.LazyFrame, tmp_path: Path) -> None:
assert_frame_equal(lf, result)


def test_lazyframe_serde_json(lf: pl.LazyFrame) -> None:
serialized = lf.serialize(format="json")
assert isinstance(serialized, bytes)
result = pl.LazyFrame.deserialize(io.BytesIO(serialized), format="json")
def test_lf_serde_enum_data() -> None:
lf = pl.LazyFrame({"a": ["a", "b", "a"]}, schema={"a": pl.Enum(["b", "a"])})
serialized = lf.serialize()
result = pl.LazyFrame.deserialize(io.BytesIO(serialized))
assert_frame_equal(result, lf)

0 comments on commit 6f6675d

Please sign in to comment.