diff --git a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala index 2f1f5ab744..6dd1023520 100644 --- a/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala +++ b/spark/src/main/scala/org/apache/comet/rules/RewriteJoin.scala @@ -20,7 +20,8 @@ package org.apache.comet.rules import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, BuildSide, JoinSelectionHelper} -import org.apache.spark.sql.catalyst.plans.{JoinType, LeftSemi} +import org.apache.spark.sql.catalyst.plans.LeftSemi +import org.apache.spark.sql.catalyst.plans.logical.Join import org.apache.spark.sql.execution.{SortExec, SparkPlan} import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} @@ -31,14 +32,29 @@ import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoin */ object RewriteJoin extends JoinSelectionHelper { - private def getBuildSide(joinType: JoinType): Option[BuildSide] = { - if (canBuildShuffledHashJoinRight(joinType)) { - Some(BuildRight) - } else if (canBuildShuffledHashJoinLeft(joinType)) { - Some(BuildLeft) - } else { - None + private def getSmjBuildSide(join: SortMergeJoinExec): Option[BuildSide] = { + val leftBuildable = canBuildShuffledHashJoinLeft(join.joinType) + val rightBuildable = canBuildShuffledHashJoinRight(join.joinType) + if (!leftBuildable && !rightBuildable) { + return None } + if (!leftBuildable) { + return Some(BuildRight) + } + if (!rightBuildable) { + return Some(BuildLeft) + } + val side = join.logicalLink + .flatMap { + case join: Join => Some(getOptimalBuildSide(join)) + case _ => None + } + .getOrElse { + // If smj has no logical link, or its logical link is not a join, + // then we always choose left as build side. + BuildLeft + } + Some(side) } private def removeSort(plan: SparkPlan) = plan match { @@ -48,7 +64,7 @@ object RewriteJoin extends JoinSelectionHelper { def rewrite(plan: SparkPlan): SparkPlan = plan match { case smj: SortMergeJoinExec => - getBuildSide(smj.joinType) match { + getSmjBuildSide(smj) match { case Some(BuildRight) if smj.joinType == LeftSemi => // TODO this was added as a workaround for TPC-DS q14 hanging and needs // further investigation @@ -67,4 +83,21 @@ object RewriteJoin extends JoinSelectionHelper { } case _ => plan } + + def getOptimalBuildSide(join: Join): BuildSide = { + val leftSize = join.left.stats.sizeInBytes + val rightSize = join.right.stats.sizeInBytes + val leftRowCount = join.left.stats.rowCount + val rightRowCount = join.right.stats.rowCount + if (leftSize == rightSize && rightRowCount.isDefined && leftRowCount.isDefined) { + if (rightRowCount.get <= leftRowCount.get) { + return BuildRight + } + return BuildLeft + } + if (rightSize <= leftSize) { + return BuildRight + } + BuildLeft + } }