Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(pyarrow): support arrow PyCapsule interface in more places #9663

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,27 +927,6 @@ def test_self_join_memory_table(backend, con, monkeypatch):
param(
lambda: pa.table({"a": ["a"], "b": [1]}).to_batches()[0],
"df_arrow_single_batch",
marks=[
pytest.mark.notimpl(
[
"bigquery",
"clickhouse",
"duckdb",
"exasol",
"impala",
"mssql",
"mysql",
"oracle",
"postgres",
"pyspark",
"risingwave",
"snowflake",
"sqlite",
"trino",
"databricks",
]
)
],
id="pyarrow_single_batch",
),
param(
Expand Down
18 changes: 18 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1222,6 +1222,24 @@ def test_memtable_construct_from_pyarrow(backend, con, monkeypatch):
)


def test_memtable_construct_from_pyarrow_c_stream(backend, con):
pa = pytest.importorskip("pyarrow")

class Opaque:
def __init__(self, table):
self._table = table

def __arrow_c_stream__(self, *args, **kwargs):
return self._table.__arrow_c_stream__(*args, **kwargs)

table = pa.table({"a": list("abc"), "b": [1, 2, 3]})

t = ibis.memtable(Opaque(table))

res = con.to_pyarrow(t.order_by("a"))
assert res.equals(table)


@pytest.mark.parametrize("lazy", [False, True])
def test_memtable_construct_from_polars(backend, con, lazy):
pl = pytest.importorskip("polars")
Expand Down
58 changes: 43 additions & 15 deletions ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,42 +412,55 @@ def memtable(

@lazy_singledispatch
def _memtable(
data: pd.DataFrame | Any,
data: Any,
*,
columns: Iterable[str] | None = None,
schema: SchemaLike | None = None,
name: str | None = None,
) -> Table:
import pandas as pd

from ibis.formats.pandas import PandasDataFrameProxy
if hasattr(data, "__arrow_c_stream__"):
# Support objects exposing arrow's PyCapsule interface
import pyarrow as pa

if not isinstance(data, pd.DataFrame):
df = pd.DataFrame(data, columns=columns)
data = pa.table(data)
else:
df = data
import pandas as pd

data = pd.DataFrame(data, columns=columns)
return _memtable(data, columns=columns, schema=schema, name=name)


@_memtable.register("pandas.DataFrame")
def _memtable_from_pandas_dataframe(
data: pd.DataFrame,
*,
columns: Iterable[str] | None = None,
schema: SchemaLike | None = None,
name: str | None = None,
) -> Table:
from ibis.formats.pandas import PandasDataFrameProxy

if df.columns.inferred_type != "string":
cols = df.columns
if data.columns.inferred_type != "string":
cols = data.columns
newcols = getattr(
schema,
"names",
(f"col{i:d}" for i in builtins.range(len(cols))),
)
df = df.rename(columns=dict(zip(cols, newcols)))
data = data.rename(columns=dict(zip(cols, newcols)))

if columns is not None:
if (provided_col := len(columns)) != (exist_col := len(df.columns)):
if (provided_col := len(columns)) != (exist_col := len(data.columns)):
raise ValueError(
"Provided `columns` must have an entry for each column in `data`.\n"
f"`columns` has {provided_col} elements but `data` has {exist_col} columns."
)

df = df.rename(columns=dict(zip(df.columns, columns)))
data = data.rename(columns=dict(zip(data.columns, columns)))

# verify that the DataFrame has no duplicate column names because ibis
# doesn't allow that
cols = df.columns
cols = data.columns
dupes = [name for name, count in Counter(cols).items() if count > 1]
if dupes:
raise IbisInputError(
Expand All @@ -456,8 +469,8 @@ def _memtable(

op = ops.InMemoryTable(
name=name if name is not None else util.gen_name("pandas_memtable"),
schema=sch.infer(df) if schema is None else schema,
data=PandasDataFrameProxy(df),
schema=sch.infer(data) if schema is None else schema,
data=PandasDataFrameProxy(data),
)
return op.to_expr()

Expand Down Expand Up @@ -499,6 +512,21 @@ def _memtable_from_pyarrow_dataset(
).to_expr()


@_memtable.register("pyarrow.RecordBatchReader")
def _memtable_from_pyarrow_RecordBatchReader(
data: pa.Table,
*,
name: str | None = None,
schema: SchemaLike | None = None,
columns: Iterable[str] | None = None,
):
raise TypeError(
"Creating an `ibis.memtable` from a `pyarrow.RecordBatchReader` would "
"load _all_ data into memory. If you want to do this, please do so "
"explicitly like `ibis.memtable(reader.read_all())`"
)


@_memtable.register("polars.LazyFrame")
def _memtable_from_polars_lazyframe(data: pl.LazyFrame, **kwargs):
return _memtable_from_polars_dataframe(data.collect(), **kwargs)
Expand Down
Loading