Skip to content

Commit

Permalink
apply predicates on rust side
Browse files Browse the repository at this point in the history
  • Loading branch information
ritchie46 committed Jul 30, 2024
1 parent 7fa72a6 commit 2e00fdd
Show file tree
Hide file tree
Showing 16 changed files with 123 additions and 72 deletions.
22 changes: 19 additions & 3 deletions crates/polars-mem-engine/src/executors/scan/python_scan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use super::*;

pub(crate) struct PythonScanExec {
pub(crate) options: PythonOptions,
pub(crate) predicate: Option<Arc<dyn PhysicalExpr>>,
}

fn python_df_to_rust(py: Python, df: Bound<PyAny>) -> PolarsResult<DataFrame> {
Expand Down Expand Up @@ -51,7 +52,12 @@ impl Executor for PythonScanExec {
let predicate = match &self.options.predicate {
PythonPredicate::PyArrow(s) => s.into_py(py),
PythonPredicate::None => (None::<()>).into_py(py),
PythonPredicate::Polars(_) => todo!(),
// Still todo, currently we apply the predicate on this side.
PythonPredicate::Polars(_) => {
assert!(self.predicate.is_some(), "should be set");

(None::<()>).into_py(py)
},
};

let generator = callable
Expand All @@ -60,14 +66,24 @@ impl Executor for PythonScanExec {

// This isn't a generator, but a `DataFrame`.
if generator.getattr(intern!(py, "_df")).is_ok() {
return python_df_to_rust(py, generator);
let df = python_df_to_rust(py, generator)?;
return if let Some(pred) = &self.predicate {
let mask = pred.evaluate(&df, state)?;
df.filter(mask.bool()?)
} else {
Ok(df)
};
}

let mut chunks = vec![];
loop {
match generator.call_method0(intern!(py, "__next__")) {
Ok(out) => {
let df = python_df_to_rust(py, out)?;
let mut df = python_df_to_rust(py, out)?;
if let Some(pred) = &self.predicate {
let mask = pred.evaluate(&df, state)?;
df = df.filter(mask.bool()?)?;
}
chunks.push(df)
},
Err(err) if err.matches(py, PyStopIteration::type_object_bound(py)) => break,
Expand Down
16 changes: 15 additions & 1 deletion crates/polars-mem-engine/src/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,21 @@ fn create_physical_plan_impl(
let logical_plan = lp_arena.take(root);
match logical_plan {
#[cfg(feature = "python")]
PythonScan { options, .. } => Ok(Box::new(executors::PythonScanExec { options })),
PythonScan { options } => {
let predicate = if let PythonPredicate::Polars(e) = &options.predicate {
let mut state = ExpressionConversionState::new(true, state.expr_depth);
Some(create_physical_expr(
e,
Context::Default,
expr_arena,
Some(&options.schema),
&mut state,
)?)
} else {
None
};
Ok(Box::new(executors::PythonScanExec { options, predicate }))
},
Sink { payload, .. } => match payload {
SinkType::Memory => {
polars_bail!(InvalidOperation: "memory sink not supported in the standard engine")
Expand Down
5 changes: 1 addition & 4 deletions crates/polars-plan/src/plans/conversion/dsl_to_ir.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,7 @@ pub fn to_alp_impl(
}
},
#[cfg(feature = "python")]
DslPlan::PythonScan { options } => IR::PythonScan {
options,
predicate: None,
},
DslPlan::PythonScan { options } => IR::PythonScan { options },
DslPlan::Union { inputs, args } => {
let mut inputs = inputs
.into_iter()
Expand Down
9 changes: 6 additions & 3 deletions crates/polars-plan/src/plans/ir/dot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,11 +153,14 @@ impl<'a> IRDotDisplay<'a> {
write_label(f, id, |f| write!(f, "FILTER BY {pred}"))?;
},
#[cfg(feature = "python")]
PythonScan { predicate, options } => {
let predicate = predicate.as_ref().map(|e| self.display_expr(e));
PythonScan { options } => {
let predicate = match &options.predicate {
PythonPredicate::Polars(e) => format!("{}", self.display_expr(e)),
PythonPredicate::PyArrow(s) => s.clone(),
PythonPredicate::None => "none".to_string(),
};
let with_columns = NumColumns(options.with_columns.as_ref().map(|s| s.as_ref()));
let total_columns = options.schema.len();
let predicate = OptionExprIRDisplay(predicate);

write_label(f, id, |f| {
write!(
Expand Down
8 changes: 6 additions & 2 deletions crates/polars-plan/src/plans/ir/format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,19 @@ impl<'a> IRDisplay<'a> {

match self.root() {
#[cfg(feature = "python")]
PythonScan { options, predicate } => {
PythonScan { options } => {
let total_columns = options.schema.len();
let n_columns = options
.with_columns
.as_ref()
.map(|s| s.len() as i64)
.unwrap_or(-1);

let predicate = predicate.as_ref().map(|p| self.display_expr(p));
let predicate = match &options.predicate {
PythonPredicate::Polars(e) => Some(self.display_expr(e)),
PythonPredicate::PyArrow(_) => None,
PythonPredicate::None => None,
};

write_scan(
f,
Expand Down
3 changes: 1 addition & 2 deletions crates/polars-plan/src/plans/ir/inputs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@ impl IR {

match self {
#[cfg(feature = "python")]
PythonScan { options, predicate } => PythonScan {
PythonScan { options } => PythonScan {
options: options.clone(),
predicate: predicate.clone(),
},
Union { options, .. } => Union {
inputs,
Expand Down
1 change: 0 additions & 1 deletion crates/polars-plan/src/plans/ir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ pub enum IR {
#[cfg(feature = "python")]
PythonScan {
options: PythonOptions,
predicate: Option<ExprIR>,
},
Slice {
input: Node,
Expand Down
4 changes: 2 additions & 2 deletions crates/polars-plan/src/plans/ir/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ impl IR {
use IR::*;
let schema = match self {
#[cfg(feature = "python")]
PythonScan { options, .. } => &options.schema,
PythonScan { options } => &options.schema,
DataFrameScan { schema, .. } => schema,
Scan { file_info, .. } => &file_info.schema,
node => {
Expand All @@ -68,7 +68,7 @@ impl IR {
use IR::*;
let schema = match self {
#[cfg(feature = "python")]
PythonScan { options, .. } => options.output_schema.as_ref().unwrap_or(&options.schema),
PythonScan { options } => options.output_schema.as_ref().unwrap_or(&options.schema),
Union { inputs, .. } => return arena.get(inputs[0]).schema(arena),
HConcat { schema, .. } => schema,
Cache { input, .. } => return arena.get(*input).schema(arena),
Expand Down
53 changes: 19 additions & 34 deletions crates/polars-plan/src/plans/optimizer/predicate_pushdown/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -657,18 +657,15 @@ impl<'a> PredicatePushDown<'a> {
}
},
#[cfg(feature = "python")]
PythonScan {
mut options,
predicate,
} => {
if options.is_pyarrow {
let predicate = predicate_at_scan(acc_predicates, predicate, expr_arena);
PythonScan { mut options } => {
let predicate = predicate_at_scan(acc_predicates, None, expr_arena);

if let Some(predicate) = predicate.clone() {
// simplify expressions before we translate them to pyarrow
if options.is_pyarrow {
if let Some(predicate) = predicate {
// Simplify expressions before we translate them to pyarrow
options.predicate = PythonPredicate::Polars(predicate);
let lp = PythonScan {
options: options.clone(),
predicate: Some(predicate),
};
let lp_top = lp_arena.add(lp);
let stack_opt = StackOptimizer {};
Expand All @@ -680,11 +677,10 @@ impl<'a> PredicatePushDown<'a> {
lp_top,
)
.unwrap();
let PythonScan {
options: _,
predicate: Some(predicate),
} = lp_arena.take(lp_top)
else {
let PythonScan { mut options } = lp_arena.take(lp_top) else {
unreachable!()
};
let PythonPredicate::Polars(predicate) = &options.predicate else {
unreachable!()
};

Expand All @@ -693,34 +689,23 @@ impl<'a> PredicatePushDown<'a> {
expr_arena,
Default::default(),
) {
// We were able to create a pyarrow string, mutate the options
// We were able to create a pyarrow string, mutate the options.
Some(eval_str) => {
options.predicate = PythonPredicate::PyArrow(eval_str)
},
// we were not able to translate the predicate
// apply here
// We were not able to translate the predicate apply on the rust side in the scan.
None => {
let lp = PythonScan {
options,
predicate: None,
};
return Ok(self.optional_apply_predicate(
lp,
vec![predicate],
lp_arena,
expr_arena,
));
let lp = PythonScan { options };
return Ok(lp);
},
}
}
Ok(PythonScan { options, predicate })
Ok(PythonScan { options })
} else {
self.no_pushdown_restart_opt(
PythonScan { options, predicate },
acc_predicates,
lp_arena,
expr_arena,
)
if let Some(predicate) = predicate {
options.predicate = PythonPredicate::Polars(predicate);
}
Ok(PythonScan { options })
}
},
Invalid => unreachable!(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -380,10 +380,7 @@ impl ProjectionPushDown {
Ok(lp)
},
#[cfg(feature = "python")]
PythonScan {
mut options,
predicate,
} => {
PythonScan { mut options } => {
options.with_columns = get_scan_columns(&acc_projections, expr_arena, None, None);

options.output_schema = if options.with_columns.is_none() {
Expand All @@ -396,7 +393,7 @@ impl ProjectionPushDown {
true,
)?))
};
Ok(PythonScan { options, predicate })
Ok(PythonScan { options })
},
Scan {
paths,
Expand Down
4 changes: 1 addition & 3 deletions crates/polars-plan/src/plans/optimizer/slice_pushdown_lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,13 @@ impl SlicePushDown {
#[cfg(feature = "python")]
(PythonScan {
mut options,
predicate,
},
// TODO! we currently skip slice pushdown if there is a predicate.
// we can modify the readers to only limit after predicates have been applied
Some(state)) if state.offset == 0 && predicate.is_none() => {
Some(state)) if state.offset == 0 && matches!(options.predicate, PythonPredicate::None) => {
options.n_rows = Some(state.len as usize);
let lp = PythonScan {
options,
predicate
};
Ok(lp)
}
Expand Down
8 changes: 3 additions & 5 deletions crates/polars-plan/src/plans/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ use polars_time::{DynamicGroupOptions, RollingGroupOptions};
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use crate::plans::ExprIR;
#[cfg(feature = "python")]
use crate::prelude::python_udf::PythonFunction;
#[cfg(feature = "python")]
use crate::prelude::Expr;

pub type FileCount = u32;

Expand Down Expand Up @@ -240,19 +239,18 @@ pub struct PythonOptions {
// Whether this is a pyarrow dataset source or a Polars source.
pub is_pyarrow: bool,
/// Optional predicate the reader must apply.
#[cfg_attr(feature = "serde", serde(skip))]
pub predicate: PythonPredicate,
/// A `head` call passed to the reader.
pub n_rows: Option<usize>,
}

#[derive(Clone, PartialEq, Eq, Debug, Default)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg(feature = "python")]
pub enum PythonPredicate {
// A pyarrow predicate python expression
// can be evaluated with python.eval
PyArrow(String),
Polars(Expr),
Polars(ExprIR),
#[default]
None,
}
Expand Down
45 changes: 45 additions & 0 deletions py-polars/polars/io/plugin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Iterator

import polars._reexport as pl
from polars._utils.unstable import unstable

if TYPE_CHECKING:
from polars import DataFrame, Expr, LazyFrame
from polars._typing import SchemaDict


@unstable()
def register_io_source(
callable: Callable[
[list[str] | None, Expr | None, int | None, int | None], Iterator[DataFrame]
],
schema: SchemaDict,
) -> LazyFrame:
"""
Register your IO plugin and initialize a LazyFrame.
Parameters
----------
callable
Function that accepts the following arguments:
`with_columns`
Columns that are projected. The reader must
project these columns if applied
predicate
Polars expression. The reader must filter
there rows accordingly.
n_rows:
Materialize only n rows from the source.
The reader can stop when `n_rows` are read.
batch_size
A hint of the ideal batch size the readers
generator must produce.
schema
Schema that the reader will produce before projection pushdown.
"""
return pl.LazyFrame._scan_python_function(
schema=schema, scan_fn=callable, pyarrow=False
)
4 changes: 2 additions & 2 deletions py-polars/polars/io/pyarrow_dataset/anonymous_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ def _scan_pyarrow_dataset_impl(
common_params["batch_size"] = batch_size

if n_rows:
yield from_arrow(ds.head(n_rows, **common_params)) # type: ignore[return-value]
return from_arrow(ds.head(n_rows, **common_params)) # type: ignore[return-value]

yield from_arrow(ds.to_table(**common_params)) # type: ignore[return-value]
return from_arrow(ds.to_table(**common_params)) # type: ignore[return-value]
1 change: 0 additions & 1 deletion py-polars/src/lazyframe/visit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,6 @@ impl NodeTraverser {
predicate: Default::default(),
n_rows: None,
},
predicate: None,
};
lp_arena.replace(self.root, ir);
}
Expand Down
Loading

0 comments on commit 2e00fdd

Please sign in to comment.