Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test SQL join #16542

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 7 additions & 4 deletions crates/polars-lazy/src/frame/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1068,8 +1068,13 @@ impl LazyFrame {

/// Creates the Cartesian product from both frames, preserving the order of the left keys.
#[cfg(feature = "cross_join")]
pub fn cross_join(self, other: LazyFrame) -> LazyFrame {
self.join(other, vec![], vec![], JoinArgs::new(JoinType::Cross))
pub fn cross_join(self, other: LazyFrame, suffix: Option<String>) -> LazyFrame {
self.join(
other,
vec![],
vec![],
JoinArgs::new(JoinType::Cross).with_suffix(suffix),
)
}

/// Left outer join this query with another lazy query.
Expand Down Expand Up @@ -1220,9 +1225,7 @@ impl LazyFrame {
if let Some(suffix) = args.suffix {
builder = builder.suffix(suffix);
}

// Note: args.slice is set by the optimizer

builder.finish()
}

Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/tests/cse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ fn test_cse_columns_projections() -> PolarsResult<()> {
]?
.lazy();

let left = left.cross_join(right.clone().select([col("A")]));
let left = left.cross_join(right.clone().select([col("A")]), None);
let q = left.join(
right.rename(["B"], ["C"]),
[col("A"), col("C")],
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/tests/projection_queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn test_cross_join_pd() -> PolarsResult<()> {
"price" => [5, 4]
]?;

let q = food.lazy().cross_join(drink.lazy()).select([
let q = food.lazy().cross_join(drink.lazy(), None).select([
col("name").alias("food"),
col("name_right").alias("beverage"),
(col("price") + col("price_right")).alias("total"),
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-lazy/src/tests/queries.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1171,7 +1171,7 @@ fn test_cross_join() -> PolarsResult<()> {
"b" => [None, Some(12)]
]?;

let out = df1.lazy().cross_join(df2.lazy()).collect()?;
let out = df1.lazy().cross_join(df2.lazy(), None).collect()?;
assert_eq!(out.shape(), (6, 4));
Ok(())
}
Expand Down
14 changes: 9 additions & 5 deletions crates/polars-lazy/src/tests/streaming.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ fn test_streaming_union_order() -> PolarsResult<()> {
fn test_streaming_union_join() -> PolarsResult<()> {
let q = get_csv_glob();
let q = q.select([col("sugars_g"), col("calories")]);
let q = q.clone().cross_join(q);
let q = q.clone().cross_join(q, None);

assert_streaming_with_default(q, true, true);
Ok(())
Expand Down Expand Up @@ -166,18 +166,22 @@ fn test_streaming_cross_join() -> PolarsResult<()> {
"a" => [1 ,2, 3]
]?;
let q = df.lazy();
let out = q.clone().cross_join(q).with_streaming(true).collect()?;
let out = q
.clone()
.cross_join(q, None)
.with_streaming(true)
.collect()?;
assert_eq!(out.shape(), (9, 2));

let q = get_parquet_file().with_projection_pushdown(false);
let q1 = q
.clone()
.select([col("calories")])
.cross_join(q.clone())
.cross_join(q.clone(), None)
.filter(col("calories").gt(col("calories_right")));
let q2 = q1
.select([all().name().suffix("_second")])
.cross_join(q)
.cross_join(q, None)
.filter(col("calories_right_second").lt(col("calories")))
.select([
col("calories"),
Expand Down Expand Up @@ -266,7 +270,7 @@ fn test_streaming_slice() -> PolarsResult<()> {
]?
.lazy();

let q = lf_a.clone().cross_join(lf_a).slice(10, 20);
let q = lf_a.clone().cross_join(lf_a, None).slice(10, 20);
let a = q.with_streaming(true).collect().unwrap();
assert_eq!(a.shape(), (20, 2));

Expand Down
5 changes: 5 additions & 0 deletions crates/polars-ops/src/frame/join/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,11 @@ impl JoinArgs {
self
}

pub fn with_suffix(mut self, suffix: Option<String>) -> Self {
self.suffix = suffix;
self
}

pub fn suffix(&self) -> &str {
self.suffix.as_deref().unwrap_or("_right")
}
Expand Down
2 changes: 1 addition & 1 deletion crates/polars-plan/src/dsl/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ pub enum Excluded {
impl Expr {
/// Get Field result of the expression. The schema is the input data.
pub fn to_field(&self, schema: &Schema, ctxt: Context) -> PolarsResult<Field> {
// this is not called much and th expression depth is typically shallow
// this is not called much and the expression depth is typically shallow
let mut arena = Arena::with_capacity(5);
self.to_field_amortized(schema, ctxt, &mut arena)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ pub(super) fn process_join(
already_added_local_to_local_projected.insert(local_name);
}
// In full outer joins both columns remain. So `add_local=true` also for the right table
let add_local = matches!(options.args.how, JoinType::Full)
&& !options.args.coalesce.coalesce(&options.args.how);
let add_local = !options.args.coalesce.coalesce(&options.args.how);
for e in &right_on {
// In case of full outer joins we also add the columns.
// But before we do that we must check if the column wasn't already added by the lhs.
Expand Down Expand Up @@ -442,7 +441,7 @@ fn resolve_join_suffixes(
.iter()
.map(|proj| {
let name = column_node_to_name(*proj, expr_arena);
if name.contains(suffix) && schema_after_join.get(&name).is_none() {
if name.ends_with(suffix) && schema_after_join.get(&name).is_none() {
let downstream_name = &name.as_ref()[..name.len() - suffix.len()];
let col = AExpr::Column(ColumnName::from(downstream_name));
let node = expr_arena.add(col);
Expand Down
1 change: 1 addition & 0 deletions crates/polars-sql/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ arrow = { workspace = true }
polars-core = { workspace = true }
polars-error = { workspace = true }
polars-lazy = { workspace = true, features = ["abs", "binary_encoding", "concat_str", "cross_join", "cum_agg", "dtype-date", "dtype-decimal", "is_in", "list_eval", "log", "meta", "regex", "round_series", "sign", "string_reverse", "strings", "timezones", "trigonometry"] }
polars-ops = { workspace = true }
polars-plan = { workspace = true }

hex = { workspace = true }
Expand Down
139 changes: 104 additions & 35 deletions crates/polars-sql/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,19 @@ use std::cell::RefCell;
use polars_core::prelude::*;
use polars_error::to_compute_err;
use polars_lazy::prelude::*;
use polars_ops::frame::JoinCoalesce;
use polars_plan::prelude::*;
use sqlparser::ast::{
Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinOperator,
ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr, SetOperator,
SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator,
Distinct, ExcludeSelectItem, Expr as SQLExpr, FunctionArg, GroupByExpr, JoinConstraint,
JoinOperator, ObjectName, ObjectType, Offset, OrderByExpr, Query, Select, SelectItem, SetExpr,
SetOperator, SetQuantifier, Statement, TableAlias, TableFactor, TableWithJoins, UnaryOperator,
Value as SQLValue, WildcardAdditionalOptions,
};
use sqlparser::dialect::GenericDialect;
use sqlparser::parser::{Parser, ParserOptions};

use crate::function_registry::{DefaultFunctionRegistry, FunctionRegistry};
use crate::sql_expr::{parse_sql_expr, process_join};
use crate::sql_expr::{parse_sql_expr, process_join_constraint};
use crate::table_functions::PolarsTableFunctions;

/// The SQLContext is the main entry point for executing SQL queries.
Expand All @@ -23,7 +24,8 @@ pub struct SQLContext {
pub(crate) table_map: PlHashMap<String, LazyFrame>,
pub(crate) function_registry: Arc<dyn FunctionRegistry>,
cte_map: RefCell<PlHashMap<String, LazyFrame>>,
aliases: RefCell<PlHashMap<String, String>>,
table_aliases: RefCell<PlHashMap<String, String>>,
joined_aliases: RefCell<PlHashMap<String, PlHashMap<String, String>>>,
}

impl Default for SQLContext {
Expand All @@ -32,7 +34,8 @@ impl Default for SQLContext {
function_registry: Arc::new(DefaultFunctionRegistry {}),
table_map: Default::default(),
cte_map: Default::default(),
aliases: Default::default(),
table_aliases: Default::default(),
joined_aliases: Default::default(),
}
}
}
Expand Down Expand Up @@ -110,11 +113,16 @@ impl SQLContext {
.map_err(to_compute_err)?
.parse_statements()
.map_err(to_compute_err)?;
polars_ensure!(ast.len() == 1, ComputeError: "One and only one statement at a time please");

polars_ensure!(ast.len() == 1, ComputeError: "One (and only one) statement at a time please");

let res = self.execute_statement(ast.first().unwrap());
// Every execution should clear the CTE map.

// Every execution should clear the statement-level maps.
self.cte_map.borrow_mut().clear();
self.aliases.borrow_mut().clear();
self.table_aliases.borrow_mut().clear();
self.joined_aliases.borrow_mut().clear();

res
}

Expand All @@ -137,22 +145,6 @@ impl SQLContext {
}

impl SQLContext {
fn register_cte(&mut self, name: &str, lf: LazyFrame) {
self.cte_map.borrow_mut().insert(name.to_owned(), lf);
}

pub(super) fn get_table_from_current_scope(&self, name: &str) -> Option<LazyFrame> {
let table_name = self.table_map.get(name).cloned();
table_name
.or_else(|| self.cte_map.borrow().get(name).cloned())
.or_else(|| {
self.aliases
.borrow()
.get(name)
.and_then(|alias| self.table_map.get(alias).cloned())
})
}

pub(crate) fn execute_statement(&mut self, stmt: &Statement) -> PolarsResult<LazyFrame> {
let ast = stmt;
Ok(match ast {
Expand Down Expand Up @@ -183,6 +175,31 @@ impl SQLContext {
self.process_limit_offset(lf, &query.limit, &query.offset)
}

pub(super) fn get_table_from_current_scope(&self, name: &str) -> Option<LazyFrame> {
let table_name = self.table_map.get(name).cloned();
table_name
.or_else(|| self.cte_map.borrow().get(name).cloned())
.or_else(|| {
self.table_aliases
.borrow()
.get(name)
.and_then(|alias| self.table_map.get(alias).cloned())
})
}

pub(super) fn resolve_name(&self, tbl_name: &str, column_name: &str) -> String {
if self.joined_aliases.borrow().contains_key(tbl_name) {
self.joined_aliases
.borrow()
.get(tbl_name)
.and_then(|aliases| aliases.get(column_name))
.cloned()
.unwrap_or_else(|| column_name.to_string())
} else {
column_name.to_string()
}
}

fn process_set_expr(&mut self, expr: &SetExpr, query: &Query) -> PolarsResult<LazyFrame> {
match expr {
SetExpr::Select(select_stmt) => self.execute_select(select_stmt, query),
Expand Down Expand Up @@ -296,6 +313,10 @@ impl SQLContext {
}
}

fn register_cte(&mut self, name: &str, lf: LazyFrame) {
self.cte_map.borrow_mut().insert(name.to_owned(), lf);
}

fn register_ctes(&mut self, query: &Query) -> PolarsResult<()> {
if let Some(with) = &query.with {
if with.recursive {
Expand All @@ -316,40 +337,63 @@ impl SQLContext {
if !tbl_expr.joins.is_empty() {
for tbl in &tbl_expr.joins {
let (r_name, rf) = self.get_table(&tbl.relation)?;
let left_schema = lf.schema()?;
let right_schema = rf.schema()?;

lf = match &tbl.join_operator {
JoinOperator::CrossJoin => lf.cross_join(rf),
JoinOperator::FullOuter(constraint) => {
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Full)?
self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Full)?
},
JoinOperator::Inner(constraint) => {
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Inner)?
self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Inner)?
},
JoinOperator::LeftOuter(constraint) => {
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Left)?
self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Left)?
},
#[cfg(feature = "semi_anti_join")]
JoinOperator::LeftAnti(constraint) => {
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Anti)?
self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Anti)?
},
#[cfg(feature = "semi_anti_join")]
JoinOperator::LeftSemi(constraint) => {
process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Semi)?
self.process_join(lf, rf, constraint, &l_name, &r_name, JoinType::Semi)?
},
#[cfg(feature = "semi_anti_join")]
JoinOperator::RightAnti(constraint) => {
process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Anti)?
self.process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Anti)?
},
#[cfg(feature = "semi_anti_join")]
JoinOperator::RightSemi(constraint) => {
process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Semi)?
self.process_join(rf, lf, constraint, &l_name, &r_name, JoinType::Semi)?
},
JoinOperator::CrossJoin => lf.cross_join(rf, Some(format!(":{}", r_name))),
join_type => {
polars_bail!(
InvalidOperation:
"join type '{:?}' not yet supported by polars-sql", join_type
);
},
}
};

// track join-aliased columns so we can resolve them later
let joined_schema = lf.schema()?;
self.joined_aliases.borrow_mut().insert(
r_name.to_string(),
right_schema
.iter_names()
.filter_map(|name| {
// col exists in both tables and is aliased in the joined result
let aliased_name = format!("{}:{}", name, r_name);
if left_schema.contains(name)
&& joined_schema.contains(aliased_name.as_str())
{
Some((name.to_string(), aliased_name))
} else {
None
}
})
.collect::<PlHashMap<String, String>>(),
);
}
};
Ok(lf)
Expand Down Expand Up @@ -578,6 +622,31 @@ impl SQLContext {
Ok(lf)
}

pub(super) fn process_join(
&self,
left_tbl: LazyFrame,
right_tbl: LazyFrame,
constraint: &JoinConstraint,
tbl_name: &str,
join_tbl_name: &str,
join_type: JoinType,
) -> PolarsResult<LazyFrame> {
let (left_on, right_on) = process_join_constraint(constraint, tbl_name, join_tbl_name)?;

let joined_tbl = left_tbl
.clone()
.join_builder()
.with(right_tbl.clone())
.left_on(left_on)
.right_on(right_on)
.how(join_type)
.suffix(format!(":{}", join_tbl_name))
.coalesce(JoinCoalesce::KeepColumns)
.finish();

Ok(joined_tbl)
}

fn process_subqueries(&self, lf: LazyFrame, exprs: Vec<&mut Expr>) -> LazyFrame {
let mut contexts = vec![];
for expr in exprs {
Expand Down Expand Up @@ -644,7 +713,7 @@ impl SQLContext {
if let Some(lf) = self.get_table_from_current_scope(tbl_name) {
match alias {
Some(alias) => {
self.aliases
self.table_aliases
.borrow_mut()
.insert(alias.name.value.clone(), tbl_name.to_string());
Ok((alias.to_string(), lf))
Expand Down
Loading
Loading