Skip to content

Commit

Permalink
Don't hold bare capsules
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron committed Jul 23, 2024
1 parent ecf5334 commit b309a57
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 38 deletions.
21 changes: 2 additions & 19 deletions py-polars/tests/unit/interop/test_interop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from polars.exceptions import ComputeError, UnstableWarning
from polars.interchange.protocol import CompatLevel
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder


def test_arrow_list_roundtrip() -> None:
Expand Down Expand Up @@ -752,26 +753,8 @@ def test_compat_level(monkeypatch: pytest.MonkeyPatch) -> None:


def test_df_pycapsule_interface() -> None:
class PyCapsuleStreamHolder:
"""
Hold the Arrow C Stream pycapsule.
A class that exposes _only_ the Arrow C Stream interface via Arrow PyCapsules.
This ensures that pyarrow is seeing _only_ the `__arrow_c_stream__` dunder, and
that nothing else (e.g. the dataframe or array interface) is actually being
used.
"""

capsule: object

def __init__(self, capsule: object) -> None:
self.capsule = capsule

def __arrow_c_stream__(self, requested_schema: object) -> object:
return self.capsule

df = pl.DataFrame({"a": [1, 2, 3], "b": ["a", "b", "c"]})
out = pa.table(PyCapsuleStreamHolder(df.__arrow_c_stream__(None)))
out = pa.table(PyCapsuleStreamHolder(df))
assert df.shape == out.shape
assert df.schema.names() == out.schema.names

Expand Down
21 changes: 2 additions & 19 deletions py-polars/tests/unit/series/test_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ShapeError,
)
from polars.testing import assert_frame_equal, assert_series_equal
from tests.unit.utils.pycapsule_utils import PyCapsuleStreamHolder

if TYPE_CHECKING:
from zoneinfo import ZoneInfo
Expand Down Expand Up @@ -629,26 +630,8 @@ def test_arrow() -> None:


def test_pycapsule_interface() -> None:
class PyCapsuleSeriesHolder:
"""
Hold the Arrow C Stream pycapsule.
A class that exposes _only_ the Arrow C Stream interface via Arrow PyCapsules.
This ensures that pyarrow is seeing _only_ the `__arrow_c_stream__` dunder, and
that nothing else (e.g. the dataframe or array interface) is actually being
used.
"""

capsule: object

def __init__(self, capsule: object):
self.capsule = capsule

def __arrow_c_stream__(self, requested_schema: object) -> object:
return self.capsule

a = pl.Series("a", [1, 2, 3, None])
out = pa.chunked_array(PyCapsuleSeriesHolder(a.__arrow_c_stream__(None)))
out = pa.chunked_array(PyCapsuleStreamHolder(a))
out_arr = out.combine_chunks()
assert out_arr == pa.array([1, 2, 3, None])

Expand Down
22 changes: 22 additions & 0 deletions py-polars/tests/unit/utils/pycapsule_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from typing import Any


class PyCapsuleStreamHolder:
"""
Hold the Arrow C Stream pycapsule.
A class that exposes _only_ the Arrow C Stream interface via Arrow PyCapsules.
This ensures that pyarrow is seeing _only_ the `__arrow_c_stream__` dunder, and
that nothing else (e.g. the dataframe or array interface) is actually being
used.
This is used by tests across multiple files.
"""

arrow_obj: Any

def __init__(self, arrow_obj: object) -> None:
self.arrow_obj = arrow_obj

def __arrow_c_stream__(self, requested_schema: object = None) -> object:
return self.arrow_obj.__arrow_c_stream__(requested_schema)

0 comments on commit b309a57

Please sign in to comment.