diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index ba927746bbf6a..fd5f9051719dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -37,6 +37,7 @@ trait LogicalPlanVisitor[T] { case p: Project => visitProject(p) case p: Repartition => visitRepartition(p) case p: RepartitionByExpression => visitRepartitionByExpr(p) + case p: RebalancePartitions => visitRebalancePartitions(p) case p: Sample => visitSample(p) case p: ScriptTransformation => visitScriptTransform(p) case p: Union => visitUnion(p) @@ -77,6 +78,8 @@ trait LogicalPlanVisitor[T] { def visitRepartitionByExpr(p: RepartitionByExpression): T + def visitRebalancePartitions(p: RebalancePartitions): T + def visitSample(p: Sample): T def visitScriptTransform(p: ScriptTransformation): T diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index 3f702724cca53..0f09022fb9c2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -88,6 +88,8 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = fallback(p) + override def visitRebalancePartitions(p: RebalancePartitions): Statistics = fallback(p) + override def visitSample(p: Sample): Statistics = fallback(p) override def visitScriptTransform(p: ScriptTransformation): Statistics = default(p) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 73c1b9445f693..67a045fe5ec1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -132,6 +132,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = p.child.stats + override def visitRebalancePartitions(p: RebalancePartitions): Statistics = p.child.stats + override def visitSample(p: Sample): Statistics = { val ratio = p.upperBound - p.lowerBound var sizeInBytes = EstimationUtils.ceil(BigDecimal(p.child.stats.sizeInBytes) * ratio) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 31e289e052586..bc61a76ecfc22 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -259,12 +259,16 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { expectedStatsCboOff = Statistics.DUMMY) } - test("SPARK-35203: Improve Repartition statistics estimation") { + test("Improve Repartition statistics estimation") { + // SPARK-35203 for repartition and repartitionByExpr + // SPARK-37949 for rebalance Seq( RepartitionByExpression(plan.output, plan, 10), RepartitionByExpression(Nil, plan, None), plan.repartition(2), - plan.coalesce(3)).foreach { rep => + plan.coalesce(3), + plan.rebalance(), + plan.rebalance(plan.output: _*)).foreach { rep => val expectedStats = Statistics(plan.size.get, Some(plan.rowCount), plan.attributeStats) checkStats( rep,