From 1f496fbea688c7082bad7e6280c8a949fbfd31b7 Mon Sep 17 00:00:00 2001 From: ulysses-you Date: Tue, 18 Jan 2022 16:22:03 +0800 Subject: [PATCH] [SPARK-37949][SQL] Improve Rebalance statistics estimation ### What changes were proposed in this pull request? Match `RebalancePartitions` in `SizeInBytesOnlyStatsPlanVisitor` and `BasicStatsPlanVisitor`. ### Why are the changes needed? The defualt statistics estimation only consider the size in bytes, which may lost the row rount and columns statistics. The `RebalancePartitions` actually does not change the statistics of plan, so we can use the statistics of its child for more accurate. ### Does this PR introduce _any_ user-facing change? no, only affect the statistics of plan ### How was this patch tested? Unify the test in `BasicStatsEstimationSuite` Closes #35235 from ulysses-you/SPARK-37949. Authored-by: ulysses-you Signed-off-by: Wenchen Fan --- .../sql/catalyst/plans/logical/LogicalPlanVisitor.scala | 3 +++ .../logical/statsEstimation/BasicStatsPlanVisitor.scala | 2 ++ .../statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala | 2 ++ .../statsEstimation/BasicStatsEstimationSuite.scala | 8 ++++++-- 4 files changed, 13 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala index ba927746bbf6a..fd5f9051719dc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlanVisitor.scala @@ -37,6 +37,7 @@ trait LogicalPlanVisitor[T] { case p: Project => visitProject(p) case p: Repartition => visitRepartition(p) case p: RepartitionByExpression => visitRepartitionByExpr(p) + case p: RebalancePartitions => visitRebalancePartitions(p) case p: Sample => visitSample(p) case p: ScriptTransformation => visitScriptTransform(p) case p: Union => visitUnion(p) @@ -77,6 +78,8 @@ trait LogicalPlanVisitor[T] { def visitRepartitionByExpr(p: RepartitionByExpression): T + def visitRebalancePartitions(p: RebalancePartitions): T + def visitSample(p: Sample): T def visitScriptTransform(p: ScriptTransformation): T diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala index 3f702724cca53..0f09022fb9c2f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/BasicStatsPlanVisitor.scala @@ -88,6 +88,8 @@ object BasicStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = fallback(p) + override def visitRebalancePartitions(p: RebalancePartitions): Statistics = fallback(p) + override def visitSample(p: Sample): Statistics = fallback(p) override def visitScriptTransform(p: ScriptTransformation): Statistics = default(p) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala index 73c1b9445f693..67a045fe5ec1a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/SizeInBytesOnlyStatsPlanVisitor.scala @@ -132,6 +132,8 @@ object SizeInBytesOnlyStatsPlanVisitor extends LogicalPlanVisitor[Statistics] { override def visitRepartitionByExpr(p: RepartitionByExpression): Statistics = p.child.stats + override def visitRebalancePartitions(p: RebalancePartitions): Statistics = p.child.stats + override def visitSample(p: Sample): Statistics = { val ratio = p.upperBound - p.lowerBound var sizeInBytes = EstimationUtils.ceil(BigDecimal(p.child.stats.sizeInBytes) * ratio) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala index 31e289e052586..bc61a76ecfc22 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -259,12 +259,16 @@ class BasicStatsEstimationSuite extends PlanTest with StatsEstimationTestBase { expectedStatsCboOff = Statistics.DUMMY) } - test("SPARK-35203: Improve Repartition statistics estimation") { + test("Improve Repartition statistics estimation") { + // SPARK-35203 for repartition and repartitionByExpr + // SPARK-37949 for rebalance Seq( RepartitionByExpression(plan.output, plan, 10), RepartitionByExpression(Nil, plan, None), plan.repartition(2), - plan.coalesce(3)).foreach { rep => + plan.coalesce(3), + plan.rebalance(), + plan.rebalance(plan.output: _*)).foreach { rep => val expectedStats = Statistics(plan.size.get, Some(plan.rowCount), plan.attributeStats) checkStats( rep,