diff --git a/datafusion/core/tests/dataframe/dataframe_functions.rs b/datafusion/core/tests/dataframe/dataframe_functions.rs index 28c0740ca76b..fec3ab786fce 100644 --- a/datafusion/core/tests/dataframe/dataframe_functions.rs +++ b/datafusion/core/tests/dataframe/dataframe_functions.rs @@ -1145,9 +1145,9 @@ async fn test_count_wildcard() -> Result<()> { .build() .unwrap(); - let expected = "Sort: count(Int64(1)) ASC NULLS LAST [count(Int64(1)):Int64]\ - \n Projection: count(Int64(1)) [count(Int64(1)):Int64]\ - \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1))]] [b:UInt32, count(Int64(1)):Int64]\ + let expected = "Sort: count(*) ASC NULLS LAST [count(*):Int64]\ + \n Projection: count(*) [count(*):Int64]\ + \n Aggregate: groupBy=[[test.b]], aggr=[[count(Int64(1)) AS count(*)]] [b:UInt32, count(*):Int64]\ \n TableScan: test [a:UInt32, b:UInt32, c:UInt32]"; let formatted_plan = plan.display_indent_schema().to_string(); diff --git a/datafusion/core/tests/dataframe/mod.rs b/datafusion/core/tests/dataframe/mod.rs index b134ec54b13d..3010144224d2 100644 --- a/datafusion/core/tests/dataframe/mod.rs +++ b/datafusion/core/tests/dataframe/mod.rs @@ -32,8 +32,9 @@ use arrow::datatypes::{ }; use arrow::error::ArrowError; use arrow::util::pretty::pretty_format_batches; -use datafusion_expr::utils::COUNT_STAR_EXPANSION; -use datafusion_functions_aggregate::count::{count_all, count_udaf}; +use datafusion_functions_aggregate::count::{ + count_all, count_all_column, count_all_window, count_all_window_column, +}; use datafusion_functions_aggregate::expr_fn::{ array_agg, avg, count, count_distinct, max, median, min, sum, }; @@ -2455,7 +2456,7 @@ async fn test_count_wildcard_on_sort() -> Result<()> { let ctx = create_join_context()?; let sql_results = ctx - .sql("select b,count(1) from t1 group by b order by count(1)") + .sql("select b, count(*) from t1 group by b order by count(*)") .await? .explain(false, false)? .collect() @@ -2469,9 +2470,52 @@ async fn test_count_wildcard_on_sort() -> Result<()> { .explain(false, false)? .collect() .await?; - //make sure sql plan same with df plan + + let expected_sql_result = "+---------------+------------------------------------------------------------------------------------------------------------+\ + \n| plan_type | plan |\ + \n+---------------+------------------------------------------------------------------------------------------------------------+\ + \n| logical_plan | Projection: t1.b, count(*) |\ + \n| | Sort: count(Int64(1)) AS count(*) AS count(*) ASC NULLS LAST |\ + \n| | Projection: t1.b, count(Int64(1)) AS count(*), count(Int64(1)) |\ + \n| | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1))]] |\ + \n| | TableScan: t1 projection=[b] |\ + \n| physical_plan | ProjectionExec: expr=[b@0 as b, count(*)@1 as count(*)] |\ + \n| | SortPreservingMergeExec: [count(Int64(1))@2 ASC NULLS LAST] |\ + \n| | SortExec: expr=[count(Int64(1))@2 ASC NULLS LAST], preserve_partitioning=[true] |\ + \n| | ProjectionExec: expr=[b@0 as b, count(Int64(1))@1 as count(*), count(Int64(1))@1 as count(Int64(1))] |\ + \n| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(Int64(1))] |\ + \n| | CoalesceBatchesExec: target_batch_size=8192 |\ + \n| | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 |\ + \n| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |\ + \n| | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(Int64(1))] |\ + \n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ + \n| | |\ + \n+---------------+------------------------------------------------------------------------------------------------------------+"; + assert_eq!( - pretty_format_batches(&sql_results)?.to_string(), + expected_sql_result, + pretty_format_batches(&sql_results)?.to_string() + ); + + let expected_df_result = "+---------------+--------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+--------------------------------------------------------------------------------+\ +\n| logical_plan | Sort: count(*) ASC NULLS LAST |\ +\n| | Aggregate: groupBy=[[t1.b]], aggr=[[count(Int64(1)) AS count(*)]] |\ +\n| | TableScan: t1 projection=[b] |\ +\n| physical_plan | SortPreservingMergeExec: [count(*)@1 ASC NULLS LAST] |\ +\n| | SortExec: expr=[count(*)@1 ASC NULLS LAST], preserve_partitioning=[true] |\ +\n| | AggregateExec: mode=FinalPartitioned, gby=[b@0 as b], aggr=[count(*)] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | RepartitionExec: partitioning=Hash([b@0], 4), input_partitions=4 |\ +\n| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |\ +\n| | AggregateExec: mode=Partial, gby=[b@0 as b], aggr=[count(*)] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+--------------------------------------------------------------------------------+"; + + assert_eq!( + expected_df_result, pretty_format_batches(&df_results)?.to_string() ); Ok(()) @@ -2481,12 +2525,35 @@ async fn test_count_wildcard_on_sort() -> Result<()> { async fn test_count_wildcard_on_where_in() -> Result<()> { let ctx = create_join_context()?; let sql_results = ctx - .sql("SELECT a,b FROM t1 WHERE a in (SELECT count(1) FROM t2)") + .sql("SELECT a, b FROM t1 WHERE a in (SELECT count(*) FROM t2)") .await? .explain(false, false)? .collect() .await?; + let expected_sql_result = "+---------------+------------------------------------------------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+------------------------------------------------------------------------------------------------------------------------+\ +\n| logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) |\ +\n| | TableScan: t1 projection=[a, b] |\ +\n| | SubqueryAlias: __correlated_sq_1 |\ +\n| | Projection: count(Int64(1)) AS count(*) |\ +\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\ +\n| | TableScan: t2 projection=[] |\ +\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] |\ +\n| | ProjectionExec: expr=[4 as count(*)] |\ +\n| | PlaceholderRowExec |\ +\n| | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+------------------------------------------------------------------------------------------------------------------------+"; + + assert_eq!( + expected_sql_result, + pretty_format_batches(&sql_results)?.to_string() + ); + // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1 // https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43 // for compare difference between sql and df logical plan, we need to create a new SessionContext here @@ -2509,9 +2576,26 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { .collect() .await?; + let actual_df_result= "+---------------+------------------------------------------------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+------------------------------------------------------------------------------------------------------------------------+\ +\n| logical_plan | LeftSemi Join: CAST(t1.a AS Int64) = __correlated_sq_1.count(*) |\ +\n| | TableScan: t1 projection=[a, b] |\ +\n| | SubqueryAlias: __correlated_sq_1 |\ +\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] |\ +\n| | TableScan: t2 projection=[] |\ +\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | HashJoinExec: mode=Partitioned, join_type=RightSemi, on=[(count(*)@0, CAST(t1.a AS Int64)@2)], projection=[a@0, b@1] |\ +\n| | ProjectionExec: expr=[4 as count(*)] |\ +\n| | PlaceholderRowExec |\ +\n| | ProjectionExec: expr=[a@0 as a, b@1 as b, CAST(a@0 AS Int64) as CAST(t1.a AS Int64)] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+------------------------------------------------------------------------------------------------------------------------+"; + // make sure sql plan same with df plan assert_eq!( - pretty_format_batches(&sql_results)?.to_string(), + actual_df_result, pretty_format_batches(&df_results)?.to_string() ); @@ -2522,11 +2606,34 @@ async fn test_count_wildcard_on_where_in() -> Result<()> { async fn test_count_wildcard_on_where_exist() -> Result<()> { let ctx = create_join_context()?; let sql_results = ctx - .sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(1) FROM t2)") + .sql("SELECT a, b FROM t1 WHERE EXISTS (SELECT count(*) FROM t2)") .await? .explain(false, false)? .collect() .await?; + + let actual_sql_result = + "+---------------+---------------------------------------------------------+\ + \n| plan_type | plan |\ + \n+---------------+---------------------------------------------------------+\ + \n| logical_plan | LeftSemi Join: |\ + \n| | TableScan: t1 projection=[a, b] |\ + \n| | SubqueryAlias: __correlated_sq_1 |\ + \n| | Projection: |\ + \n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\ + \n| | TableScan: t2 projection=[] |\ + \n| physical_plan | NestedLoopJoinExec: join_type=RightSemi |\ + \n| | ProjectionExec: expr=[] |\ + \n| | PlaceholderRowExec |\ + \n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ + \n| | |\ + \n+---------------+---------------------------------------------------------+"; + + assert_eq!( + actual_sql_result, + pretty_format_batches(&sql_results)?.to_string() + ); + let df_results = ctx .table("t1") .await? @@ -2545,9 +2652,24 @@ async fn test_count_wildcard_on_where_exist() -> Result<()> { .collect() .await?; - //make sure sql plan same with df plan + let actual_df_result = "+---------------+---------------------------------------------------------------------+\ + \n| plan_type | plan |\ + \n+---------------+---------------------------------------------------------------------+\ + \n| logical_plan | LeftSemi Join: |\ + \n| | TableScan: t1 projection=[a, b] |\ + \n| | SubqueryAlias: __correlated_sq_1 |\ + \n| | Projection: |\ + \n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] |\ + \n| | TableScan: t2 projection=[] |\ + \n| physical_plan | NestedLoopJoinExec: join_type=RightSemi |\ + \n| | ProjectionExec: expr=[] |\ + \n| | PlaceholderRowExec |\ + \n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ + \n| | |\ + \n+---------------+---------------------------------------------------------------------+"; + assert_eq!( - pretty_format_batches(&sql_results)?.to_string(), + actual_df_result, pretty_format_batches(&df_results)?.to_string() ); @@ -2559,34 +2681,62 @@ async fn test_count_wildcard_on_window() -> Result<()> { let ctx = create_join_context()?; let sql_results = ctx - .sql("select count(1) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1") + .sql("select count(*) OVER(ORDER BY a DESC RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING) from t1") .await? .explain(false, false)? .collect() .await?; + + let actual_sql_result = "+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\ +\n| logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING AS count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING |\ +\n| | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] |\ +\n| | TableScan: t1 projection=[a] |\ +\n| physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(*) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] |\ +\n| | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: \"count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(6)), end_bound: Following(UInt32(2)), is_causal: false }], mode=[Sorted] |\ +\n| | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+"; + + assert_eq!( + actual_sql_result, + pretty_format_batches(&sql_results)?.to_string() + ); + let df_results = ctx .table("t1") .await? - .select(vec![Expr::WindowFunction(WindowFunction::new( - WindowFunctionDefinition::AggregateUDF(count_udaf()), - vec![Expr::Literal(COUNT_STAR_EXPANSION)], - )) - .order_by(vec![Sort::new(col("a"), false, true)]) - .window_frame(WindowFrame::new_bounds( - WindowFrameUnits::Range, - WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), - WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), - )) - .build() - .unwrap()])? + .select(vec![count_all_window() + .order_by(vec![Sort::new(col("a"), false, true)]) + .window_frame(WindowFrame::new_bounds( + WindowFrameUnits::Range, + WindowFrameBound::Preceding(ScalarValue::UInt32(Some(6))), + WindowFrameBound::Following(ScalarValue::UInt32(Some(2))), + )) + .build() + .unwrap()])? .explain(false, false)? .collect() .await?; - //make sure sql plan same with df plan + let actual_df_result = "+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+\ +\n| logical_plan | Projection: count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING |\ +\n| | WindowAggr: windowExpr=[[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING]] |\ +\n| | TableScan: t1 projection=[a] |\ +\n| physical_plan | ProjectionExec: expr=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING@1 as count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING] |\ +\n| | BoundedWindowAggExec: wdw=[count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING: Ok(Field { name: \"count(Int64(1)) ORDER BY [t1.a DESC NULLS FIRST] RANGE BETWEEN 6 PRECEDING AND 2 FOLLOWING\", data_type: Int64, nullable: false, dict_id: 0, dict_is_ordered: false, metadata: {} }), frame: WindowFrame { units: Range, start_bound: Preceding(UInt32(6)), end_bound: Following(UInt32(2)), is_causal: false }], mode=[Sorted] |\ +\n| | SortExec: expr=[a@0 DESC], preserve_partitioning=[false] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+"; + assert_eq!( - pretty_format_batches(&df_results)?.to_string(), - pretty_format_batches(&sql_results)?.to_string() + actual_df_result, + pretty_format_batches(&df_results)?.to_string() ); Ok(()) @@ -2598,12 +2748,28 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { register_alltypes_tiny_pages_parquet(&ctx).await?; let sql_results = ctx - .sql("select count(1) from t1") + .sql("select count(*) from t1") .await? .explain(false, false)? .collect() .await?; + let actual_sql_result = + "+---------------+-----------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+-----------------------------------------------------+\ +\n| logical_plan | Projection: count(Int64(1)) AS count(*) |\ +\n| | Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] |\ +\n| | TableScan: t1 projection=[] |\ +\n| physical_plan | ProjectionExec: expr=[4 as count(*)] |\ +\n| | PlaceholderRowExec |\ +\n| | |\ +\n+---------------+-----------------------------------------------------+"; + assert_eq!( + actual_sql_result, + pretty_format_batches(&sql_results)?.to_string() + ); + // add `.select(vec![count_wildcard()])?` to make sure we can analyze all node instead of just top node. let df_results = ctx .table("t1") @@ -2614,26 +2780,77 @@ async fn test_count_wildcard_on_aggregate() -> Result<()> { .collect() .await?; - //make sure sql plan same with df plan + let actual_df_result = "+---------------+---------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+---------------------------------------------------------------+\ +\n| logical_plan | Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]] |\ +\n| | TableScan: t1 projection=[] |\ +\n| physical_plan | ProjectionExec: expr=[4 as count(*)] |\ +\n| | PlaceholderRowExec |\ +\n| | |\ +\n+---------------+---------------------------------------------------------------+"; assert_eq!( - pretty_format_batches(&sql_results)?.to_string(), + actual_df_result, pretty_format_batches(&df_results)?.to_string() ); Ok(()) } +#[tokio::test] +async fn test_count_wildcard_shema_name() { + assert_eq!(count_all().schema_name().to_string(), "count(*)"); + assert_eq!(count_all_column(), col("count(*)")); + assert_eq!( + count_all_window_column(), + col("count(Int64(1)) ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING") + ); +} + #[tokio::test] async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { let ctx = create_join_context()?; let sql_results = ctx - .sql("select a,b from t1 where (select count(1) from t2 where t1.a = t2.a)>0;") + .sql("select a,b from t1 where (select count(*) from t2 where t1.a = t2.a)>0;") .await? .explain(false, false)? .collect() .await?; + let actual_sql_result = "+---------------+---------------------------------------------------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------+\ +\n| logical_plan | Projection: t1.a, t1.b |\ +\n| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |\ +\n| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |\ +\n| | Left Join: t1.a = __scalar_sq_1.a |\ +\n| | TableScan: t1 projection=[a, b] |\ +\n| | SubqueryAlias: __scalar_sq_1 |\ +\n| | Projection: count(Int64(1)) AS count(*), t2.a, Boolean(true) AS __always_true |\ +\n| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1))]] |\ +\n| | TableScan: t2 projection=[a] |\ +\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | ProjectionExec: expr=[count(Int64(1))@1 as count(*), a@0 as a, true as __always_true] |\ +\n| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(Int64(1))] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 |\ +\n| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |\ +\n| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(Int64(1))] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------+"; + assert_eq!( + actual_sql_result, + pretty_format_batches(&sql_results)?.to_string() + ); + // In the same SessionContext, AliasGenerator will increase subquery_alias id by 1 // https://github.com/apache/datafusion/blame/cf45eb9020092943b96653d70fafb143cc362e19/datafusion/optimizer/src/alias.rs#L40-L43 // for compare difference between sql and df logical plan, we need to create a new SessionContext here @@ -2647,7 +2864,7 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { .await? .filter(out_ref_col(DataType::UInt32, "t1.a").eq(col("t2.a")))? .aggregate(vec![], vec![count_all()])? - .select(vec![col(count_all().to_string())])? + .select(vec![count_all_column()])? .into_unoptimized_plan(), )) .gt(lit(ScalarValue::UInt8(Some(0)))), @@ -2657,9 +2874,36 @@ async fn test_count_wildcard_on_where_scalar_subquery() -> Result<()> { .collect() .await?; - //make sure sql plan same with df plan + let actual_df_result = "+---------------+---------------------------------------------------------------------------------------------------------------------------+\ +\n| plan_type | plan |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------+\ +\n| logical_plan | Projection: t1.a, t1.b |\ +\n| | Filter: CASE WHEN __scalar_sq_1.__always_true IS NULL THEN Int64(0) ELSE __scalar_sq_1.count(*) END > Int64(0) |\ +\n| | Projection: t1.a, t1.b, __scalar_sq_1.count(*), __scalar_sq_1.__always_true |\ +\n| | Left Join: t1.a = __scalar_sq_1.a |\ +\n| | TableScan: t1 projection=[a, b] |\ +\n| | SubqueryAlias: __scalar_sq_1 |\ +\n| | Projection: count(*), t2.a, Boolean(true) AS __always_true |\ +\n| | Aggregate: groupBy=[[t2.a]], aggr=[[count(Int64(1)) AS count(*)]] |\ +\n| | TableScan: t2 projection=[a] |\ +\n| physical_plan | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | FilterExec: CASE WHEN __always_true@3 IS NULL THEN 0 ELSE count(*)@2 END > 0, projection=[a@0, b@1] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | HashJoinExec: mode=Partitioned, join_type=Left, on=[(a@0, a@1)], projection=[a@0, b@1, count(*)@2, __always_true@4] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=1 |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | ProjectionExec: expr=[count(*)@1 as count(*), a@0 as a, true as __always_true] |\ +\n| | AggregateExec: mode=FinalPartitioned, gby=[a@0 as a], aggr=[count(*)] |\ +\n| | CoalesceBatchesExec: target_batch_size=8192 |\ +\n| | RepartitionExec: partitioning=Hash([a@0], 4), input_partitions=4 |\ +\n| | RepartitionExec: partitioning=RoundRobinBatch(4), input_partitions=1 |\ +\n| | AggregateExec: mode=Partial, gby=[a@0 as a], aggr=[count(*)] |\ +\n| | DataSourceExec: partitions=1, partition_sizes=[1] |\ +\n| | |\ +\n+---------------+---------------------------------------------------------------------------------------------------------------------------+"; assert_eq!( - pretty_format_batches(&sql_results)?.to_string(), + actual_df_result, pretty_format_batches(&df_results)?.to_string() ); @@ -4228,7 +4472,9 @@ fn create_join_context() -> Result { ], )?; - let ctx = SessionContext::new(); + let config = SessionConfig::new().with_target_partitions(4); + let ctx = SessionContext::new_with_config(config); + // let ctx = SessionContext::new(); ctx.register_batch("t1", batch1)?; ctx.register_batch("t2", batch2)?; diff --git a/datafusion/functions-aggregate/src/count.rs b/datafusion/functions-aggregate/src/count.rs index a3339f0fceb9..5afe2b0584eb 100644 --- a/datafusion/functions-aggregate/src/count.rs +++ b/datafusion/functions-aggregate/src/count.rs @@ -17,6 +17,7 @@ use ahash::RandomState; use datafusion_common::stats::Precision; +use datafusion_expr::expr::WindowFunction; use datafusion_functions_aggregate_common::aggregate::count_distinct::BytesViewDistinctCountAccumulator; use datafusion_macros::user_doc; use datafusion_physical_expr::expressions; @@ -47,11 +48,13 @@ use datafusion_common::{ downcast_value, internal_err, not_impl_err, Result, ScalarValue, }; use datafusion_expr::function::StateFieldsArgs; +use datafusion_expr::{ + col, Expr, ReversedUDAF, StatisticsArgs, TypeSignature, WindowFunctionDefinition, +}; use datafusion_expr::{ function::AccumulatorArgs, utils::format_state_name, Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, SetMonotonicity, Signature, Volatility, }; -use datafusion_expr::{Expr, ReversedUDAF, StatisticsArgs, TypeSignature}; use datafusion_functions_aggregate_common::aggregate::count_distinct::{ BytesDistinctCountAccumulator, FloatDistinctCountAccumulator, PrimitiveDistinctCountAccumulator, @@ -80,8 +83,34 @@ pub fn count_distinct(expr: Expr) -> Expr { } /// Creates aggregation to count all rows, equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)` +/// Alias to count(*) for backward comaptibility pub fn count_all() -> Expr { - count(Expr::Literal(COUNT_STAR_EXPANSION)) + count(Expr::Literal(COUNT_STAR_EXPANSION)).alias("count(*)") +} + +/// Creates window aggregation to count all rows, equivalent to `COUNT(*)`, `COUNT()`, `COUNT(1)` +pub fn count_all_window() -> Expr { + Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(count_udaf()), + vec![Expr::Literal(COUNT_STAR_EXPANSION)], + )) +} + +/// Expr::Column(Count Wildcard Window Function) +/// Could be used in Dataframe API where you need Expr::Column of count wildcard +pub fn count_all_window_column() -> Expr { + col(Expr::WindowFunction(WindowFunction::new( + WindowFunctionDefinition::AggregateUDF(count_udaf()), + vec![Expr::Literal(COUNT_STAR_EXPANSION)], + )) + .schema_name() + .to_string()) +} + +/// Expr::Column(Count Wildcard Aggregate Function) +/// Could be used in Dataframe API where you need Expr::Column of count wildcard +pub fn count_all_column() -> Expr { + col(count_all().schema_name().to_string()) } #[user_doc( diff --git a/datafusion/sqllogictest/test_files/subquery.slt b/datafusion/sqllogictest/test_files/subquery.slt index 94c9eaf810fb..207bb72fd549 100644 --- a/datafusion/sqllogictest/test_files/subquery.slt +++ b/datafusion/sqllogictest/test_files/subquery.slt @@ -1393,3 +1393,41 @@ item1 1970-01-01T00:00:03 75 statement ok drop table source_table; + +statement count 0 +drop table t1; + +statement count 0 +drop table t2; + +statement count 0 +drop table t3; + +# test count wildcard +statement count 0 +create table t1(a int) as values (1); + +statement count 0 +create table t2(b int) as values (1); + +query I +SELECT a FROM t1 WHERE EXISTS (SELECT count(*) FROM t2) +---- +1 + +query TT +explain SELECT a FROM t1 WHERE EXISTS (SELECT count(*) FROM t2) +---- +logical_plan +01)LeftSemi Join: +02)--TableScan: t1 projection=[a] +03)--SubqueryAlias: __correlated_sq_1 +04)----Projection: +05)------Aggregate: groupBy=[[]], aggr=[[count(Int64(1))]] +06)--------TableScan: t2 projection=[] + +statement count 0 +drop table t1; + +statement count 0 +drop table t2;