Skip to content

Commit

Permalink
fix scalafmt
Browse files Browse the repository at this point in the history
  • Loading branch information
zeotuan committed May 3, 2024
1 parent c9e0209 commit fa76c39
Show file tree
Hide file tree
Showing 20 changed files with 2,348 additions and 1,756 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ trait LogKey {
}

/**
* Various keys used for mapped diagnostic contexts(MDC) in logging.
* All structured logging keys should be defined here for standardization.
* Various keys used for mapped diagnostic contexts(MDC) in logging. All structured logging keys
* should be defined here for standardization.
*/
object LogKeys {
case object ACCUMULATOR_ID extends LogKey
Expand Down
118 changes: 70 additions & 48 deletions core/src/main/scala/org/apache/spark/BarrierCoordinator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,18 @@ private case class ContextBarrierId(stageId: Int, stageAttemptId: Int) {

/**
* A coordinator that handles all global sync requests from BarrierTaskContext. Each global sync
* request is generated by `BarrierTaskContext.barrier()`, and identified by
* stageId + stageAttemptId + barrierEpoch. Reply all the blocking global sync requests upon
* all the requests for a group of `barrier()` calls are received. If the coordinator is unable to
* collect enough global sync requests within a configured time, fail all the requests and return
* an Exception with timeout message.
* request is generated by `BarrierTaskContext.barrier()`, and identified by stageId +
* stageAttemptId + barrierEpoch. Reply all the blocking global sync requests upon all the
* requests for a group of `barrier()` calls are received. If the coordinator is unable to collect
* enough global sync requests within a configured time, fail all the requests and return an
* Exception with timeout message.
*/
private[spark] class BarrierCoordinator(
timeoutInSecs: Long,
listenerBus: LiveListenerBus,
override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging {
override val rpcEnv: RpcEnv)
extends ThreadSafeRpcEndpoint
with Logging {

// TODO SPARK-25030 Create a Timer() in the mainClass submitted to SparkSubmit makes it unable to
// fetch result, we shall fix the issue.
Expand Down Expand Up @@ -90,21 +92,22 @@ private[spark] class BarrierCoordinator(
* Provide the current state of a barrier() call. A state is created when a new stage attempt
* sends out a barrier() call, and recycled on stage completed.
*
* @param barrierId Identifier of the barrier stage that make a barrier() call.
* @param numTasks Number of tasks of the barrier stage, all barrier() calls from the stage shall
* collect `numTasks` requests to succeed.
* @param barrierId
* Identifier of the barrier stage that make a barrier() call.
* @param numTasks
* Number of tasks of the barrier stage, all barrier() calls from the stage shall collect
* `numTasks` requests to succeed.
*/
private class ContextBarrierState(
val barrierId: ContextBarrierId,
val numTasks: Int) {
private class ContextBarrierState(val barrierId: ContextBarrierId, val numTasks: Int) {

// There may be multiple barrier() calls from a barrier stage attempt, `barrierEpoch` is used
// to identify each barrier() call. It shall get increased when a barrier() call succeeds, or
// reset when a barrier() call fails due to timeout.
private var barrierEpoch: Int = 0

// An Array of RPCCallContexts for barrier tasks that have made a blocking runBarrier() call
private val requesters: ArrayBuffer[RpcCallContext] = new ArrayBuffer[RpcCallContext](numTasks)
private val requesters: ArrayBuffer[RpcCallContext] =
new ArrayBuffer[RpcCallContext](numTasks)

// Messages from each barrier task that have made a blocking runBarrier() call.
// The messages will be replied to all tasks once sync finished.
Expand All @@ -124,9 +127,11 @@ private[spark] class BarrierCoordinator(
timerTask = new TimerTask {
override def run(): Unit = state.synchronized {
// Timeout current barrier() call, fail all the sync requests.
requesters.foreach(_.sendFailure(new SparkException("The coordinator didn't get all " +
s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " +
s"$timeoutInSecs second(s).")))
requesters.foreach(
_.sendFailure(
new SparkException("The coordinator didn't get all " +
s"barrier sync requests for barrier epoch $barrierEpoch from $barrierId within " +
s"$timeoutInSecs second(s).")))
cleanupBarrierStage(barrierId)
}
}
Expand All @@ -149,25 +154,31 @@ private[spark] class BarrierCoordinator(
val curReqMethod = request.requestMethod
requestMethods.add(curReqMethod)
if (requestMethods.size > 1) {
val error = new SparkException(s"Different barrier sync types found for the " +
s"sync $barrierId: ${requestMethods.mkString(", ")}. Please use the " +
s"same barrier sync type within a single sync.")
val error = new SparkException(
s"Different barrier sync types found for the " +
s"sync $barrierId: ${requestMethods.mkString(", ")}. Please use the " +
s"same barrier sync type within a single sync.")
(requesters :+ requester).foreach(_.sendFailure(error))
clear()
return
}

// Require the number of tasks is correctly set from the BarrierTaskContext.
require(request.numTasks == numTasks, s"Number of tasks of $barrierId is " +
s"${request.numTasks} from Task $taskId, previously it was $numTasks.")
require(
request.numTasks == numTasks,
s"Number of tasks of $barrierId is " +
s"${request.numTasks} from Task $taskId, previously it was $numTasks.")

// Check whether the epoch from the barrier tasks matches current barrierEpoch.
logInfo(log"Current barrier epoch for ${MDC(BARRIER_ID, barrierId)}" +
log" is ${MDC(BARRIER_EPOCH, barrierEpoch)}.")
logInfo(
log"Current barrier epoch for ${MDC(BARRIER_ID, barrierId)}" +
log" is ${MDC(BARRIER_EPOCH, barrierEpoch)}.")
if (epoch != barrierEpoch) {
requester.sendFailure(new SparkException(s"The request to sync of $barrierId with " +
s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " +
"properly killed."))
requester.sendFailure(
new SparkException(
s"The request to sync of $barrierId with " +
s"barrier epoch $barrierEpoch has already finished. Maybe task $taskId is not " +
"properly killed."))
} else {
// If this is the first sync message received for a barrier() call, start timer to ensure
// we may timeout for the sync.
Expand All @@ -186,9 +197,10 @@ private[spark] class BarrierCoordinator(
requesters.foreach(_.reply(messages.clone()))
// Finished current barrier() call successfully, clean up ContextBarrierState and
// increase the barrier epoch.
logInfo(log"Barrier sync epoch ${MDC(BARRIER_EPOCH, barrierEpoch)}" +
log" from ${MDC(BARRIER_ID, barrierId)} received all updates from" +
log" tasks, finished successfully.")
logInfo(
log"Barrier sync epoch ${MDC(BARRIER_EPOCH, barrierEpoch)}" +
log" from ${MDC(BARRIER_ID, barrierId)} received all updates from" +
log" tasks, finished successfully.")
barrierEpoch += 1
requesters.clear()
requestMethods.clear()
Expand Down Expand Up @@ -219,7 +231,8 @@ private[spark] class BarrierCoordinator(
case request @ RequestToSync(numTasks, stageId, stageAttemptId, _, _, _, _, _) =>
// Get or init the ContextBarrierState correspond to the stage attempt.
val barrierId = ContextBarrierId(stageId, stageAttemptId)
states.computeIfAbsent(barrierId,
states.computeIfAbsent(
barrierId,
(key: ContextBarrierId) => new ContextBarrierState(key, numTasks))
val barrierState = states.get(barrierId)

Expand All @@ -234,27 +247,36 @@ private[spark] class BarrierCoordinator(
private[spark] sealed trait BarrierCoordinatorMessage extends Serializable

/**
* A global sync request message from BarrierTaskContext. Each request is
* identified by stageId + stageAttemptId + barrierEpoch.
* A global sync request message from BarrierTaskContext. Each request is identified by stageId +
* stageAttemptId + barrierEpoch.
*
* @param numTasks The number of global sync requests the BarrierCoordinator shall receive
* @param stageId ID of current stage
* @param stageAttemptId ID of current stage attempt
* @param taskAttemptId Unique ID of current task
* @param barrierEpoch ID of a runBarrier() call, a task may consist multiple runBarrier() calls
* @param partitionId ID of the current partition the task is assigned to
* @param message Message sent from the BarrierTaskContext
* @param requestMethod The BarrierTaskContext method that was called to trigger BarrierCoordinator
* @param numTasks
* The number of global sync requests the BarrierCoordinator shall receive
* @param stageId
* ID of current stage
* @param stageAttemptId
* ID of current stage attempt
* @param taskAttemptId
* Unique ID of current task
* @param barrierEpoch
* ID of a runBarrier() call, a task may consist multiple runBarrier() calls
* @param partitionId
* ID of the current partition the task is assigned to
* @param message
* Message sent from the BarrierTaskContext
* @param requestMethod
* The BarrierTaskContext method that was called to trigger BarrierCoordinator
*/
private[spark] case class RequestToSync(
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
message: String,
requestMethod: RequestMethod.Value) extends BarrierCoordinatorMessage
numTasks: Int,
stageId: Int,
stageAttemptId: Int,
taskAttemptId: Long,
barrierEpoch: Int,
partitionId: Int,
message: String,
requestMethod: RequestMethod.Value)
extends BarrierCoordinatorMessage

private[spark] object RequestMethod extends Enumeration {
val BARRIER, ALL_GATHER = Value
Expand Down
80 changes: 45 additions & 35 deletions core/src/main/scala/org/apache/spark/BarrierTaskContext.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ package org.apache.spark

import java.util.{Properties, TimerTask}
import java.util.concurrent.{ScheduledThreadPoolExecutor, TimeUnit}

import scala.concurrent.duration._
import scala.jdk.CollectionConverters._
import scala.util.{Failure, Try, Success => ScalaSuccess}
import scala.util.{Failure, Success => ScalaSuccess, Try}

import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.internal.{LogEntry, Logging, MDC, MessageWithContext}
Expand All @@ -34,14 +36,15 @@ import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.util._

/**
* :: Experimental ::
* A [[TaskContext]] with extra contextual info and tooling for tasks in a barrier stage.
* Use [[BarrierTaskContext#get]] to obtain the barrier context for a running barrier task.
* :: Experimental :: A [[TaskContext]] with extra contextual info and tooling for tasks in a
* barrier stage. Use [[BarrierTaskContext#get]] to obtain the barrier context for a running
* barrier task.
*/
@Experimental
@Since("2.4.0")
class BarrierTaskContext private[spark] (
taskContext: TaskContext) extends TaskContext with Logging {
class BarrierTaskContext private[spark] (taskContext: TaskContext)
extends TaskContext
with Logging {

import BarrierTaskContext._

Expand All @@ -56,13 +59,15 @@ class BarrierTaskContext private[spark] (
private var barrierEpoch = 0

private def logProgressInfo(msg: MessageWithContext, startTime: Option[Long]): Unit = {
val waitMsg = startTime.fold(log"")(st => log", waited " +
log"for ${MDC(TOTAL_TIME, System.currentTimeMillis() - st)} ms,")
logInfo(log"Task ${MDC(TASK_ATTEMPT_ID, taskAttemptId())}" +
log" from Stage ${MDC(STAGE_ID, stageId())}" +
log"(Attempt ${MDC(STAGE_ATTEMPT, stageAttemptNumber())}) " +
msg + waitMsg +
log" current barrier epoch is ${MDC(BARRIER_EPOCH, barrierEpoch)}.")
val waitMsg = startTime.fold(log"")(st =>
log", waited " +
log"for ${MDC(TOTAL_TIME, System.currentTimeMillis() - st)} ms,")
logInfo(
log"Task ${MDC(TASK_ATTEMPT_ID, taskAttemptId())}" +
log" from Stage ${MDC(STAGE_ID, stageId())}" +
log"(Attempt ${MDC(STAGE_ATTEMPT, stageAttemptNumber())}) " +
msg + waitMsg +
log" current barrier epoch is ${MDC(BARRIER_EPOCH, barrierEpoch)}.")
}

private def runBarrier(message: String, requestMethod: RequestMethod.Value): Array[String] = {
Expand All @@ -74,17 +79,23 @@ class BarrierTaskContext private[spark] (
override def run(): Unit = {
logProgressInfo(
log"waiting under the global sync since ${MDC(TIME, startTime)}",
Some(startTime)
)
Some(startTime))
}
}
// Log the update of global sync every 1 minute.
timer.scheduleAtFixedRate(timerTask, 1, 1, TimeUnit.MINUTES)

try {
val abortableRpcFuture = barrierCoordinator.askAbortable[Array[String]](
message = RequestToSync(numPartitions(), stageId(), stageAttemptNumber(), taskAttemptId(),
barrierEpoch, partitionId(), message, requestMethod),
message = RequestToSync(
numPartitions(),
stageId(),
stageAttemptNumber(),
taskAttemptId(),
barrierEpoch,
partitionId(),
message,
requestMethod),
// Set a fixed timeout for RPC here, so users shall get a SparkException thrown by
// BarrierCoordinator on timeout, instead of RPCTimeoutException from the RPC framework.
timeout = new RpcTimeout(365.days, "barrierTimeout"))
Expand Down Expand Up @@ -124,17 +135,16 @@ class BarrierTaskContext private[spark] (
}

/**
* :: Experimental ::
* Sets a global barrier and waits until all tasks in this stage hit this barrier. Similar to
* MPI_Barrier function in MPI, the barrier() function call blocks until all tasks in the same
* stage have reached this routine.
* :: Experimental :: Sets a global barrier and waits until all tasks in this stage hit this
* barrier. Similar to MPI_Barrier function in MPI, the barrier() function call blocks until all
* tasks in the same stage have reached this routine.
*
* CAUTION! In a barrier stage, each task must have the same number of barrier() calls, in all
* possible code branches. Otherwise, you may get the job hanging or a SparkException after
* timeout. Some examples of '''misuses''' are listed below:
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
* shall lead to timeout of the function call.
* {{{
* 1. Only call barrier() function on a subset of all the tasks in the same barrier stage, it
* shall lead to timeout of the function call.
* {{{
* rdd.barrier().mapPartitions { iter =>
* val context = BarrierTaskContext.get()
* if (context.partitionId() == 0) {
Expand All @@ -144,7 +154,7 @@ class BarrierTaskContext private[spark] (
* }
* iter
* }
* }}}
* }}}
*
* 2. Include barrier() function in a try-catch code block, this may lead to timeout of the
* second function call.
Expand All @@ -168,22 +178,22 @@ class BarrierTaskContext private[spark] (
def barrier(): Unit = runBarrier("", RequestMethod.BARRIER)

/**
* :: Experimental ::
* Blocks until all tasks in the same stage have reached this routine. Each task passes in
* a message and returns with a list of all the messages passed in by each of those tasks.
* :: Experimental :: Blocks until all tasks in the same stage have reached this routine. Each
* task passes in a message and returns with a list of all the messages passed in by each of
* those tasks.
*
* CAUTION! The allGather method requires the same precautions as the barrier method
*
* The message is type String rather than Array[Byte] because it is more convenient for
* the user at the cost of worse performance.
* The message is type String rather than Array[Byte] because it is more convenient for the user
* at the cost of worse performance.
*/
@Experimental
@Since("3.0.0")
def allGather(message: String): Array[String] = runBarrier(message, RequestMethod.ALL_GATHER)

/**
* :: Experimental ::
* Returns [[BarrierTaskInfo]] for all tasks in this barrier stage, ordered by partition ID.
* :: Experimental :: Returns [[BarrierTaskInfo]] for all tasks in this barrier stage, ordered
* by partition ID.
*/
@Experimental
@Since("2.4.0")
Expand Down Expand Up @@ -276,10 +286,10 @@ class BarrierTaskContext private[spark] (
@Experimental
@Since("2.4.0")
object BarrierTaskContext {

/**
* :: Experimental ::
* Returns the currently active BarrierTaskContext. This can be called inside of user functions to
* access contextual information about running barrier tasks.
* :: Experimental :: Returns the currently active BarrierTaskContext. This can be called inside
* of user functions to access contextual information about running barrier tasks.
*/
@Experimental
@Since("2.4.0")
Expand Down
Loading

0 comments on commit fa76c39

Please sign in to comment.