Skip to content

Commit

Permalink
[SPARK-44776][CONNECT] Add ProducedRowCount to SparkListenerConnectOp…
Browse files Browse the repository at this point in the history
…erationFinished

### What changes were proposed in this pull request?
Add ProducedRowCount field to SparkListenerConnectOperationFinished

### Why are the changes needed?
Needed for showing number of rows getting produced

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Added Unit test

Closes apache#42454 from gjxdxh/SPARK-44776.

Authored-by: Lingkai Kong <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
gjxdxh authored and HyukjinKwon committed Aug 22, 2023
1 parent 8d4ca0a commit 4646991
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 101 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
errorOnDuplicatedFieldNames = false)

var numSent = 0
var totalNumRows: Long = 0
def sendBatch(bytes: Array[Byte], count: Long): Unit = {
val response = proto.ExecutePlanResponse.newBuilder().setSessionId(sessionId)
val batch = proto.ExecutePlanResponse.ArrowBatch
Expand All @@ -120,14 +121,15 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
response.setArrowBatch(batch)
responseObserver.onNext(response.build())
numSent += 1
totalNumRows += count
}

dataframe.queryExecution.executedPlan match {
case LocalTableScanExec(_, rows) =>
executePlan.eventsManager.postFinished()
converter(rows.iterator).foreach { case (bytes, count) =>
sendBatch(bytes, count)
}
executePlan.eventsManager.postFinished(Some(totalNumRows))
case _ =>
SQLExecution.withNewExecutionId(dataframe.queryExecution, Some("collectArrow")) {
val rows = dataframe.queryExecution.executedPlan.execute()
Expand Down Expand Up @@ -162,8 +164,7 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
resultFunc = () => ())
// Collect errors and propagate them to the main thread.
.andThen {
case Success(_) =>
executePlan.eventsManager.postFinished()
case Success(_) => // do nothing
case Failure(throwable) =>
signal.synchronized {
error = Some(throwable)
Expand Down Expand Up @@ -200,8 +201,9 @@ private[execution] class SparkConnectPlanExecution(executeHolder: ExecuteHolder)
currentPartitionId += 1
}
ThreadUtils.awaitReady(future, Duration.Inf)
executePlan.eventsManager.postFinished(Some(totalNumRows))
} else {
executePlan.eventsManager.postFinished()
executePlan.eventsManager.postFinished(Some(totalNumRows))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2509,7 +2509,7 @@ class SparkConnectPlanner(val sessionHolder: SessionHolder) extends Logging {
.putAllArgs(getSqlCommand.getArgsMap)
.addAllPosArgs(getSqlCommand.getPosArgsList)))
}
executeHolder.eventsManager.postFinished()
executeHolder.eventsManager.postFinished(Some(rows.size))
// Exactly one SQL Command Result Batch
responseObserver.onNext(
ExecutePlanResponse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) {

private var canceled = Option.empty[Boolean]

private var producedRowCount = Option.empty[Long]

/**
* @return
* Last event posted by the Connect request
Expand All @@ -95,6 +97,13 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) {
*/
private[connect] def hasError: Option[Boolean] = error

/**
* @return
* How many rows the Connect request has produced @link
* org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished
*/
private[connect] def getProducedRowCount: Option[Long] = producedRowCount

/**
* Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationStarted.
*/
Expand Down Expand Up @@ -192,13 +201,23 @@ case class ExecuteEventsManager(executeHolder: ExecuteHolder, clock: Clock) {

/**
* Post @link org.apache.spark.sql.connect.service.SparkListenerConnectOperationFinished.
* @param producedRowsCountOpt
* Number of rows that are returned to the user. None is expected when the operation does not
* return any rows.
*/
def postFinished(): Unit = {
def postFinished(producedRowsCountOpt: Option[Long] = None): Unit = {
assertStatus(
List(ExecuteStatus.Started, ExecuteStatus.ReadyForExecution),
ExecuteStatus.Finished)
producedRowCount = producedRowsCountOpt

listenerBus
.post(SparkListenerConnectOperationFinished(jobTag, operationId, clock.getTimeMillis()))
.post(
SparkListenerConnectOperationFinished(
jobTag,
operationId,
clock.getTimeMillis(),
producedRowCount))
}

/**
Expand Down Expand Up @@ -395,13 +414,17 @@ case class SparkListenerConnectOperationFailed(
* 36 characters UUID assigned by Connect during a request.
* @param eventTime:
* The time in ms when the event was generated.
* @param producedRowCount:
* Number of rows that are returned to the user. None is expected when the operation does not
* return any rows.
* @param extraTags:
* Additional metadata during the request.
*/
case class SparkListenerConnectOperationFinished(
jobTag: String,
operationId: String,
eventTime: Long,
producedRowCount: Option[Long] = None,
extraTags: Map[String, String] = Map.empty)
extends SparkListenerEvent

Expand Down
Loading

0 comments on commit 4646991

Please sign in to comment.