From 080e7eb0de06d49959f2f2b05a1f446adf12fadb Mon Sep 17 00:00:00 2001 From: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Date: Fri, 2 Aug 2024 16:07:19 +0800 Subject: [PATCH] [SPARK-49000][SQL][FOLLOWUP] Improve code style and update comments ### What changes were proposed in this pull request? Fix `RewriteDistinctAggregates` rule to deal properly with aggregation on DISTINCT literals. Physical plan for `select count(distinct 1) from t`: ``` -- count(distinct 1) == Physical Plan == AdaptiveSparkPlan isFinalPlan=false +- HashAggregate(keys=[], functions=[count(distinct 1)], output=[count(DISTINCT 1)#2L]) +- HashAggregate(keys=[], functions=[partial_count(distinct 1)], output=[count#6L]) +- HashAggregate(keys=[], functions=[], output=[]) +- Exchange SinglePartition, ENSURE_REQUIREMENTS, [plan_id=20] +- HashAggregate(keys=[], functions=[], output=[]) +- FileScan parquet spark_catalog.default.t[] Batched: false, DataFilters: [], Format: Parquet, Location: InMemoryFileIndex(1 paths)[file:/Users/nikola.mandic/oss-spark/spark-warehouse/org.apache.spark.s..., PartitionFilters: [], PushedFilters: [], ReadSchema: struct<> ``` Problem is happening when `HashAggregate(keys=[], functions=[], output=[])` node yields one row to `partial_count` node, which then captures one row. This four-node structure is constructed by `AggUtils.planAggregateWithOneDistinct`. To fix the problem, we're adding `Expand` node which will force non-empty grouping expressions in `HashAggregateExec` nodes. This will in turn enable streaming zero rows to parent `partial_count` node, yielding correct final result. ### Why are the changes needed? Aggregation with DISTINCT literal gives wrong results. For example, when running on empty table `t`: `select count(distinct 1) from t` returns 1, while the correct result should be 0. For reference: `select count(1) from t` returns 0, which is the correct and expected result. ### Does this PR introduce _any_ user-facing change? Yes, this fixes a critical bug in Spark. ### How was this patch tested? New e2e SQL tests for aggregates with DISTINCT literals. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #47565 from uros-db/SPARK-49000-followup. Authored-by: Uros Bojanic <157381213+uros-db@users.noreply.github.com> Signed-off-by: Kent Yao --- .../optimizer/RewriteDistinctAggregates.scala | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala index e91493188873e..801bd2693af42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteDistinctAggregates.scala @@ -198,13 +198,15 @@ import org.apache.spark.util.collection.Utils */ object RewriteDistinctAggregates extends Rule[LogicalPlan] { private def mustRewrite( - aggregateExpressions: Seq[AggregateExpression], + distinctAggs: Seq[AggregateExpression], groupingExpressions: Seq[Expression]): Boolean = { - // If there are any AggregateExpressions with filter, we need to rewrite the query. - // Also, if there are no grouping expressions and all aggregate expressions are foldable, - // we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1). - aggregateExpressions.exists(_.filter.isDefined) || (groupingExpressions.isEmpty && - aggregateExpressions.exists(_.aggregateFunction.children.forall(_.foldable))) + // If there are any distinct AggregateExpressions with filter, we need to rewrite the query. + // Also, if there are no grouping expressions and all distinct aggregate expressions are + // foldable, we need to rewrite the query, e.g. SELECT COUNT(DISTINCT 1). Without this case, + // non-grouping aggregation queries with distinct aggregate expressions will be incorrectly + // handled by the aggregation strategy, causing wrong results when working with empty tables. + distinctAggs.exists(_.filter.isDefined) || (groupingExpressions.isEmpty && + distinctAggs.exists(_.aggregateFunction.children.forall(_.foldable))) } private def mayNeedtoRewrite(a: Aggregate): Boolean = { @@ -213,7 +215,6 @@ object RewriteDistinctAggregates extends Rule[LogicalPlan] { // We need at least two distinct aggregates or the single distinct aggregate group exists filter // clause for this rule because aggregation strategy can handle a single distinct aggregate // group without filter clause. - // This check can produce false-positives, e.g., SUM(DISTINCT a) & COUNT(DISTINCT a). distinctAggs.size > 1 || mustRewrite(distinctAggs, a.groupingExpressions) }