From e48c90f305b86745b610304263081f0d9d4f04e9 Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Mon, 22 Jul 2024 15:25:14 -0500 Subject: [PATCH] feat(pyarrow): support objects implementing `__arrow_c_stream__` in `ibis.memtable` --- ibis/backends/tests/test_client.py | 21 ----------- ibis/backends/tests/test_generic.py | 18 +++++++++ ibis/expr/api.py | 58 +++++++++++++++++++++-------- 3 files changed, 61 insertions(+), 36 deletions(-) diff --git a/ibis/backends/tests/test_client.py b/ibis/backends/tests/test_client.py index 3b1164426d60..19fd3933266a 100644 --- a/ibis/backends/tests/test_client.py +++ b/ibis/backends/tests/test_client.py @@ -927,27 +927,6 @@ def test_self_join_memory_table(backend, con, monkeypatch): param( lambda: pa.table({"a": ["a"], "b": [1]}).to_batches()[0], "df_arrow_single_batch", - marks=[ - pytest.mark.notimpl( - [ - "bigquery", - "clickhouse", - "duckdb", - "exasol", - "impala", - "mssql", - "mysql", - "oracle", - "postgres", - "pyspark", - "risingwave", - "snowflake", - "sqlite", - "trino", - "databricks", - ] - ) - ], id="pyarrow_single_batch", ), param( diff --git a/ibis/backends/tests/test_generic.py b/ibis/backends/tests/test_generic.py index 634f26e520f9..dfefe1212fea 100644 --- a/ibis/backends/tests/test_generic.py +++ b/ibis/backends/tests/test_generic.py @@ -1222,6 +1222,24 @@ def test_memtable_construct_from_pyarrow(backend, con, monkeypatch): ) +def test_memtable_construct_from_pyarrow_c_stream(backend, con): + pa = pytest.importorskip("pyarrow") + + class Opaque: + def __init__(self, table): + self._table = table + + def __arrow_c_stream__(self, *args, **kwargs): + return self._table.__arrow_c_stream__(*args, **kwargs) + + table = pa.table({"a": list("abc"), "b": [1, 2, 3]}) + + t = ibis.memtable(Opaque(table)) + + res = con.to_pyarrow(t.order_by("a")) + assert res.equals(table) + + @pytest.mark.parametrize("lazy", [False, True]) def test_memtable_construct_from_polars(backend, con, lazy): pl = pytest.importorskip("polars") diff --git a/ibis/expr/api.py b/ibis/expr/api.py index 63148d7fb1cb..d03c48d4ccb0 100644 --- a/ibis/expr/api.py +++ b/ibis/expr/api.py @@ -412,42 +412,55 @@ def memtable( @lazy_singledispatch def _memtable( - data: pd.DataFrame | Any, + data: Any, *, columns: Iterable[str] | None = None, schema: SchemaLike | None = None, name: str | None = None, ) -> Table: - import pandas as pd - - from ibis.formats.pandas import PandasDataFrameProxy + if hasattr(data, "__arrow_c_stream__"): + # Support objects exposing arrow's PyCapsule interface + import pyarrow as pa - if not isinstance(data, pd.DataFrame): - df = pd.DataFrame(data, columns=columns) + data = pa.table(data) else: - df = data + import pandas as pd + + data = pd.DataFrame(data, columns=columns) + return _memtable(data, columns=columns, schema=schema, name=name) + + +@_memtable.register("pandas.DataFrame") +def _memtable_from_pandas_dataframe( + data: pd.DataFrame, + *, + columns: Iterable[str] | None = None, + schema: SchemaLike | None = None, + name: str | None = None, +) -> Table: + from ibis.formats.pandas import PandasDataFrameProxy - if df.columns.inferred_type != "string": - cols = df.columns + if data.columns.inferred_type != "string": + cols = data.columns newcols = getattr( schema, "names", (f"col{i:d}" for i in builtins.range(len(cols))), ) - df = df.rename(columns=dict(zip(cols, newcols))) + data = data.rename(columns=dict(zip(cols, newcols))) if columns is not None: - if (provided_col := len(columns)) != (exist_col := len(df.columns)): + if (provided_col := len(columns)) != (exist_col := len(data.columns)): raise ValueError( "Provided `columns` must have an entry for each column in `data`.\n" f"`columns` has {provided_col} elements but `data` has {exist_col} columns." ) - df = df.rename(columns=dict(zip(df.columns, columns))) + data = data.rename(columns=dict(zip(data.columns, columns))) # verify that the DataFrame has no duplicate column names because ibis # doesn't allow that - cols = df.columns + cols = data.columns dupes = [name for name, count in Counter(cols).items() if count > 1] if dupes: raise IbisInputError( @@ -456,8 +469,8 @@ def _memtable( op = ops.InMemoryTable( name=name if name is not None else util.gen_name("pandas_memtable"), - schema=sch.infer(df) if schema is None else schema, - data=PandasDataFrameProxy(df), + schema=sch.infer(data) if schema is None else schema, + data=PandasDataFrameProxy(data), ) return op.to_expr() @@ -499,6 +512,21 @@ def _memtable_from_pyarrow_dataset( ).to_expr() +@_memtable.register("pyarrow.RecordBatchReader") +def _memtable_from_pyarrow_RecordBatchReader( + data: pa.Table, + *, + name: str | None = None, + schema: SchemaLike | None = None, + columns: Iterable[str] | None = None, +): + raise TypeError( + "Creating an `ibis.memtable` from a `pyarrow.RecordBatchReader` would " + "load _all_ data into memory. If you want to do this, please do so " + "explicitly like `ibis.memtable(reader.read_all())`" + ) + + @_memtable.register("polars.LazyFrame") def _memtable_from_polars_lazyframe(data: pl.LazyFrame, **kwargs): return _memtable_from_polars_dataframe(data.collect(), **kwargs)