From ed2a3f5d89fa7bad9fdb54eb5c5f935e2fb907df Mon Sep 17 00:00:00 2001 From: Jiaan Geng Date: Sat, 5 Aug 2023 13:52:39 +0800 Subject: [PATCH] Update code --- .../optimizer/CombineJoinedAggregates.scala | 13 +++++++++---- .../apache/spark/sql/DataFrameAggregateSuite.scala | 12 ++++++++++++ 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala index f075cfb904d7d..b0e16e8a7f308 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala @@ -110,12 +110,17 @@ object CombineJoinedAggregates case (lf: Filter, rf: Filter) => val mergedChildPlan = mergePlan(lf.child, rf.child) mergedChildPlan.map { - case (newChild, outputMap, _) + case (newChild, outputMap, filters) if isLikelySelective(lf.condition) && isLikelySelective(rf.condition) => - val rightCondition = mapAttributes(rf.condition, outputMap) - val newCondition = Or(lf.condition, rightCondition) + val mappedRightCondition = mapAttributes(rf.condition, outputMap) + val (newLeftCondition, newRightCondition) = if (filters.length == 2) { + (And(lf.condition, filters.head), And(mappedRightCondition, filters.last)) + } else { + (lf.condition, mappedRightCondition) + } + val newCondition = Or(newLeftCondition, newRightCondition) - (Filter(newCondition, newChild), outputMap, Seq(lf.condition, rightCondition)) + (Filter(newCondition, newChild), outputMap, Seq(newLeftCondition, newRightCondition)) } case (ll: LeafNode, rl: LeafNode) => checkIdenticalPlans(rl, ll).map { outputMap => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 4b838bf47ecf8..209b95d52eff2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkException, SparkThrowable} import org.apache.spark.sql.catalyst.expressions.EqualTo import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Count, Sum} +import org.apache.spark.sql.catalyst.optimizer.PushDownPredicates import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -2315,6 +2316,17 @@ class DataFrameAggregateSuite extends QueryTest df.where($"date" === 20151124).agg(new Column(avgWithFilter).as("avg_temp"))).join( df.where($"date" === 20151125).agg(count($"temp").as("count_temp"))) checkAnswer(join27, Row(84.0, 18.25, 5)) + + Seq(PushDownPredicates.ruleName, "").map { ruleName => + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> ruleName) { + val subQuery1 = df.where($"date" === 20151123).as("tab1") + val subQuery2 = df.where($"date" === 20151124).as("tab2") + val join28 = + subQuery1.where($"tab1.minute" > 30).agg(sum($"tab1.temp").as("sum_temp")).join( + subQuery2.where($"tab2.minute" < 30).agg(avg($"tab2.temp").as("avg_temp"))) + checkAnswer(join28, Row(84.0, 24.600000000000005)) + } + } } } }