Skip to content

Commit

Permalink
transform down -> transform up and reuse some functions from MergeSca…
Browse files Browse the repository at this point in the history
…larSubqueries
  • Loading branch information
beliefer committed Aug 1, 2023
1 parent 1573e62 commit 2ef930b
Show file tree
Hide file tree
Showing 6 changed files with 269 additions and 233 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, Expression, NamedExpression, Or}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeMap, Expression, NamedExpression, Or}
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, LeafNode, LogicalPlan, Project, SerializeFromObject}
Expand All @@ -32,96 +32,53 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{AGGREGATE, JOIN}
* every [[Join]] are [[Aggregate]]s.
*
* Note: this rule doesn't following cases:
* 1. The [[Aggregate]]s to be merged exists filter clause in aggregate expressions.
* 2. One of the to be merged two [[Aggregate]]s with child [[Filter]] and the other one is not.
* 3. The upstream node of these [[Aggregate]]s to be merged exists [[Join]].
* 1. One of the to be merged two [[Aggregate]]s with child [[Filter]] and the other one is not.
* 2. The upstream node of these [[Aggregate]]s to be merged exists [[Join]].
*/
object CombineJoinedAggregates extends Rule[LogicalPlan] {
object CombineJoinedAggregates extends Rule[LogicalPlan] with MergeScalarSubqueriesHelper {

private def isSupportedJoinType(joinType: JoinType): Boolean =
Seq(Inner, Cross, LeftOuter, RightOuter, FullOuter).contains(joinType)

// Collect all the Aggregates from both side of single or nested Join.
private def collectAggregate(plan: LogicalPlan, aggregates: ArrayBuffer[Aggregate]): Boolean = {
var flag = true
if (plan.containsAnyPattern(JOIN, AGGREGATE)) {
plan match {
case Join(left: Aggregate, right: Aggregate, _, None, _)
if left.groupingExpressions.isEmpty && right.groupingExpressions.isEmpty &&
left.aggregateExpressions.forall(filterNotDefined) &&
right.aggregateExpressions.forall(filterNotDefined) =>
aggregates += left
aggregates += right
case Join(left @ Join(_, _, joinType, None, _), right: Aggregate, _, None, _)
if isSupportedJoinType(joinType) && right.groupingExpressions.isEmpty &&
right.aggregateExpressions.forall(filterNotDefined) =>
flag = collectAggregate(left, aggregates)
aggregates += right
case Join(left: Aggregate, right @ Join(_, _, joinType, None, _), _, None, _)
if isSupportedJoinType(joinType) && left.groupingExpressions.isEmpty &&
left.aggregateExpressions.forall(filterNotDefined) =>
aggregates += left
flag = collectAggregate(right, aggregates)
// The side of Join is neither Aggregate nor Join.
case _ => flag = false
}
}

flag
}

// TODO Support aggregate expression with filter clause.
private def filterNotDefined(ne: NamedExpression): Boolean = {
ne match {
case Alias(ae: AggregateExpression, _) => ae.filter.isEmpty
case ae: AggregateExpression => ae.filter.isEmpty
}
}

// Merge the multiple Aggregates.
private def mergePlan(
left: LogicalPlan,
right: LogicalPlan): Option[(LogicalPlan, Map[Expression, Attribute], Seq[Expression])] = {
right: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute], Seq[Expression])] = {
(left, right) match {
case (la: Aggregate, ra: Aggregate) =>
val mergedChildPlan = mergePlan(la.child, ra.child)
mergedChildPlan.map { case (newChild, outputMap, filters) =>
val rightAggregateExprs = ra.aggregateExpressions.map { ne =>
ne.transform {
case attr: Attribute =>
outputMap.getOrElse(attr.canonicalized, attr)
}.asInstanceOf[NamedExpression]
}
val rightAggregateExprs = ra.aggregateExpressions.map(mapAttributes(_, outputMap))

val mergedAggregateExprs = if (filters.length == 2) {
la.aggregateExpressions.map { ne =>
ne.transform {
case ae @ AggregateExpression(_, _, _, None, _) =>
ae.copy(filter = Some(filters.head))
}.asInstanceOf[NamedExpression]
} ++ rightAggregateExprs.map { ne =>
ne.transform {
case ae @ AggregateExpression(_, _, _, None, _) =>
ae.copy(filter = Some(filters.last))
}.asInstanceOf[NamedExpression]
Seq(
(la.aggregateExpressions, filters.head),
(rightAggregateExprs, filters.last)
).flatMap { case (aggregateExpressions, propagatedFilter) =>
aggregateExpressions.map { ne =>
ne.transform {
case ae @ AggregateExpression(_, _, _, filterOpt, _) =>
val newFilter = filterOpt.map { filter =>
And(filter, propagatedFilter)
}.orElse(Some(propagatedFilter))
ae.copy(filter = newFilter)
}.asInstanceOf[NamedExpression]
}
}
} else {
la.aggregateExpressions ++ rightAggregateExprs
}

(Aggregate(Seq.empty, mergedAggregateExprs, newChild), Map.empty, Seq.empty)
(Aggregate(Seq.empty, mergedAggregateExprs, newChild), AttributeMap.empty, Seq.empty)
}
case (lp: Project, rp: Project) =>
val mergedInfo = mergePlan(lp.child, rp.child)
val mergedProjectList = ArrayBuffer[NamedExpression](lp.projectList: _*)

mergedInfo.map { case (newChild, outputMap, filters) =>
val allFilterReferences = filters.flatMap(_.references)
val newOutputMap = (rp.projectList ++ allFilterReferences).map { ne =>
val mapped = ne.transform {
case attr: Attribute =>
outputMap.getOrElse(attr.canonicalized, attr)
}.asInstanceOf[NamedExpression]
val newOutputMap = AttributeMap((rp.projectList ++ allFilterReferences).map { ne =>
val mapped = mapAttributes(ne, outputMap)

val withoutAlias = mapped match {
case Alias(child, _) => child
Expand All @@ -135,42 +92,27 @@ object CombineJoinedAggregates extends Rule[LogicalPlan] {
mergedProjectList += mapped
mapped
}.toAttribute
ne.toAttribute.canonicalized -> outputAttr
}.toMap
ne.toAttribute -> outputAttr
})

(Project(mergedProjectList.toSeq, newChild), newOutputMap, filters)
}
// TODO support only one side contains Filter
case (lf: Filter, rf: Filter) =>
val mergedInfo = mergePlan(lf.child, rf.child)
mergedInfo.map { case (newChild, outputMap, _) =>
val rightCondition = rf.condition transform {
case attr: Attribute =>
outputMap.getOrElse(attr.canonicalized, attr)
}
val rightCondition = mapAttributes(rf.condition, outputMap)
val newCondition = Or(lf.condition, rightCondition)

(Filter(newCondition, newChild), outputMap, Seq(lf.condition, rightCondition))
}
case (ll: LeafNode, rl: LeafNode) =>
if (ll.canonicalized == rl.canonicalized) {
val outputMap = rl.output.zip(ll.output).map { case (ra, la) =>
ra.canonicalized -> la
}.toMap

Some((ll, outputMap, Seq.empty))
} else {
None
checkIdenticalPlans(rl, ll).map { outputMap =>
(ll, outputMap, Seq.empty)
}
case (ls: SerializeFromObject, rs: SerializeFromObject) =>
if (ls.canonicalized == rs.canonicalized) {
val outputMap = rs.output.zip(ls.output).map { case (ra, la) =>
ra.canonicalized -> la
}.toMap

Some((ls, outputMap, Seq.empty))
} else {
None
checkIdenticalPlans(rs, ls).map { outputMap =>
(ls, outputMap, Seq.empty)
}
case _ => None
}
Expand All @@ -179,21 +121,12 @@ object CombineJoinedAggregates extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = {
if (!conf.combineJoinedAggregatesEnabled) return plan

plan.transformDownWithPruning(_.containsAnyPattern(JOIN, AGGREGATE), ruleId) {
case j @ Join(_, _, joinType, None, _) if isSupportedJoinType(joinType) =>
val aggregates = ArrayBuffer.empty[Aggregate]
if (collectAggregate(j, aggregates)) {
var finalAggregate: Option[LogicalPlan] = None
for ((aggregate, i) <- aggregates.tail.zipWithIndex
if i == 0 || finalAggregate.isDefined) {
val mergedAggregate = mergePlan(finalAggregate.getOrElse(aggregates.head), aggregate)
finalAggregate = mergedAggregate.map(_._1)
}

finalAggregate.getOrElse(j)
} else {
j
}
plan.transformUpWithPruning(_.containsAnyPattern(JOIN, AGGREGATE), ruleId) {
case j @ Join(left: Aggregate, right: Aggregate, joinType, None, _)
if isSupportedJoinType(joinType) &&
left.groupingExpressions.isEmpty && right.groupingExpressions.isEmpty =>
val mergedAggregate = mergePlan(left, right)
mergedAggregate.map(_._1).getOrElse(j)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ import org.apache.spark.sql.types.DataType
* : +- ReusedSubquery Subquery scalar-subquery#242, [id=#125]
* +- *(1) Scan OneRowRelation[]
*/
object MergeScalarSubqueries extends Rule[LogicalPlan] {
object MergeScalarSubqueries extends Rule[LogicalPlan] with MergeScalarSubqueriesHelper {
def apply(plan: LogicalPlan): LogicalPlan = {
plan match {
// Subquery reuse needs to be enabled for this optimization.
Expand Down Expand Up @@ -212,17 +212,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
}
}

// If 2 plans are identical return the attribute mapping from the new to the cached version.
private def checkIdenticalPlans(
newPlan: LogicalPlan,
cachedPlan: LogicalPlan): Option[AttributeMap[Attribute]] = {
if (newPlan.canonicalized == cachedPlan.canonicalized) {
Some(AttributeMap(newPlan.output.zip(cachedPlan.output)))
} else {
None
}
}

// Recursively traverse down and try merging 2 plans. If merge is possible then return the merged
// plan with the attribute mapping from the new to the merged version.
// Please note that merging arbitrary plans can be complicated, the current version supports only
Expand Down Expand Up @@ -314,12 +303,6 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] {
plan)
}

private def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]) = {
expr.transform {
case a: Attribute => outputMap.getOrElse(a, a)
}.asInstanceOf[T]
}

// Applies `outputMap` attribute mapping on attributes of `newExpressions` and merges them into
// `cachedExpressions`. Returns the merged expressions and the attribute mapping from the new to
// the merged version that can be propagated up during merging nodes.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, Expression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan

/**
* The helper class used to merge scalar subqueries.
*/
trait MergeScalarSubqueriesHelper {

// If 2 plans are identical return the attribute mapping from the new to the cached version.
protected def checkIdenticalPlans(
left: LogicalPlan, right: LogicalPlan): Option[AttributeMap[Attribute]] = {
if (left.canonicalized == right.canonicalized) {
Some(AttributeMap(left.output.zip(right.output)))
} else {
None
}
}

protected def mapAttributes[T <: Expression](expr: T, outputMap: AttributeMap[Attribute]): T = {
expr.transform {
case a: Attribute => outputMap.getOrElse(a, a)
}.asInstanceOf[T]
}
}
Loading

0 comments on commit 2ef930b

Please sign in to comment.