Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

perf: Fix quadratic 'with_columns' behavior #19701

Merged
merged 5 commits into from
Nov 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions crates/polars-core/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1382,12 +1382,24 @@ impl DataFrame {
self
}

// Note: Schema can be both input or output_schema
fn add_column_by_schema(&mut self, c: Column, schema: &Schema) -> PolarsResult<()> {
let name = c.name();
if let Some((idx, _, _)) = schema.get_full(name.as_str()) {
// schema is incorrect fallback to search
if self.columns.get(idx).map(|s| s.name()) != Some(name) {
self.add_column_by_search(c)?;
// Given schema is output_schema and we can push.
if idx == self.columns.len() {
if self.width() == 0 {
self.height = c.len();
}

self.columns.push(c);
}
// Schema is incorrect fallback to search
else {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self, add a debug assert here.

debug_assert!(false);
self.add_column_by_search(c)?;
}
} else {
self.replace_column(idx, c)?;
}
Expand All @@ -1401,6 +1413,7 @@ impl DataFrame {
Ok(())
}

// Note: Schema can be both input or output_schema
pub fn _add_series(&mut self, series: Vec<Series>, schema: &Schema) -> PolarsResult<()> {
for (i, s) in series.into_iter().enumerate() {
// we need to branch here
Expand Down Expand Up @@ -1430,6 +1443,8 @@ impl DataFrame {
/// Add a new column to this [`DataFrame`] or replace an existing one.
/// Uses an existing schema to amortize lookups.
/// If the schema is incorrect, we will fallback to linear search.
///
/// Note: Schema can be both input or output_schema
pub fn with_column_and_schema<C: IntoColumn>(
&mut self,
column: C,
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-mem-engine/src/executors/stack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ pub struct StackExec {
pub(crate) has_windows: bool,
pub(crate) exprs: Vec<Arc<dyn PhysicalExpr>>,
pub(crate) input_schema: SchemaRef,
pub(crate) output_schema: SchemaRef,
pub(crate) options: ProjectionOptions,
// Can run all operations elementwise
pub(crate) streamable: bool,
Expand All @@ -19,7 +20,7 @@ impl StackExec {
state: &ExecutionState,
mut df: DataFrame,
) -> PolarsResult<DataFrame> {
let schema = &*self.input_schema;
let schema = &*self.output_schema;

// Vertical and horizontal parallelism.
let df = if self.streamable
Expand Down
3 changes: 2 additions & 1 deletion crates/polars-mem-engine/src/planner/lp.rs
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,7 @@ fn create_physical_plan_impl(
HStack {
input,
exprs,
schema: _schema,
schema: output_schema,
options,
} => {
let input_schema = lp_arena.get(input).schema(lp_arena).into_owned();
Expand Down Expand Up @@ -659,6 +659,7 @@ fn create_physical_plan_impl(
has_windows: state.has_windows,
exprs: phys_exprs,
input_schema,
output_schema,
options,
streamable,
}))
Expand Down
21 changes: 21 additions & 0 deletions py-polars/tests/benchmark/test_with_columns.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import time

import pytest

import polars as pl


# TODO: this is slow in streaming
@pytest.mark.may_fail_auto_streaming
def test_with_columns_quadratic_19503() -> None:
num_columns = 2000
data1 = {f"col_{i}": [0] for i in range(num_columns)}
df1 = pl.DataFrame(data1)

data2 = {f"feature_{i}": [0] for i in range(num_columns)}
df2 = pl.DataFrame(data2)

t0 = time.time()
df1.with_columns(df2)
t1 = time.time()
assert t1 - t0 < 0.2
Loading