Skip to content

Commit 8d500b8

Browse files
cloud-fanterana
authored andcommitted
[SPARK-56111][CORE] Add SparkContext.isDriver() and use it across the codebase
### What changes were proposed in this pull request? Add a `SparkContext.isDriver(executorId)` utility method and replace all manual `executorId == SparkContext.DRIVER_IDENTIFIER` / `executorId != SparkContext.DRIVER_IDENTIFIER` checks across the codebase with it. ### Why are the changes needed? The driver-detection pattern `executorId == SparkContext.DRIVER_IDENTIFIER` is duplicated in ~20 places across core, SQL, streaming, profiler, and Kubernetes modules. Centralizing it in a single method improves readability and reduces duplication. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Existing tests — this is a pure refactoring with no behavior change. ### Was this patch authored or co-authored using generative AI tooling? Generated-by: Claude Opus 4.6 Closes apache#54922 from cloud-fan/SPARK-53915-followup. Authored-by: Wenchen Fan <wenchen@databricks.com> Signed-off-by: Dongjoon Hyun <dongjoon@apache.org>
1 parent 7a47aba commit 8d500b8

File tree

20 files changed

+33
-32
lines changed

20 files changed

+33
-32
lines changed

connector/profiler/src/main/scala/org/apache/spark/profiler/SparkAsyncProfiler.scala

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@ import one.profiler.{AsyncProfiler, AsyncProfilerLoader}
2323
import org.apache.hadoop.fs.{FileSystem, FSDataOutputStream, Path}
2424
import org.apache.hadoop.fs.permission.FsPermission
2525

26-
import org.apache.spark.SparkConf
27-
import org.apache.spark.SparkContext.DRIVER_IDENTIFIER
26+
import org.apache.spark.{SparkConf, SparkContext}
2827
import org.apache.spark.deploy.SparkHadoopUtil
2928
import org.apache.spark.internal.Logging
3029
import org.apache.spark.internal.LogKeys.PATH
@@ -45,7 +44,7 @@ private[spark] class SparkAsyncProfiler(conf: SparkConf, executorId: String) ext
4544
private def getAppId: Option[String] = conf.getOption("spark.app.id")
4645
private def getAttemptId: Option[String] = conf.getOption("spark.app.attempt.id")
4746

