Skip to content

Commit

Permalink
fix and add test
Browse files Browse the repository at this point in the history
  • Loading branch information
wForget committed Feb 1, 2024
1 parent cd23093 commit 368ecce
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ import java.util.concurrent.atomic.AtomicBoolean

import org.apache.spark.SparkContext
import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.SQLConfHelper
import org.apache.spark.sql.catalyst.util.truncatedString
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.adaptive.{AdaptiveSparkPlanExec, QueryStageExec}
Expand All @@ -34,33 +32,32 @@ import org.apache.spark.status.ElementTrackingStore

import org.apache.kyuubi.sql.KyuubiSQLConf.COLLECT_METRICS_PRETTY_DISPLAY_ENABLED

private class CollectMetricsPrettyDisplayListener extends SparkListener with SQLConfHelper {

private def session: SparkSession = SparkSession.active
private def kvstore: ElementTrackingStore =
session.sparkContext.statusStore.store.asInstanceOf[ElementTrackingStore]
private class CollectMetricsPrettyDisplayListener extends SparkListener {

override def onOtherEvent(event: SparkListenerEvent): Unit = {
if (conf.getConf(COLLECT_METRICS_PRETTY_DISPLAY_ENABLED)) {
event match {
case e: SparkListenerSQLExecutionEnd =>
val qe = e.qe
if (qe.observedMetrics.nonEmpty) {
val executionId =
Option(session.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)).map(
_.toLong).getOrElse(e.executionId)

val sparkPlanInfo = fromSparkPlan(qe.executedPlan)

val planGraph = SparkPlanGraph(sparkPlanInfo)
val graphToStore = new SparkPlanGraphWrapper(
executionId,
toStoredNodes(planGraph.nodes),
planGraph.edges)
kvstore.write(graphToStore)
}
case _ =>
}
event match {
case e: SparkListenerSQLExecutionEnd
if e.qe.sparkSession.conf.get(COLLECT_METRICS_PRETTY_DISPLAY_ENABLED) =>
val qe = e.qe
if (qe.observedMetrics.nonEmpty) {
val session = qe.sparkSession
val executionId =
Option(session.sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)).map(
_.toLong).getOrElse(e.executionId)

val sparkPlanInfo = fromSparkPlan(qe.executedPlan)

val planGraph = SparkPlanGraph(sparkPlanInfo)
val graphToStore = new SparkPlanGraphWrapper(
executionId,
toStoredNodes(planGraph.nodes),
planGraph.edges)

val kvstore: ElementTrackingStore =
session.sparkContext.statusStore.store.asInstanceOf[ElementTrackingStore]
kvstore.write(graphToStore)
}
case _ =>
}
}

Expand All @@ -87,7 +84,7 @@ private class CollectMetricsPrettyDisplayListener extends SparkListener with SQL
val metrics: Map[String, Any] =
c.collectedMetrics.getValuesMap[Any](c.metricsSchema.fieldNames)
val metricsString = redactMapString(metrics, SQLConf.get.maxToStringFields)
s"CollectMetricsExec(${c.name}) $metricsString"
s"CollectMetrics(${c.name}) $metricsString"
case p => p.simpleString(SQLConf.get.maxToStringFields)
}
new SparkPlanInfo(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql

import java.util.concurrent.atomic.AtomicLong

import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
import org.apache.spark.sql.execution.ui.SparkListenerSQLExecutionEnd
import org.scalatest.concurrent.PatienceConfiguration.Timeout
import org.scalatest.time.SpanSugar._

import org.apache.kyuubi.sql.KyuubiSQLConf.COLLECT_METRICS_PRETTY_DISPLAY_ENABLED

class CollectMetricsPrettyDisplaySuite extends KyuubiSparkSQLExtensionTest {

override protected def beforeAll(): Unit = {
super.beforeAll()
setupData()
}

test("collect metrics pretty display") {
withSQLConf(COLLECT_METRICS_PRETTY_DISPLAY_ENABLED.key -> "true") {
val executionId = new AtomicLong(-1)
val executionIdListener = new SparkListener {
override def onOtherEvent(event: SparkListenerEvent): Unit = {
event match {
case e: SparkListenerSQLExecutionEnd =>
executionId.set(e.executionId)
case _ =>
}
}
}

spark.sparkContext.addSparkListener(executionIdListener)
try {
import org.apache.spark.sql.functions._
spark.table("t1").observe("observer1", sum(col("c1")), count(lit(1))).collect()

eventually(Timeout(3.seconds)) {
assert(executionId.get() >= 0)
val sparkPlanGraph = spark.sharedState.statusStore.planGraph(executionId.get())
val collectMetricsDescs =
sparkPlanGraph.allNodes.filter(_.name == "CollectMetrics").map(_.desc)
assert(collectMetricsDescs.size == 1)
assert(collectMetricsDescs.head ==
"CollectMetrics(observer1) [sum(c1)=5050, count(1)=100]")
}
} finally {
spark.sparkContext.removeSparkListener(executionIdListener)
}
}
}
}

0 comments on commit 368ecce

Please sign in to comment.