diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 94229e25436c5..21765f2c9b6fe 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -1645,6 +1645,7 @@ mod tests { use arrow_array::UInt64Array; use datafusion_expr::expr::InList; use datafusion_expr::{cast, is_null, try_cast, Expr}; + use datafusion_functions_nested::expr_fn::{array_has, make_array}; use datafusion_physical_expr::expressions as phys_expr; use datafusion_physical_expr::planner::logical2physical; @@ -3469,70 +3470,64 @@ mod tests { /// This handles an arbitrary case of a column that doesn't exist in the schema /// by renaming it to yet another column that doesn't exist in the schema /// (the transformation is arbitrary, the point is that it can do whatever it wants) - fn handle(&self, expr: &Arc) -> Arc { - if let Some(expr) = expr.as_any().downcast_ref::() { - let left = expr.left(); - let right = expr.right(); - if let Some(column) = left.as_any().downcast_ref::() { - if column.name() == "b" && right.as_any().downcast_ref::().is_some() { - let new_column = Arc::new(phys_expr::Column::new("c", column.index())) as _; - return Arc::new(phys_expr::BinaryExpr::new( - new_column, - *expr.op(), - right.clone(), - )); - } - } - } - - Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) - as Arc + fn handle(&self, _expr: &Arc) -> Arc { + Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(42)))) } } let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - - let expr = Arc::new(phys_expr::BinaryExpr::new( - Arc::new(phys_expr::Column::new("b", 1)), - Operator::Eq, - Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(12)))), - )) as Arc; - - let expected_expr = Arc::new(phys_expr::BinaryExpr::new( - Arc::new(phys_expr::Column::new("c", 1)), - Operator::Eq, - Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(12)))), - )) as Arc; - - let handler = Arc::new(CustomUnhandledHook {}) as _; - let actual_expr = rewrite_predicate_to_statistics_predicate( - &expr, - &schema, - &handler, + let schema_with_b = Schema::new( + vec![Field::new("a", DataType::Int32, true), Field::new("b", DataType::Int32, true)], ); - assert_eq!(actual_expr.to_string(), expected_expr.to_string()); - - // but other cases do end up as `true` - - let expr = Arc::new(phys_expr::BinaryExpr::new( - Arc::new(phys_expr::Column::new("d", 1)), - Operator::Eq, - Arc::new(phys_expr::Literal::new(ScalarValue::Int32(Some(12)))), - )) as Arc; - - let expected_expr = - Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))) - as Arc; + let transform_expr = |expr| { + let expr = logical2physical(&expr, &schema_with_b); + rewrite_predicate_to_statistics_predicate( + &expr, + &schema, + Some(Arc::new(CustomUnhandledHook {})), + ) + }; - let handler = Arc::new(CustomUnhandledHook {}) as _; - let actual_expr = rewrite_predicate_to_statistics_predicate( - &expr, + // transform an arbitrary valid expression that we know is handled + let known_expression = col("a").eq(lit(ScalarValue::Int32(Some(12)))); + let known_expression_transformed = rewrite_predicate_to_statistics_predicate( + &logical2physical(&known_expression, &schema), &schema, - &handler, + None, ); - assert_eq!(actual_expr.to_string(), expected_expr.to_string()); + // an expression referencing an unknown column (that is not in the schema) gets passed to the hook + let input = col("b").eq(lit(ScalarValue::Int32(Some(12)))); + let expected = logical2physical(&lit(42), &schema); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // more complex case with unknown column + let input = known_expression.clone().and(input.clone()); + let expected = phys_expr::BinaryExpr::new( + known_expression_transformed.clone(), + Operator::And, + logical2physical(&lit(42), &schema), + ); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // an unknown expression gets passed to the hook + let input = array_has(make_array(vec![lit(1)]), col("a")); + let expected = logical2physical(&lit(42), &schema); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); + + // more complex case with unknown expression + let input = known_expression.and(input); + let expected = phys_expr::BinaryExpr::new( + known_expression_transformed.clone(), + Operator::And, + logical2physical(&lit(42), &schema), + ); + let transformed = transform_expr(input.clone()); + assert_eq!(transformed.to_string(), expected.to_string()); } #[test]