48-
private val profileFile = if (executorId == DRIVER_IDENTIFIER) {
47+
private val profileFile = if (SparkContext.isDriver(executorId)) {
4948
s"profile-$executorId.jfr"
5049
} else {
5150
s"profile-exec-$executorId.jfr"

core/src/main/scala/org/apache/spark/SparkContext.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ class SparkContext(config: SparkConf) extends Logging {
754754
*/
755755
private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = {
756756
try {
757-
if (executorId == SparkContext.DRIVER_IDENTIFIER) {
757+
if (SparkContext.isDriver(executorId)) {
758758
Some(Utils.getThreadDump())
759759
} else {
760760
env.blockManager.master.getExecutorEndpointRef(executorId) match {
@@ -786,7 +786,7 @@ class SparkContext(config: SparkConf) extends Logging {
786786
*/
787787
private[spark] def getExecutorHeapHistogram(executorId: String): Option[Array[String]] = {
788788
try {
789-
if (executorId == SparkContext.DRIVER_IDENTIFIER) {
789+
if (SparkContext.isDriver(executorId)) {
790790
Some(Utils.getHeapHistogram())
791791
} else {
792792
env.blockManager.master.getExecutorEndpointRef(executorId) match {
@@ -3163,6 +3163,11 @@ object SparkContext extends Logging {
31633163
/** Separator of tags in SPARK_JOB_TAGS property */
31643164
private[spark] val SPARK_JOB_TAGS_SEP = ","
31653165

3166+
/** Returns true if the given executor ID identifies the driver. */
3167+
private[spark] def isDriver(executorId: String): Boolean = {
3168+
DRIVER_IDENTIFIER == executorId
3169+
}
3170+
31663171
// Same rules apply to Spark Connect execution tags, see ExecuteHolder.throwIfInvalidTag
31673172
private[spark] def throwIfInvalidTag(tag: String) = {
31683173
if (tag == null) {

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ class SparkEnv (
248248
Preconditions.checkState(null == _shuffleManager,
249249
"Shuffle manager already initialized to %s", _shuffleManager)
250250
try {
251-
_shuffleManager = ShuffleManager.create(conf, executorId == SparkContext.DRIVER_IDENTIFIER)
251+
_shuffleManager = ShuffleManager.create(conf, SparkContext.isDriver(executorId))
252252
} finally {
253253
// Signal that the ShuffleManager has been initialized
254254
shuffleManagerInitLatch.countDown()
@@ -356,7 +356,7 @@ object SparkEnv extends Logging {
356356
listenerBus: LiveListenerBus = null,
357357
mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = {
358358

359-
val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER
359+
val isDriver = SparkContext.isDriver(executorId)
360360

361361
// Listener bus is only used on the driver
362362
if (isDriver) {

core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ private class SharedMessageLoop(
116116
.getOrElse(math.max(2, availableCores))
117117

118118
conf.get(EXECUTOR_ID).map { id =>
119-
val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor"
119+
val role = if (SparkContext.isDriver(id)) "driver" else "executor"
120120
conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", modNumThreads)
121121
}.getOrElse(modNumThreads)
122122
}

core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ private[netty] class NettyRpcEnv(
5050
securityManager: SecurityManager,
5151
numUsableCores: Int) extends RpcEnv(conf) with Logging {
5252
val role = conf.get(EXECUTOR_ID).map { id =>
53-
if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor"
53+
if (SparkContext.isDriver(id)) "driver" else "executor"
5454
}
5555

5656
private[netty] val transportConf = SparkTransportConf.fromSparkConf(

core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ private[spark] class EventLoggingListener(
249249

250250
override def onExecutorMetricsUpdate(event: SparkListenerExecutorMetricsUpdate): Unit = {
251251
if (shouldLogStageExecutorMetrics) {
252-
if (event.execId == SparkContext.DRIVER_IDENTIFIER) {
252+
if (SparkContext.isDriver(event.execId)) {
253253
logEvent(event)
254254
}
255255
event.executorUpdates.foreach { case (stageKey1, newPeaks) =>

core/src/main/scala/org/apache/spark/shuffle/ShuffleBlockPusher.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ private[spark] object ShuffleBlockPusher {
502502
private val BLOCK_PUSHER_POOL: ExecutorService = {
503503
val conf = SparkEnv.get.conf
504504
if (Utils.isPushBasedShuffleEnabled(conf,
505-
isDriver = SparkContext.DRIVER_IDENTIFIER == SparkEnv.get.executorId)) {
505+
isDriver = SparkContext.isDriver(SparkEnv.get.executorId))) {
506506
val numThreads = conf.get(SHUFFLE_NUM_PUSH_THREADS)
507507
.getOrElse(conf.getInt(SparkLauncher.EXECUTOR_CORES, 1))
508508
ThreadUtils.newDaemonFixedThreadPool(numThreads, "shuffle-block-push-thread")

core/src/main/scala/org/apache/spark/status/AppStatusListener.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -362,11 +362,11 @@ private[spark] class AppStatusListener(
362362
// Implicitly exclude every available executor for the stage associated with this node
363363
Option(liveStages.get((stageId, stageAttemptId))).foreach { stage =>
364364
val executorIds = liveExecutors.values.filter(exec => exec.host == hostId
365-
&& exec.executorId != SparkContext.DRIVER_IDENTIFIER).map(_.executorId).toSeq
365+
&& !SparkContext.isDriver(exec.executorId)).map(_.executorId).toSeq
366366
setStageExcludedStatus(stage, now, executorIds: _*)
367367
}
368368
liveExecutors.values.filter(exec => exec.hostname == hostId
369-
&& exec.executorId != SparkContext.DRIVER_IDENTIFIER).foreach { exec =>
369+
&& !SparkContext.isDriver(exec.executorId)).foreach { exec =>
370370
addExcludedStageTo(exec, stageId, now)
371371
}
372372
}
@@ -413,7 +413,7 @@ private[spark] class AppStatusListener(
413413

414414
// Implicitly (un)exclude every executor associated with the node.
415415
liveExecutors.values.foreach { exec =>
416-
if (exec.hostname == host && exec.executorId != SparkContext.DRIVER_IDENTIFIER) {
416+
if (exec.hostname == host && !SparkContext.isDriver(exec.executorId)) {
417417
updateExecExclusionStatus(exec, excluded, now)
418418
}
419419
}

core/src/main/scala/org/apache/spark/status/AppStatusStore.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ private[spark] class AppStatusStore(
103103
}
104104

105105
private def replaceExec(origin: v1.ExecutorSummary): v1.ExecutorSummary = {
106-
if (origin.id == SparkContext.DRIVER_IDENTIFIER) {
106+
if (SparkContext.isDriver(origin.id)) {
107107
replaceDriverGcTime(origin, extractGcTime(origin), extractAppTime)
108108
} else {
109109
origin

core/src/main/scala/org/apache/spark/status/api/v1/OneApplicationResource.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ private[v1] class AbstractApplicationResource extends BaseAppResource {
182182
}
183183

184184
private def checkExecutorId(execId: String): Unit = {
185-
if (execId != SparkContext.DRIVER_IDENTIFIER && !execId.forall(Character.isDigit)) {
185+
if (!SparkContext.isDriver(execId) && !execId.forall(Character.isDigit)) {
186186
throw new BadParameterException(
187187
s"Invalid executorId: neither '${SparkContext.DRIVER_IDENTIFIER}' nor number.")
188188
}

0 commit comments

Comments
 (0)