diff --git a/Cargo.lock b/Cargo.lock index 626ab52988d31..dd55347169456 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3323,6 +3323,7 @@ dependencies = [ "recursive", "regex", "serde", + "serde_json", "smartstring", "strum_macros 0.26.4", "version_check", diff --git a/crates/polars-mem-engine/src/executors/scan/python_scan.rs b/crates/polars-mem-engine/src/executors/scan/python_scan.rs index bdc46900a3aa0..ace26f9f591da 100644 --- a/crates/polars-mem-engine/src/executors/scan/python_scan.rs +++ b/crates/polars-mem-engine/src/executors/scan/python_scan.rs @@ -2,6 +2,7 @@ use polars_core::error::to_compute_err; use polars_core::utils::accumulate_dataframes_vertical; use pyo3::exceptions::PyStopIteration; use pyo3::prelude::*; +use pyo3::types::PyBytes; use pyo3::{intern, PyTypeInfo}; use super::*; @@ -9,6 +10,7 @@ use super::*; pub(crate) struct PythonScanExec { pub(crate) options: PythonOptions, pub(crate) predicate: Option>, + pub(crate) predicate_serialized: Option>, } fn python_df_to_rust(py: Python, df: Bound) -> PolarsResult { @@ -51,22 +53,36 @@ impl Executor for PythonScanExec { let predicate = match &self.options.predicate { PythonPredicate::PyArrow(s) => s.into_py(py), - PythonPredicate::None => (None::<()>).into_py(py), - // Still todo, currently we apply the predicate on this side. + PythonPredicate::None => None::<()>.into_py(py), PythonPredicate::Polars(_) => { assert!(self.predicate.is_some(), "should be set"); - (None::<()>).into_py(py) + match &self.predicate_serialized { + None => None::<()>.into_py(py), + Some(buf) => PyBytes::new_bound(py, buf).to_object(py), + } }, }; - let generator = callable - .call1((python_scan_function, with_columns, predicate, n_rows)) + let batch_size = if self.options.is_pyarrow { + None + } else { + Some(100_000usize) + }; + + let generator_init = callable + .call1(( + python_scan_function, + with_columns, + predicate, + n_rows, + batch_size, + )) .map_err(to_compute_err)?; // This isn't a generator, but a `DataFrame`. - if generator.getattr(intern!(py, "_df")).is_ok() { - let df = python_df_to_rust(py, generator)?; + if generator_init.getattr(intern!(py, "_df")).is_ok() { + let df = python_df_to_rust(py, generator_init)?; return if let Some(pred) = &self.predicate { let mask = pred.evaluate(&df, state)?; df.filter(mask.bool()?) @@ -75,12 +91,22 @@ impl Executor for PythonScanExec { }; } + let generator = generator_init + .get_item(0) + .map_err(|_| polars_err!(ComputeError: "expected tuple got {}", generator_init))?; + let can_parse_predicate = generator_init + .get_item(1) + .map_err(|_| polars_err!(ComputeError: "expected tuple got {}", generator))?; + let can_parse_predicate = can_parse_predicate.extract::().map_err( + |_| polars_err!(ComputeError: "expected bool got {}", can_parse_predicate), + )?; + let mut chunks = vec![]; loop { match generator.call_method0(intern!(py, "__next__")) { Ok(out) => { let mut df = python_df_to_rust(py, out)?; - if let Some(pred) = &self.predicate { + if let (Some(pred), false) = (&self.predicate, can_parse_predicate) { let mask = pred.evaluate(&df, state)?; df = df.filter(mask.bool()?)?; } @@ -88,7 +114,7 @@ impl Executor for PythonScanExec { }, Err(err) if err.matches(py, PyStopIteration::type_object_bound(py)) => break, Err(err) => { - polars_bail!(ComputeError: "catched exception during execution of a Python source, exception: {}", err) + polars_bail!(ComputeError: "caught exception during execution of a Python source, exception: {}", err) }, } } diff --git a/crates/polars-mem-engine/src/planner/lp.rs b/crates/polars-mem-engine/src/planner/lp.rs index baf41101a369b..67c81b64142f9 100644 --- a/crates/polars-mem-engine/src/planner/lp.rs +++ b/crates/polars-mem-engine/src/planner/lp.rs @@ -159,12 +159,15 @@ fn create_physical_plan_impl( match logical_plan { #[cfg(feature = "python")] PythonScan { mut options } => { + let mut predicate_serialized = None; let predicate = if let PythonPredicate::Polars(e) = &options.predicate { // Convert to a pyarrow eval string. if options.is_pyarrow { - if let Some(eval_str) = - pyarrow::predicate_to_pa(e.node(), expr_arena, Default::default()) - { + if let Some(eval_str) = polars_plan::plans::python::pyarrow::predicate_to_pa( + e.node(), + expr_arena, + Default::default(), + ) { options.predicate = PythonPredicate::PyArrow(eval_str) } @@ -173,6 +176,10 @@ fn create_physical_plan_impl( } // Convert to physical expression for the case the reader cannot consume the predicate. else { + let dsl_expr = e.to_expr(expr_arena); + predicate_serialized = + polars_plan::plans::python::predicate::serialize(&dsl_expr)?; + let mut state = ExpressionConversionState::new(true, state.expr_depth); Some(create_physical_expr( e, @@ -185,7 +192,11 @@ fn create_physical_plan_impl( } else { None }; - Ok(Box::new(executors::PythonScanExec { options, predicate })) + Ok(Box::new(executors::PythonScanExec { + options, + predicate, + predicate_serialized, + })) }, Sink { payload, .. } => match payload { SinkType::Memory => { diff --git a/crates/polars-plan/Cargo.toml b/crates/polars-plan/Cargo.toml index a05b19720be64..3ad50ace7fd87 100644 --- a/crates/polars-plan/Cargo.toml +++ b/crates/polars-plan/Cargo.toml @@ -40,6 +40,7 @@ rayon = { workspace = true } recursive = { workspace = true } regex = { workspace = true, optional = true } serde = { workspace = true, features = ["rc"], optional = true } +serde_json = { workspace = true, optional = true } smartstring = { workspace = true } strum_macros = { workspace = true } diff --git a/crates/polars-plan/src/plans/mod.rs b/crates/polars-plan/src/plans/mod.rs index f470e9177d90c..8688521edeafd 100644 --- a/crates/polars-plan/src/plans/mod.rs +++ b/crates/polars-plan/src/plans/mod.rs @@ -29,7 +29,7 @@ mod lit; pub(crate) mod optimizer; pub(crate) mod options; #[cfg(feature = "python")] -pub mod pyarrow; +pub mod python; mod schema; pub mod visitor; diff --git a/crates/polars-plan/src/plans/optimizer/fused.rs b/crates/polars-plan/src/plans/optimizer/fused.rs index e6840f992fbc1..2584ff36dfe6d 100644 --- a/crates/polars-plan/src/plans/optimizer/fused.rs +++ b/crates/polars-plan/src/plans/optimizer/fused.rs @@ -65,6 +65,13 @@ impl OptimizationRule for FusedArithmetic { lp_arena: &Arena, lp_node: Node, ) -> PolarsResult> { + // We don't want to fuse arithmetic that we send to pyarrow. + #[cfg(feature = "python")] + if let IR::PythonScan { options } = lp_arena.get(lp_node) { + if options.is_pyarrow { + return Ok(None); + } + }; let expr = expr_arena.get(expr_node); use AExpr::*; diff --git a/crates/polars-plan/src/plans/python/mod.rs b/crates/polars-plan/src/plans/python/mod.rs new file mode 100644 index 0000000000000..e82c95f09e782 --- /dev/null +++ b/crates/polars-plan/src/plans/python/mod.rs @@ -0,0 +1,2 @@ +pub mod predicate; +pub mod pyarrow; diff --git a/crates/polars-plan/src/plans/python/predicate.rs b/crates/polars-plan/src/plans/python/predicate.rs new file mode 100644 index 0000000000000..2e4a21af2749c --- /dev/null +++ b/crates/polars-plan/src/plans/python/predicate.rs @@ -0,0 +1,69 @@ +use polars_core::error::polars_err; +use polars_core::prelude::PolarsResult; + +use crate::prelude::*; + +fn accept_as_io_predicate(e: &Expr) -> bool { + const LIMIT: usize = 1 << 16; + match e { + Expr::Literal(lv) => match lv { + LiteralValue::Binary(v) => v.len() <= LIMIT, + LiteralValue::String(v) => v.len() <= LIMIT, + LiteralValue::Series(s) => s.estimated_size() < LIMIT, + // Don't accept dynamic types + LiteralValue::Int(_) => false, + LiteralValue::Float(_) => false, + _ => true, + }, + Expr::Wildcard | Expr::Column(_) => true, + Expr::BinaryExpr { left, right, .. } => { + accept_as_io_predicate(left) && accept_as_io_predicate(right) + }, + Expr::Ternary { + truthy, + falsy, + predicate, + } => { + accept_as_io_predicate(truthy) + && accept_as_io_predicate(falsy) + && accept_as_io_predicate(predicate) + }, + Expr::Alias(_, _) => true, + Expr::Function { + function, input, .. + } => { + match function { + // we already checked if streaming, so we can all functions + FunctionExpr::Boolean(_) | FunctionExpr::BinaryExpr(_) | FunctionExpr::Coalesce => { + }, + #[cfg(feature = "log")] + FunctionExpr::Entropy { .. } + | FunctionExpr::Log { .. } + | FunctionExpr::Log1p { .. } + | FunctionExpr::Exp { .. } => {}, + #[cfg(feature = "abs")] + FunctionExpr::Abs => {}, + #[cfg(feature = "trigonometry")] + FunctionExpr::Atan2 => {}, + #[cfg(feature = "round_series")] + FunctionExpr::Clip { .. } => {}, + #[cfg(feature = "fused")] + FunctionExpr::Fused(_) => {}, + _ => return false, + } + input.iter().all(accept_as_io_predicate) + }, + _ => false, + } +} + +pub fn serialize(expr: &Expr) -> PolarsResult>> { + if !accept_as_io_predicate(expr) { + return Ok(None); + } + let mut buf = vec![]; + ciborium::into_writer(expr, &mut buf) + .map_err(|_| polars_err!(ComputeError: "could not serialize: {}", expr))?; + + Ok(Some(buf)) +} diff --git a/crates/polars-plan/src/plans/pyarrow.rs b/crates/polars-plan/src/plans/python/pyarrow.rs similarity index 88% rename from crates/polars-plan/src/plans/pyarrow.rs rename to crates/polars-plan/src/plans/python/pyarrow.rs index abf2c8e34a7ff..1232fcfde673f 100644 --- a/crates/polars-plan/src/plans/pyarrow.rs +++ b/crates/polars-plan/src/plans/python/pyarrow.rs @@ -38,7 +38,6 @@ pub fn predicate_to_pa( } }, AExpr::Column(name) => Some(format!("pa.compute.field('{}')", name.as_ref())), - AExpr::Alias(input, _) => predicate_to_pa(*input, expr_arena, args), AExpr::Literal(LiteralValue::Series(s)) => { if !args.allow_literal_series || s.is_empty() || s.len() > 100 { None @@ -115,33 +114,6 @@ pub fn predicate_to_pa( }, } }, - AExpr::Function { - function: FunctionExpr::Boolean(BooleanFunction::Not), - input, - .. - } => { - let input = input.first().unwrap().node(); - let input = predicate_to_pa(input, expr_arena, args)?; - Some(format!("~({input})")) - }, - AExpr::Function { - function: FunctionExpr::Boolean(BooleanFunction::IsNull), - input, - .. - } => { - let input = input.first().unwrap().node(); - let input = predicate_to_pa(input, expr_arena, args)?; - Some(format!("({input}).is_null()")) - }, - AExpr::Function { - function: FunctionExpr::Boolean(BooleanFunction::IsNotNull), - input, - .. - } => { - let input = input.first().unwrap().node(); - let input = predicate_to_pa(input, expr_arena, args)?; - Some(format!("~({input}).is_null()")) - }, #[cfg(feature = "is_in")] AExpr::Function { function: FunctionExpr::Boolean(BooleanFunction::IsIn), @@ -182,6 +154,23 @@ pub fn predicate_to_pa( )) } }, + AExpr::Function { + function, input, .. + } => { + let input = input.first().unwrap().node(); + let input = predicate_to_pa(input, expr_arena, args)?; + + match function { + FunctionExpr::Boolean(BooleanFunction::Not) => Some(format!("~({input})")), + FunctionExpr::Boolean(BooleanFunction::IsNull) => { + Some(format!("({input}).is_null()")) + }, + FunctionExpr::Boolean(BooleanFunction::IsNotNull) => { + Some(format!("~({input}).is_null()")) + }, + _ => None, + } + }, _ => None, } } diff --git a/py-polars/polars/expr/expr.py b/py-polars/polars/expr/expr.py index ec65f40af0d72..92baadbaddee3 100644 --- a/py-polars/polars/expr/expr.py +++ b/py-polars/polars/expr/expr.py @@ -341,7 +341,10 @@ def function(s: Series) -> Series: # pragma: no cover @classmethod def deserialize( - cls, source: str | Path | IOBase, *, format: SerializationFormat = "binary" + cls, + source: str | Path | IOBase | bytes, + *, + format: SerializationFormat = "binary", ) -> Expr: """ Read a serialized expression from a file. @@ -385,6 +388,8 @@ def deserialize( source = BytesIO(source.getvalue().encode()) elif isinstance(source, (str, Path)): source = normalize_filepath(source) + elif isinstance(source, bytes): + source = BytesIO(source) if format == "binary": deserializer = PyExpr.deserialize_binary diff --git a/py-polars/polars/io/plugin.py b/py-polars/polars/io/plugin.py deleted file mode 100644 index cd93d24dff0e4..0000000000000 --- a/py-polars/polars/io/plugin.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from typing import TYPE_CHECKING, Callable, Iterator - -import polars._reexport as pl -from polars._utils.unstable import unstable - -if TYPE_CHECKING: - from polars import DataFrame, Expr, LazyFrame - from polars._typing import SchemaDict - - -@unstable() -def register_io_source( - callable: Callable[ - [list[str] | None, Expr | None, int | None, int | None], Iterator[DataFrame] - ], - schema: SchemaDict, -) -> LazyFrame: - """ - Register your IO plugin and initialize a LazyFrame. - - Parameters - ---------- - callable - Function that accepts the following arguments: - `with_columns` - Columns that are projected. The reader must - project these columns if applied - predicate - Polars expression. The reader must filter - there rows accordingly. - n_rows: - Materialize only n rows from the source. - The reader can stop when `n_rows` are read. - batch_size - A hint of the ideal batch size the readers - generator must produce. - schema - Schema that the reader will produce before projection pushdown. - - """ - return pl.LazyFrame._scan_python_function( - schema=schema, scan_fn=callable, pyarrow=False - ) diff --git a/py-polars/polars/io/plugins.py b/py-polars/polars/io/plugins.py new file mode 100644 index 0000000000000..a2100d3e7e07f --- /dev/null +++ b/py-polars/polars/io/plugins.py @@ -0,0 +1,76 @@ +from __future__ import annotations + +import os +import sys +from typing import TYPE_CHECKING, Callable, Iterator + +import polars._reexport as pl +from polars._utils.unstable import unstable + +if TYPE_CHECKING: + from typing import Callable, Iterator + + from polars import DataFrame, Expr, LazyFrame + from polars._typing import SchemaDict + + +@unstable() +def register_io_source( + callable: Callable[ + [list[str] | None, Expr | None, int | None, int | None], Iterator[DataFrame] + ], + schema: SchemaDict, +) -> LazyFrame: + """ + Register your IO plugin and initialize a LazyFrame. + + Parameters + ---------- + callable + Function that accepts the following arguments: + with_columns + Columns that are projected. The reader must + project these columns if applied + predicate + Polars expression. The reader must filter + there rows accordingly. + n_rows: + Materialize only n rows from the source. + The reader can stop when `n_rows` are read. + batch_size + A hint of the ideal batch size the readers + generator must produce. + returns + A DataFrame batch and whether it was able to deserialize + and apply the predicate + schema + Schema that the reader will produce before projection pushdown. + + """ + + def wrap( + with_columns: list[str] | None, + predicate: bytes | None, + n_rows: int | None, + batch_size: int | None, + ) -> tuple[Iterator[DataFrame], bool]: + parsed_predicate_success = True + parsed_predicate = None + if predicate: + try: + parsed_predicate = pl.Expr.deserialize(predicate) + except Exception as e: + if os.environ.get("POLARS_VERBOSE"): + print( + f"failed parsing IO plugin expression\n\nfilter will be handled on Polars' side: {e}", + file=sys.stderr, + ) + parsed_predicate_success = False + + return callable( + with_columns, parsed_predicate, n_rows, batch_size + ), parsed_predicate_success + + return pl.LazyFrame._scan_python_function( + schema=schema, scan_fn=wrap, pyarrow=False + ) diff --git a/py-polars/tests/unit/io/test_plugins.py b/py-polars/tests/unit/io/test_plugins.py new file mode 100644 index 0000000000000..98c25edc3f4ab --- /dev/null +++ b/py-polars/tests/unit/io/test_plugins.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import polars as pl +from polars.io.plugins import register_io_source +from polars.testing import assert_frame_equal + +if TYPE_CHECKING: + from typing import Iterator + + +# A simple python source. But this can dispatch into a rust IO source as well. +def my_source( + with_columns: list[str] | None, + predicate: pl.Expr | None, + _n_rows: int | None, + _batch_size: int | None, +) -> Iterator[pl.DataFrame]: + for i in [1, 2, 3]: + df = pl.DataFrame({"a": [i], "b": [i]}) + + if predicate is not None: + df = df.filter(predicate) + + if with_columns is not None: + df = df.select(with_columns) + + yield df + + +def scan_my_source() -> pl.LazyFrame: + # schema inference logic + # TODO: make lazy via callable + schema = pl.Schema({"a": pl.Int64(), "b": pl.Int64()}) + + return register_io_source(my_source, schema=schema) + + +def test_my_source() -> None: + assert_frame_equal( + scan_my_source().collect(), pl.DataFrame({"a": [1, 2, 3], "b": [1, 2, 3]}) + ) + assert_frame_equal( + scan_my_source().filter(pl.col("b") > 1).collect(), + pl.DataFrame({"a": [2, 3], "b": [2, 3]}), + ) + assert_frame_equal( + scan_my_source().filter(pl.col("b") > 1).select("a").collect(), + pl.DataFrame({"a": [2, 3]}), + ) + assert_frame_equal( + scan_my_source().select("a").collect(), pl.DataFrame({"a": [1, 2, 3]}) + )