From aedeaaa9fb7c6e4e1e13d2254ddf5830c24596f8 Mon Sep 17 00:00:00 2001 From: Peter Ke Date: Fri, 4 Oct 2024 14:37:20 -0700 Subject: [PATCH] init --- python/deltalake/_internal.pyi | 5 +++ python/deltalake/table.py | 7 ++++ python/src/lib.rs | 70 ++++++++++++++++++++++++++++++++- python/tests/test_table_read.py | 64 ++++++++++++++++++++++++++++++ 4 files changed, 144 insertions(+), 2 deletions(-) diff --git a/python/deltalake/_internal.pyi b/python/deltalake/_internal.pyi index 8329dddad9..24fae471be 100644 --- a/python/deltalake/_internal.pyi +++ b/python/deltalake/_internal.pyi @@ -221,6 +221,11 @@ class RawDeltaTable: starting_timestamp: Optional[str] = None, ending_timestamp: Optional[str] = None, ) -> pyarrow.RecordBatchReader: ... + def datafusion_read( + self, + predicate: Optional[str] = None, + columns: Optional[List[str]] = None, + ) -> None: ... def rust_core_version() -> str: ... def write_new_deltalake( diff --git a/python/deltalake/table.py b/python/deltalake/table.py index 9150be697c..5e9fff43a7 100644 --- a/python/deltalake/table.py +++ b/python/deltalake/table.py @@ -1417,6 +1417,13 @@ def repair( ) return json.loads(metrics) + def datafusion_read( + self, + predicate: Optional[str] = None, + columns: Optional[List[str]] = None, + ) -> List[pyarrow.RecordBatch]: + return self._table.datafusion_read(predicate, columns) + class TableMerger: """API for various table `MERGE` commands.""" diff --git a/python/src/lib.rs b/python/src/lib.rs index 473f5ceea9..b52ed9d5a9 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -8,6 +8,7 @@ mod utils; use std::collections::{HashMap, HashSet}; use std::future::IntoFuture; use std::str::FromStr; +use std::sync::Arc; use std::time; use std::time::{SystemTime, UNIX_EPOCH}; @@ -17,12 +18,18 @@ use delta_kernel::expressions::Scalar; use delta_kernel::schema::StructField; use deltalake::arrow::compute::concat_batches; use deltalake::arrow::ffi_stream::{ArrowArrayStreamReader, FFI_ArrowArrayStream}; +use deltalake::arrow::pyarrow::ToPyArrow; use deltalake::arrow::record_batch::{RecordBatch, RecordBatchIterator}; use deltalake::arrow::{self, datatypes::Schema as ArrowSchema}; use deltalake::checkpoints::{cleanup_metadata, create_checkpoint}; +use deltalake::datafusion::datasource::provider_as_source; +use deltalake::datafusion::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE}; use deltalake::datafusion::physical_plan::ExecutionPlan; -use deltalake::datafusion::prelude::SessionContext; -use deltalake::delta_datafusion::DeltaDataChecker; +use deltalake::datafusion::prelude::{DataFrame, SessionContext}; +use deltalake::delta_datafusion::{ + DataFusionMixins, DeltaDataChecker, DeltaScanConfigBuilder, DeltaSessionConfig, + DeltaTableProvider, +}; use deltalake::errors::DeltaTableError; use deltalake::kernel::{ scalars::ScalarExt, Action, Add, Invariant, LogicalFile, Remove, StructType, @@ -1232,6 +1239,65 @@ impl RawDeltaTable { self._table.state = table.state; Ok(serde_json::to_string(&metrics).unwrap()) } + + #[pyo3(signature = (predicate = None, columns = None))] + pub fn datafusion_read( + &self, + py: Python, + predicate: Option, + columns: Option>, + ) -> PyResult { + let batches = py.allow_threads(|| -> PyResult<_> { + let snapshot = self._table.snapshot().map_err(PythonError::from)?; + let log_store = self._table.log_store(); + + let scan_config = DeltaScanConfigBuilder::default() + .with_parquet_pushdown(false) + .build(snapshot) + .map_err(PythonError::from)?; + + let provider = Arc::new( + DeltaTableProvider::try_new(snapshot.clone(), log_store, scan_config) + .map_err(PythonError::from)?, + ); + let source = provider_as_source(provider); + + let config = DeltaSessionConfig::default().into(); + let session = SessionContext::new_with_config(config); + let state = session.state(); + + let maybe_filter = predicate + .map(|predicate| snapshot.parse_predicate_expression(predicate, &state)) + .transpose() + .map_err(PythonError::from)?; + + let filters = match &maybe_filter { + Some(filter) => vec![filter.clone()], + None => vec![], + }; + + let plan = LogicalPlanBuilder::scan_with_filters(UNNAMED_TABLE, source, None, filters) + .unwrap() + .build() + .unwrap(); + + let mut df = DataFrame::new(state, plan); + + if let Some(filter) = maybe_filter { + df = df.filter(filter).unwrap(); + } + + if let Some(columns) = columns { + df = df + .select_columns(&columns.iter().map(String::as_str).collect::>()) + .unwrap(); + } + + Ok(rt().block_on(async { df.collect().await }).unwrap()) + })?; + + batches.to_pyarrow(py) + } } fn set_post_commithook_properties( diff --git a/python/tests/test_table_read.py b/python/tests/test_table_read.py index 5ff07ed9e8..1339d4e146 100644 --- a/python/tests/test_table_read.py +++ b/python/tests/test_table_read.py @@ -946,3 +946,67 @@ def test_is_deltatable_with_storage_opts(): "DELTA_DYNAMO_TABLE_NAME": "custom_table_name", } assert DeltaTable.is_deltatable(table_path, storage_options=storage_options) + + +def test_datafusion_read_table(): + table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + expected = { + "value": ["1", "2", "3", "4", "5", "6", "7"], + "year": ["2020", "2020", "2020", "2021", "2021", "2021", "2021"], + "month": ["1", "2", "2", "4", "12", "12", "12"], + "day": ["1", "3", "5", "5", "4", "20", "20"], + } + actual = pa.Table.from_batches(dt.datafusion_read()).sort_by("value").to_pydict() + assert expected == actual + + +def test_datafusion_read_table_with_columns(): + table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + expected = { + "value": ["1", "2", "3", "4", "5", "6", "7"], + "day": ["1", "3", "5", "5", "4", "20", "20"], + } + actual = ( + pa.Table.from_batches(dt.datafusion_read(columns=["value", "day"])) + .sort_by("value") + .to_pydict() + ) + assert expected == actual + + +def test_datafusion_read_with_filter_on_partitioned_column(): + table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + expected = { + "value": ["1", "2", "3"], + "year": ["2020", "2020", "2020"], + "month": ["1", "2", "2"], + "day": ["1", "3", "5"], + } + actual = ( + pa.Table.from_batches(dt.datafusion_read(predicate="year = '2020'")) + .sort_by("value") + .to_pydict() + ) + assert expected == actual + + +def test_datafusion_read_with_filter_on_multiple_columns(): + table_path = "../crates/test/tests/data/delta-0.8.0-partitioned" + dt = DeltaTable(table_path) + expected = { + "value": ["4", "5"], + "year": ["2021", "2021"], + "month": ["4", "12"], + "day": ["5", "4"], + } + actual = ( + pa.Table.from_batches( + dt.datafusion_read(predicate="year = '2021' and value < '6'") + ) + .sort_by("value") + .to_pydict() + ) + assert expected == actual