From 196845aacf65c8b1e89414a141e01033f57c437d Mon Sep 17 00:00:00 2001 From: Fei Wang Date: Thu, 25 Jan 2024 11:08:32 -0800 Subject: [PATCH] use optimized plan --- .../spark/sql/kyuubi/SparkDatasetHelper.scala | 16 ++++++++-------- .../sql/kyuubi/SparkDatasetHelperSuite.scala | 8 ++++---- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala index a1f303a2623..d4cf25ec728 100644 --- a/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala +++ b/externals/kyuubi-spark-sql-engine/src/main/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelper.scala @@ -25,8 +25,9 @@ import org.apache.spark.network.util.{ByteUnit, JavaUtils} import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession} import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.plans.logical.{GlobalLimit, LogicalPlan} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils -import org.apache.spark.sql.execution.{CollectLimitExec, HiveResult, LocalTableScanExec, SparkPlan, SQLExecution, TakeOrderedAndProjectExec} +import org.apache.spark.sql.execution.{CollectLimitExec, HiveResult, LocalTableScanExec, QueryExecution, SparkPlan, SQLExecution} import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.arrow.KyuubiArrowConverters import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -294,18 +295,17 @@ object SparkDatasetHelper extends Logging { SQLMetrics.postDriverMetricUpdates(sc, executionId, metrics.values.toSeq) } - private[kyuubi] def planLimit(plan: SparkPlan): Option[Int] = plan match { - case tp: TakeOrderedAndProjectExec => Option(tp.limit) - case c: CollectLimitExec => Option(c.limit) - case ap: AdaptiveSparkPlanExec => planLimit(ap.inputPlan) - case _ => None - } + private[kyuubi] def optimizedPlanLimit(queryExecution: QueryExecution): Option[Long] = + queryExecution.optimizedPlan match { + case globalLimit: GlobalLimit => globalLimit.maxRows + case _ => None + } def shouldSaveResultToFs(resultMaxRows: Int, minSize: Long, result: DataFrame): Boolean = { if (isCommandExec(result.queryExecution.executedPlan.nodeName)) { return false } - val finalLimit = planLimit(result.queryExecution.sparkPlan) match { + val finalLimit = optimizedPlanLimit(result.queryExecution) match { case Some(limit) if resultMaxRows > 0 => math.min(limit, resultMaxRows) case Some(limit) => limit case None => resultMaxRows diff --git a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelperSuite.scala b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelperSuite.scala index 8e51484b176..8ac00e60262 100644 --- a/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelperSuite.scala +++ b/externals/kyuubi-spark-sql-engine/src/test/scala/org/apache/spark/sql/kyuubi/SparkDatasetHelperSuite.scala @@ -33,13 +33,13 @@ class SparkDatasetHelperSuite extends WithSparkSQLEngine { " SELECT * FROM VALUES(1),(2),(3),(4) AS t(id)") val topKStatement = s"SELECT * FROM(SELECT * FROM tv ORDER BY id LIMIT ${topKThreshold - 1})" - assert(SparkDatasetHelper.planLimit( - spark.sql(topKStatement).queryExecution.sparkPlan) === Option(topKThreshold - 1)) + assert(SparkDatasetHelper.optimizedPlanLimit( + spark.sql(topKStatement).queryExecution) === Option(topKThreshold - 1)) val collectLimitStatement = s"SELECT * FROM (SELECT * FROM tv ORDER BY id LIMIT $topKThreshold)" - assert(SparkDatasetHelper.planLimit( - spark.sql(collectLimitStatement).queryExecution.sparkPlan) === Option(topKThreshold)) + assert(SparkDatasetHelper.optimizedPlanLimit( + spark.sql(collectLimitStatement).queryExecution) === Option(topKThreshold)) } } }