diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala index a80c2eef955b9..3bdf177d0accd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregates.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.collection.mutable.ArrayBuffer -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression, Or} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeMap, Expression, NamedExpression, Or} import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftOuter, RightOuter} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LeafNode, LogicalPlan, Project, SerializeFromObject} @@ -32,84 +32,44 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, JOIN} * every [[Join]] are [[Aggregate]]s. * * Note: this rule doesn't following cases: - * 1. The [[Aggregate]]s to be merged exists filter clause in aggregate expressions. - * 2. One of the to be merged two [[Aggregate]]s with child [[Filter]] and the other one is not. - * 3. The upstream node of these [[Aggregate]]s to be merged exists [[Join]]. + * 1. One of the to be merged two [[Aggregate]]s with child [[Filter]] and the other one is not. + * 2. The upstream node of these [[Aggregate]]s to be merged exists [[Join]]. */ -object CombineJoinedAggregates extends Rule[LogicalPlan] { +object CombineJoinedAggregates extends Rule[LogicalPlan] with MergeScalarSubqueriesHelper { private def isSupportedJoinType(joinType: JoinType): Boolean = Seq(Inner, Cross, LeftOuter, RightOuter, FullOuter).contains(joinType) - // Collect all the Aggregates from both side of single or nested Join. - private def collectAggregate(plan: LogicalPlan, aggregates: ArrayBuffer[Aggregate]): Boolean = { - var flag = true - if (plan.containsAnyPattern(JOIN, AGGREGATE)) { - plan match { - case Join(left: Aggregate, right: Aggregate, _, None, _) - if left.groupingExpressions.isEmpty && right.groupingExpressions.isEmpty && - left.aggregateExpressions.forall(filterNotDefined) && - right.aggregateExpressions.forall(filterNotDefined) => - aggregates += left - aggregates += right - case Join(left @ Join(_, _, joinType, None, _), right: Aggregate, _, None, _) - if isSupportedJoinType(joinType) && right.groupingExpressions.isEmpty && - right.aggregateExpressions.forall(filterNotDefined) => - flag = collectAggregate(left, aggregates) - aggregates += right - case Join(left: Aggregate, right @ Join(_, _, joinType, None, _), _, None, _) - if isSupportedJoinType(joinType) && left.groupingExpressions.isEmpty && - left.aggregateExpressions.forall(filterNotDefined) => - aggregates += left - flag = collectAggregate(right, aggregates) - // The side of Join is neither Aggregate nor Join. - case _ => flag = false - } - } - - flag - } - - // TODO Support aggregate expression with filter clause. - private def filterNotDefined(ne: NamedExpression): Boolean = { - ne match { - case Alias(ae: AggregateExpression, _) => ae.filter.isEmpty - case ae: AggregateExpression => ae.filter.isEmpty - } - } - // Merge the multiple Aggregates. private def mergePlan( left: LogicalPlan, - right: LogicalPlan): Option[(LogicalPlan, Map[Expression, Attribute], Seq[Expression])] = { + right: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute], Seq[Expression])] = { (left, right) match { case (la: Aggregate, ra: Aggregate) => val mergedChildPlan = mergePlan(la.child, ra.child) mergedChildPlan.map { case (newChild, outputMap, filters) => - val rightAggregateExprs = ra.aggregateExpressions.map { ne => - ne.transform { - case attr: Attribute => - outputMap.getOrElse(attr.canonicalized, attr) - }.asInstanceOf[NamedExpression] - } + val rightAggregateExprs = ra.aggregateExpressions.map(mapAttributes(_, outputMap)) val mergedAggregateExprs = if (filters.length == 2) { - la.aggregateExpressions.map { ne => - ne.transform { - case ae @ AggregateExpression(_, _, _, None, _) => - ae.copy(filter = Some(filters.head)) - }.asInstanceOf[NamedExpression] - } ++ rightAggregateExprs.map { ne => - ne.transform { - case ae @ AggregateExpression(_, _, _, None, _) => - ae.copy(filter = Some(filters.last)) - }.asInstanceOf[NamedExpression] + Seq( + (la.aggregateExpressions, filters.head), + (rightAggregateExprs, filters.last) + ).flatMap { case (aggregateExpressions, propagatedFilter) => + aggregateExpressions.map { ne => + ne.transform { + case ae @ AggregateExpression(_, _, _, filterOpt, _) => + val newFilter = filterOpt.map { filter => + And(filter, propagatedFilter) + }.orElse(Some(propagatedFilter)) + ae.copy(filter = newFilter) + }.asInstanceOf[NamedExpression] + } } } else { la.aggregateExpressions ++ rightAggregateExprs } - (Aggregate(Seq.empty, mergedAggregateExprs, newChild), Map.empty, Seq.empty) + (Aggregate(Seq.empty, mergedAggregateExprs, newChild), AttributeMap.empty, Seq.empty) } case (lp: Project, rp: Project) => val mergedInfo = mergePlan(lp.child, rp.child) @@ -117,11 +77,8 @@ object CombineJoinedAggregates extends Rule[LogicalPlan] { mergedInfo.map { case (newChild, outputMap, filters) => val allFilterReferences = filters.flatMap(_.references) - val newOutputMap = (rp.projectList ++ allFilterReferences).map { ne => - val mapped = ne.transform { - case attr: Attribute => - outputMap.getOrElse(attr.canonicalized, attr) - }.asInstanceOf[NamedExpression] + val newOutputMap = AttributeMap((rp.projectList ++ allFilterReferences).map { ne => + val mapped = mapAttributes(ne, outputMap) val withoutAlias = mapped match { case Alias(child, _) => child @@ -135,8 +92,8 @@ object CombineJoinedAggregates extends Rule[LogicalPlan] { mergedProjectList += mapped mapped }.toAttribute - ne.toAttribute.canonicalized -> outputAttr - }.toMap + ne.toAttribute -> outputAttr + }) (Project(mergedProjectList.toSeq, newChild), newOutputMap, filters) } @@ -144,33 +101,18 @@ object CombineJoinedAggregates extends Rule[LogicalPlan] { case (lf: Filter, rf: Filter) => val mergedInfo = mergePlan(lf.child, rf.child) mergedInfo.map { case (newChild, outputMap, _) => - val rightCondition = rf.condition transform { - case attr: Attribute => - outputMap.getOrElse(attr.canonicalized, attr) - } + val rightCondition = mapAttributes(rf.condition, outputMap) val newCondition = Or(lf.condition, rightCondition) (Filter(newCondition, newChild), outputMap, Seq(lf.condition, rightCondition)) } case (ll: LeafNode, rl: LeafNode) => - if (ll.canonicalized == rl.canonicalized) { - val outputMap = rl.output.zip(ll.output).map { case (ra, la) => - ra.canonicalized -> la - }.toMap - - Some((ll, outputMap, Seq.empty)) - } else { - None + checkIdenticalPlans(rl, ll).map { outputMap => + (ll, outputMap, Seq.empty) } case (ls: SerializeFromObject, rs: SerializeFromObject) => - if (ls.canonicalized == rs.canonicalized) { - val outputMap = rs.output.zip(ls.output).map { case (ra, la) => - ra.canonicalized -> la - }.toMap - - Some((ls, outputMap, Seq.empty)) - } else { - None + checkIdenticalPlans(rs, ls).map { outputMap => + (ls, outputMap, Seq.empty) } case _ => None } @@ -179,21 +121,12 @@ object CombineJoinedAggregates extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { if (!conf.combineJoinedAggregatesEnabled) return plan - plan.transformDownWithPruning(_.containsAnyPattern(JOIN, AGGREGATE), ruleId) { - case j @ Join(_, _, joinType, None, _) if isSupportedJoinType(joinType) => - val aggregates = ArrayBuffer.empty[Aggregate] - if (collectAggregate(j, aggregates)) { - var finalAggregate: Option[LogicalPlan] = None - for ((aggregate, i) <- aggregates.tail.zipWithIndex - if i == 0 || finalAggregate.isDefined) { - val mergedAggregate = mergePlan(finalAggregate.getOrElse(aggregates.head), aggregate) - finalAggregate = mergedAggregate.map(_._1) - } - - finalAggregate.getOrElse(j) - } else { - j - } + plan.transformUpWithPruning(_.containsAnyPattern(JOIN, AGGREGATE), ruleId) { + case j @ Join(left: Aggregate, right: Aggregate, joinType, None, _) + if isSupportedJoinType(joinType) && + left.groupingExpressions.isEmpty && right.groupingExpressions.isEmpty => + val mergedAggregate = mergePlan(left, right) + mergedAggregate.map(_._1).getOrElse(j) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala index 6184160829ba6..d48b3f1ed1301 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala @@ -101,7 +101,7 @@ import org.apache.spark.sql.types.DataType * : +- ReusedSubquery Subquery scalar-subquery#242, [id=#125] * +- *(1) Scan OneRowRelation[] */ -object MergeScalarSubqueries extends Rule[LogicalPlan] { +object MergeScalarSubqueries extends Rule[LogicalPlan] with MergeScalarSubqueriesHelper { def apply(plan: LogicalPlan): LogicalPlan = { plan match { // Subquery reuse needs to be enabled for this optimization. @@ -212,17 +212,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { } } - // If 2 plans are identical return the attribute mapping from the new to the cached version. - private def checkIdenticalPlans( - newPlan: LogicalPlan, - cachedPlan: LogicalPlan): Option[AttributeMap[Attribute]] = { - if (newPlan.canonicalized == cachedPlan.canonicalized) { - Some(AttributeMap(newPlan.output.zip(cachedPlan.output))) - } else { - None - } - } - // Recursively traverse down and try merging 2 plans. If merge is possible then return the merged // plan with the attribute mapping from the new to the merged version. // Please note that merging arbitrary plans can be complicated, the current version supports only @@ -314,12 +303,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] { plan) } - private def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]) = { - expr.transform { - case a: Attribute => outputMap.getOrElse(a, a) - }.asInstanceOf[T] - } - // Applies `outputMap` attribute mapping on attributes of `newExpressions` and merges them into // `cachedExpressions`. Returns the merged expressions and the attribute mapping from the new to // the merged version that can be propagated up during merging nodes. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesHelper.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesHelper.scala new file mode 100644 index 0000000000000..48ae5b3556609 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesHelper.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * The helper class used to merge scalar subqueries. + */ +trait MergeScalarSubqueriesHelper { + + // If 2 plans are identical return the attribute mapping from the new to the cached version. + protected def checkIdenticalPlans( + left: LogicalPlan, right: LogicalPlan): Option[AttributeMap[Attribute]] = { + if (left.canonicalized == right.canonicalized) { + Some(AttributeMap(left.output.zip(right.output))) + } else { + None + } + } + + protected def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]): T = { + expr.transform { + case a: Attribute => outputMap.getOrElse(a, a) + }.asInstanceOf[T] + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregatesSuite.scala index 5bdccaa165f6e..436bae6d44a7a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregatesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CombineJoinedAggregatesSuite.scala @@ -35,7 +35,8 @@ class CombineJoinedAggregatesSuite extends PlanTest { CollapseProject, RemoveNoopOperators, PushDownPredicates, - CombineJoinedAggregates) :: Nil + CombineJoinedAggregates, + BooleanSimplification) :: Nil } private object WithoutOptimize extends RuleExecutor[LogicalPlan] { @@ -43,7 +44,8 @@ class CombineJoinedAggregatesSuite extends PlanTest { Batch("Eliminate Join By Combine Aggregate", FixedPoint(10), CollapseProject, RemoveNoopOperators, - PushDownPredicates) :: Nil + PushDownPredicates, + BooleanSimplification) :: Nil } private val testRelation = LocalRelation.fromExternalRows( @@ -161,32 +163,6 @@ class CombineJoinedAggregatesSuite extends PlanTest { WithoutOptimize.execute(originalQuery4.analyze)) } - test("join side contains Aggregate and aggregate expressions exist Filter clause") { - val originalQuery1 = - testRelation.where(a === 1).groupBy()(sum(b, Some(c === 1)).as("sum_b")).join( - testRelation.where(a === 2).groupBy()(sum(b).as("sum_b"))) - - comparePlans( - Optimize.execute(originalQuery1.analyze), - WithoutOptimize.execute(originalQuery1.analyze)) - - val originalQuery2 = - testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( - testRelation.where(a === 2).groupBy()(sum(b, Some(c === 1)).as("sum_b"))) - - comparePlans( - Optimize.execute(originalQuery2.analyze), - WithoutOptimize.execute(originalQuery2.analyze)) - - val originalQuery3 = - testRelation.where(a === 1).groupBy()(sum(b, Some(c === 1)).as("sum_b")).join( - testRelation.where(a === 2).groupBy()(sum(b, Some(c === 1)).as("sum_b"))) - - comparePlans( - Optimize.execute(originalQuery3.analyze), - WithoutOptimize.execute(originalQuery3.analyze)) - } - test("join side contains Aggregate with group by clause") { val originalQuery1 = testRelation.where(a === 1).groupBy(c)(sum(b).as("sum_b")).join( @@ -326,57 +302,128 @@ class CombineJoinedAggregatesSuite extends PlanTest { Optimize.execute(originalQuery1.analyze), WithoutOptimize.execute(correctAnswer1.analyze)) -// val originalQuery2 = -// testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( -// testRelation.where(a === 2).groupBy()(avg(b).as("avg_b")).join( -// testRelation.where(a === 3).groupBy()(count(b).as("count_b")))) -// -// val correctAnswer2 = -// testRelation.where(a === 1 || a === 2 || a === 3).groupBy()( -// sum(b, Some(a === 1)).as("sum_b"), -// avg(b, Some(a === 2)).as("avg_b"), -// count(b, Some(a === 3)).as("count_b")) -// -// comparePlans( -// Optimize.execute(originalQuery2.analyze), -// WithoutOptimize.execute(correctAnswer2.analyze)) -// -// val originalQuery3 = -// testRelation.where(a === 1).groupBy()(avg(a).as("avg_a"), sum(b).as("sum_b")).join( -// testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"), sum(a).as("sum_a")).join( -// testRelation.where(a === 3).groupBy()( -// count(a).as("count_a"), -// count(b).as("count_b"), -// count(c).as("count_c")))) -// -// val correctAnswer3 = -// testRelation.where(a === 1 || a === 2 || a === 3).groupBy()( -// avg(a, Some(a === 1)).as("avg_a"), -// sum(b, Some(a === 1)).as("sum_b"), -// avg(b, Some(a === 2)).as("avg_b"), -// sum(a, Some(a === 2)).as("sum_a"), -// count(a, Some(a === 3)).as("count_a"), -// count(b, Some(a === 3)).as("count_b"), -// count(c, Some(a === 3)).as("count_c")) -// -// comparePlans( -// Optimize.execute(originalQuery3.analyze), -// WithoutOptimize.execute(correctAnswer3.analyze)) -// -// val originalQuery4 = -// testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( -// testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))).join( -// testRelation.where(a === 3).groupBy()(countDistinct(b).as("count_distinct_b"))) -// -// val correctAnswer4 = -// testRelation.where(a === 1 || a === 2 || a === 3).groupBy()( -// sum(b, Some(a === 1)).as("sum_b"), -// avg(b, Some(a === 2)).as("avg_b"), -// countDistinctWithFilter(a === 3, b).as("count_distinct_b")) -// -// comparePlans( -// Optimize.execute(originalQuery4.analyze), -// WithoutOptimize.execute(correctAnswer4.analyze)) + val originalQuery2 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b")).join( + testRelation.where(a === 3).groupBy()(count(b).as("count_b")))) + + val correctAnswer2 = + testRelation.where(a === 1 || (a === 2 || a === 3)).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b"), + count(b, Some(a === 3)).as("count_b")) + + comparePlans( + Optimize.execute(originalQuery2.analyze), + WithoutOptimize.execute(correctAnswer2.analyze)) + + val originalQuery3 = + testRelation.where(a === 1).groupBy()(avg(a).as("avg_a"), sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"), sum(a).as("sum_a")).join( + testRelation.where(a === 3).groupBy()( + count(a).as("count_a"), + count(b).as("count_b"), + count(c).as("count_c")))) + + val correctAnswer3 = + testRelation.where(a === 1 || (a === 2 || a === 3)).groupBy()( + avg(a, Some(a === 1)).as("avg_a"), + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b"), + sum(a, Some(a === 2)).as("sum_a"), + count(a, Some(a === 3)).as("count_a"), + count(b, Some(a === 3)).as("count_b"), + count(c, Some(a === 3)).as("count_c")) + + comparePlans( + Optimize.execute(originalQuery3.analyze), + WithoutOptimize.execute(correctAnswer3.analyze)) + + val originalQuery4 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))).join( + testRelation.where(a === 3).groupBy()(countDistinct(b).as("count_distinct_b"))) + + val correctAnswer4 = + testRelation.where(a === 1 || a === 2 || a === 3).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b"), + countDistinctWithFilter(a === 3, b).as("count_distinct_b")) + + comparePlans( + Optimize.execute(originalQuery4.analyze), + WithoutOptimize.execute(correctAnswer4.analyze)) + } + + test("join two side are Aggregates and aggregate expressions exist Filter clause") { + val originalQuery1 = + testRelation.where(a === 1).groupBy()(sum(b, Some(c === 1)).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))) + + val correctAnswer1 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some((c === 1) && (a === 1))).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b")) + + comparePlans( + Optimize.execute(originalQuery1.analyze), + WithoutOptimize.execute(correctAnswer1.analyze)) + + val originalQuery2 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b, Some(c === 1)).as("avg_b"))) + + val correctAnswer2 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some((c === 1) && (a === 2))).as("avg_b")) + + comparePlans( + Optimize.execute(originalQuery2.analyze), + WithoutOptimize.execute(correctAnswer2.analyze)) + + val originalQuery3 = + testRelation.where(a === 1).groupBy()(sum(b, Some(c === 1)).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b, Some(c === 1)).as("avg_b"))) + + val correctAnswer3 = + testRelation.where(a === 1 || a === 2).groupBy()( + sum(b, Some((c === 1) && (a === 1))).as("sum_b"), + avg(b, Some((c === 1) && (a === 2))).as("avg_b")) + + comparePlans( + Optimize.execute(originalQuery3.analyze), + WithoutOptimize.execute(correctAnswer3.analyze)) + + val originalQuery4 = + testRelation.where(a === 1).groupBy()(sum(b, Some(c === 1)).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b, Some(c === 1)).as("avg_b"))).join( + testRelation.where(a === 3).groupBy()(count(b, Some(c > 1)).as("count_b"))) + + val correctAnswer4 = + testRelation.where(a === 1 || a === 2 || a === 3).groupBy()( + sum(b, Some(((c === 1) && (a === 1)) && ((a === 1) || (a === 2)))).as("sum_b"), + avg(b, Some(((c === 1) && (a === 2)) && ((a === 1) || (a === 2)))).as("avg_b"), + count(b, Some((c > 1) && (a === 3))).as("count_b")) + + comparePlans( + Optimize.execute(originalQuery4.analyze), + WithoutOptimize.execute(correctAnswer4.analyze)) + + val originalQuery5 = + testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( + testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))).join( + testRelation.where(a === 3).groupBy()(count(b, Some(c === 1)).as("count_b"))) + + val correctAnswer5 = + testRelation.where(a === 1 || a === 2 || a === 3).groupBy()( + sum(b, Some(a === 1)).as("sum_b"), + avg(b, Some(a === 2)).as("avg_b"), + count(b, Some((c === 1) && (a === 3))).as("count_b")) + + comparePlans( + Optimize.execute(originalQuery5.analyze), + WithoutOptimize.execute(correctAnswer5.analyze)) } test("upstream join could be optimized") { @@ -434,31 +481,16 @@ class CombineJoinedAggregatesSuite extends PlanTest { val originalQuery4 = testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))).join( - testRelation.where(a === 3).groupBy()(count(b, Some(c === 1)).as("count_b"))) + testRelation.where(a === 3).groupBy(c)(count(b).as("count_b"))) val correctAnswer4 = testRelation.where(a === 1 || a === 2).groupBy()( sum(b, Some(a === 1)).as("sum_b"), avg(b, Some(a === 2)).as("avg_b")).join( - testRelation.where(a === 3).groupBy()(count(b, Some(c === 1)).as("count_b"))) + testRelation.where(a === 3).groupBy(c)(count(b).as("count_b"))) comparePlans( Optimize.execute(originalQuery4.analyze), WithoutOptimize.execute(correctAnswer4.analyze)) - - val originalQuery5 = - testRelation.where(a === 1).groupBy()(sum(b).as("sum_b")).join( - testRelation.where(a === 2).groupBy()(avg(b).as("avg_b"))).join( - testRelation.where(a === 3).groupBy(c)(count(b).as("count_b"))) - - val correctAnswer5 = - testRelation.where(a === 1 || a === 2).groupBy()( - sum(b, Some(a === 1)).as("sum_b"), - avg(b, Some(a === 2)).as("avg_b")).join( - testRelation.where(a === 3).groupBy(c)(count(b).as("count_b"))) - - comparePlans( - Optimize.execute(originalQuery5.analyze), - WithoutOptimize.execute(correctAnswer5.analyze)) } } diff --git a/sql/core/benchmarks/CombineJoinedAggregatesBenchmark-results.txt b/sql/core/benchmarks/CombineJoinedAggregatesBenchmark-results.txt index 1b56081b9da7b..7ae40509c4ae0 100644 --- a/sql/core/benchmarks/CombineJoinedAggregatesBenchmark-results.txt +++ b/sql/core/benchmarks/CombineJoinedAggregatesBenchmark-results.txt @@ -6,21 +6,21 @@ Java HotSpot(TM) 64-Bit Server VM 1.8.0_311-b11 on Mac OS X 10.16 Intel(R) Core(TM) i7-9750H CPU @ 2.60GHz Benchmark CombineJoinedAggregates: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative ------------------------------------------------------------------------------------------------------------------------------------- -filter is not defined, CombineJoinedAggregates: false 753 832 93 27.8 35.9 1.0X -filter is not defined, CombineJoinedAggregates: true 644 682 50 32.6 30.7 1.2X -step is 1000000, CombineJoinedAggregates: false 606 686 43 34.6 28.9 1.2X -step is 1000000, CombineJoinedAggregates: true) 755 882 158 27.8 36.0 1.0X -step is 100000, CombineJoinedAggregates: false 402 490 83 52.2 19.1 1.9X -step is 100000, CombineJoinedAggregates: true) 217 234 11 96.8 10.3 3.5X -step is 10000, CombineJoinedAggregates: false 337 372 32 62.3 16.0 2.2X -step is 10000, CombineJoinedAggregates: true) 157 176 14 133.5 7.5 4.8X -step is 1000, CombineJoinedAggregates: false 331 358 19 63.4 15.8 2.3X -step is 1000, CombineJoinedAggregates: true) 150 178 23 139.5 7.2 5.0X -step is 100, CombineJoinedAggregates: false 302 393 56 69.4 14.4 2.5X -step is 100, CombineJoinedAggregates: true) 148 163 13 141.6 7.1 5.1X -step is 10, CombineJoinedAggregates: false 338 380 26 62.0 16.1 2.2X -step is 10, CombineJoinedAggregates: true) 163 178 16 128.5 7.8 4.6X -step is 1, CombineJoinedAggregates: false 323 362 45 65.0 15.4 2.3X -step is 1, CombineJoinedAggregates: true) 132 156 25 158.4 6.3 5.7X +filter is not defined, CombineJoinedAggregates: false 730 819 69 28.7 34.8 1.0X +filter is not defined, CombineJoinedAggregates: true 618 632 14 33.9 29.5 1.2X +step is 1000000, CombineJoinedAggregates: false 572 590 20 36.7 27.3 1.3X +step is 1000000, CombineJoinedAggregates: true) 769 794 21 27.3 36.6 1.0X +step is 100000, CombineJoinedAggregates: false 350 370 26 59.9 16.7 2.1X +step is 100000, CombineJoinedAggregates: true) 231 241 10 90.7 11.0 3.2X +step is 10000, CombineJoinedAggregates: false 314 340 26 66.8 15.0 2.3X +step is 10000, CombineJoinedAggregates: true) 171 182 9 122.5 8.2 4.3X +step is 1000, CombineJoinedAggregates: false 303 337 32 69.3 14.4 2.4X +step is 1000, CombineJoinedAggregates: true) 162 171 9 129.4 7.7 4.5X +step is 100, CombineJoinedAggregates: false 300 316 27 70.0 14.3 2.4X +step is 100, CombineJoinedAggregates: true) 160 169 9 131.3 7.6 4.6X +step is 10, CombineJoinedAggregates: false 297 325 33 70.6 14.2 2.5X +step is 10, CombineJoinedAggregates: true) 170 203 36 123.5 8.1 4.3X +step is 1, CombineJoinedAggregates: false 328 352 17 64.0 15.6 2.2X +step is 1, CombineJoinedAggregates: true) 140 148 7 149.3 6.7 5.2X diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 5d33f305a70d4..4b838bf47ecf8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -24,6 +24,8 @@ import scala.util.Random import org.scalatest.matchers.must.Matchers.the import org.apache.spark.{SparkException, SparkThrowable} +import org.apache.spark.sql.catalyst.expressions.EqualTo +import org.apache.spark.sql.catalyst.expressions.aggregate.{Average, Count, Sum} import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec, SortAggregateExec} @@ -2234,7 +2236,7 @@ class DataFrameAggregateSuite extends QueryTest Seq.empty, "left_anti") checkAnswer(join16, Seq.empty) - // ReorderJoin push the join condition into upstream JOIN, + // ReorderJoin push the join condition of inner like join into upstream join, // So EliminateJoinByCombineAggregate can't eliminate JOIN. val join17 = df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( @@ -2250,6 +2252,8 @@ class DataFrameAggregateSuite extends QueryTest $"left.sum_temp" === $"right.count_temp", "cross") checkAnswer(join18, Seq.empty) + // ReorderJoin can't push the join condition of non inner like join into upstream join, + // So EliminateJoinByCombineAggregate can still eliminate JOIN. val join19 = df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))).as("left").join( @@ -2270,6 +2274,47 @@ class DataFrameAggregateSuite extends QueryTest df.where($"date" === 20151125).agg(count($"temp").as("count_temp")).as("right"), $"left.sum_temp" === $"right.count_temp", "full_outer") checkAnswer(join21, Seq(Row(84.0, 23.125, null), Row(null, null, 5))) + + // join two side are Aggregates and aggregate expressions exist Filter clause + val sumWithFilter = Sum($"temp".expr).toAggregateExpression( + false, Some(EqualTo($"room_name".expr, lit("room1").expr))) + val join22 = + df.where($"date" === 20151123).agg( + new Column(sumWithFilter).as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))) + checkAnswer(join22, Row(36.0, 23.125)) + + val avgWithFilter = Average($"temp".expr).toAggregateExpression( + false, Some(EqualTo($"room_name".expr, lit("room1").expr))) + val join23 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(new Column(avgWithFilter).as("avg_temp"))) + checkAnswer(join23, Row(84.0, 18.25)) + + val join24 = + df.where($"date" === 20151123).agg(new Column(sumWithFilter).as("sum_temp")).join( + df.where($"date" === 20151124).agg(new Column(avgWithFilter).as("avg_temp"))) + checkAnswer(join24, Row(36.0, 18.25)) + + val countWithFilter = Count($"temp".expr).toAggregateExpression( + false, Some(EqualTo($"room_name".expr, lit("room2").expr))) + val join25 = + df.where($"date" === 20151123).agg(new Column(sumWithFilter).as("sum_temp")).join( + df.where($"date" === 20151124).agg(new Column(avgWithFilter).as("avg_temp"))).join( + df.where($"date" === 20151125).agg(new Column(countWithFilter).as("count_temp"))) + checkAnswer(join25, Row(36.0, 18.25, 3)) + + val join26 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(avg($"temp").as("avg_temp"))).join( + df.where($"date" === 20151125).agg(new Column(countWithFilter).as("count_temp"))) + checkAnswer(join26, Row(84.0, 23.125, 3)) + + val join27 = + df.where($"date" === 20151123).agg(sum($"temp").as("sum_temp")).join( + df.where($"date" === 20151124).agg(new Column(avgWithFilter).as("avg_temp"))).join( + df.where($"date" === 20151125).agg(count($"temp").as("count_temp"))) + checkAnswer(join27, Row(84.0, 18.25, 5)) } } }