Skip to content

Commit

Permalink
feat(pyarrow): support objects implementing __arrow_c_stream__ in `…
Browse files Browse the repository at this point in the history
…ibis.memtable`
  • Loading branch information
jcrist committed Jul 22, 2024
1 parent c9ff617 commit 28aff8a
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 37 deletions.
23 changes: 1 addition & 22 deletions ibis/backends/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -923,28 +923,7 @@ def test_self_join_memory_table(backend, con, monkeypatch):
param(
lambda: pa.table({"a": ["a"], "b": [1]}).to_batches()[0],
"df_arrow_single_batch",
marks=[
pytest.mark.notimpl(
[
"bigquery",
"clickhouse",
"dask",
"duckdb",
"exasol",
"impala",
"mssql",
"mysql",
"oracle",
"pandas",
"postgres",
"pyspark",
"risingwave",
"snowflake",
"sqlite",
"trino",
]
)
],
marks=[pytest.mark.notimpl(["dask", "pandas"])],
id="pyarrow_single_batch",
),
param(
Expand Down
18 changes: 18 additions & 0 deletions ibis/backends/tests/test_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,24 @@ def test_memtable_construct_from_pyarrow(backend, con, monkeypatch):
)


def test_memtable_construct_from_pyarrow_c_stream(backend, con):
pa = pytest.importorskip("pyarrow")

class Opaque:
def __init__(self, table):
self._table = table

def __arrow_c_stream__(self, *args, **kwargs):
return self._table.__arrow_c_stream__(*args, **kwargs)

table = pa.table({"a": list("abc"), "b": [1, 2, 3]})

t = ibis.memtable(Opaque(table))

res = con.to_pyarrow(t.order_by("a"))
assert res.equals(table)


@pytest.mark.parametrize("lazy", [False, True])
def test_memtable_construct_from_polars(backend, con, lazy):
pl = pytest.importorskip("polars")
Expand Down
58 changes: 43 additions & 15 deletions ibis/expr/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,42 +464,55 @@ def memtable(

@lazy_singledispatch
def _memtable(
data: pd.DataFrame | Any,
data: Any,
*,
columns: Iterable[str] | None = None,
schema: SchemaLike | None = None,
name: str | None = None,
) -> Table:
import pandas as pd

from ibis.formats.pandas import PandasDataFrameProxy
if hasattr(data, "__arrow_c_stream__"):
# Support objects exposing arrow's PyCapsule interface
import pyarrow as pa

if not isinstance(data, pd.DataFrame):
df = pd.DataFrame(data, columns=columns)
data = pa.table(data)
else:
df = data
import pandas as pd

data = pd.DataFrame(data, columns=columns)
return _memtable(data, columns=columns, schema=schema, name=name)


@_memtable.register("pandas.DataFrame")
def _memtable_from_pandas_dataframe(
data: pd.DataFrame,
*,
columns: Iterable[str] | None = None,
schema: SchemaLike | None = None,
name: str | None = None,
) -> Table:
from ibis.formats.pandas import PandasDataFrameProxy

if df.columns.inferred_type != "string":
cols = df.columns
if data.columns.inferred_type != "string":
cols = data.columns
newcols = getattr(
schema,
"names",
(f"col{i:d}" for i in builtins.range(len(cols))),
)
df = df.rename(columns=dict(zip(cols, newcols)))
data = data.rename(columns=dict(zip(cols, newcols)))

if columns is not None:
if (provided_col := len(columns)) != (exist_col := len(df.columns)):
if (provided_col := len(columns)) != (exist_col := len(data.columns)):
raise ValueError(
"Provided `columns` must have an entry for each column in `data`.\n"
f"`columns` has {provided_col} elements but `data` has {exist_col} columns."
)

df = df.rename(columns=dict(zip(df.columns, columns)))
data = data.rename(columns=dict(zip(data.columns, columns)))

# verify that the DataFrame has no duplicate column names because ibis
# doesn't allow that
cols = df.columns
cols = data.columns
dupes = [name for name, count in Counter(cols).items() if count > 1]
if dupes:
raise IbisInputError(
Expand All @@ -508,8 +521,8 @@ def _memtable(

op = ops.InMemoryTable(
name=name if name is not None else util.gen_name("pandas_memtable"),
schema=sch.infer(df) if schema is None else schema,
data=PandasDataFrameProxy(df),
schema=sch.infer(data) if schema is None else schema,
data=PandasDataFrameProxy(data),
)
return op.to_expr()

Expand All @@ -534,6 +547,21 @@ def _memtable_from_pyarrow_table(
).to_expr()


@_memtable.register("pyarrow.RecordBatchReader")
def _memtable_from_pyarrow_RecordBatchReader(
data: pa.Table,
*,
name: str | None = None,
schema: SchemaLike | None = None,
columns: Iterable[str] | None = None,
):
raise TypeError(
"Creating an `ibis.memtable` from a `pyarrow.RecordBatchReader` would "
"load _all_ data into memory. If you want to do this, please do so "
"explicitly like `ibis.memtable(reader.read_all())`"
)


@_memtable.register("polars.LazyFrame")
def _memtable_from_polars_lazyframe(data: pl.LazyFrame, **kwargs):
return _memtable_from_polars_dataframe(data.collect(), **kwargs)
Expand Down

0 comments on commit 28aff8a

Please sign in to comment.