From 1fd880700afaedefacae043d7eaa12041c4f91ff Mon Sep 17 00:00:00 2001
From: Chao Sun <sunchao@apache.org>
Date: Wed, 7 Feb 2024 00:57:38 -0800
Subject: [PATCH] fix

---
 .../scala/org/apache/spark/SparkEnv.scala     |  7 ++++-
 .../spark/memory/UnifiedMemoryManager.scala   |  2 +-
 .../apache/spark/storage/BlockManager.scala   | 30 ++++++++++++++-----
 .../BlockManagerReplicationSuite.scala        |  4 +--
 .../spark/storage/BlockManagerSuite.scala     | 10 +++----
 .../streaming/ReceivedBlockHandlerSuite.scala |  2 +-
 6 files changed, 37 insertions(+), 18 deletions(-)

diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index 6bdccae719053..5f0de27950783 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -416,6 +416,9 @@ object SparkEnv extends Logging {
       new NettyBlockTransferService(conf, securityManager, serializerManager, bindAddress,
         advertiseAddress, blockManagerPort, numUsableCores, blockManagerMaster.driverEndpoint)
 
+    val maxOnHeapMemory = UnifiedMemoryManager.getMaxMemory(conf)
+    val maxOffHeapMemory = conf.get(MEMORY_OFFHEAP_SIZE)
+
     // NB: blockManager is not valid until initialize() is called later.
     //     SPARK-45762 introduces a change where the ShuffleManager is initialized later
     //     in the SparkContext and Executor, to allow for custom ShuffleManagers defined
@@ -432,7 +435,9 @@ object SparkEnv extends Logging {
       _shuffleManager = null,
       blockTransferService,
       securityManager,
-      externalShuffleClient)
+      externalShuffleClient,
+      maxOnHeapMemory,
+      maxOffHeapMemory)
 
     val metricsSystem = if (isDriver) {
       // Don't start metrics system right now for Driver.
diff --git a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
index 73805c11e0371..1ed7d4e495e18 100644
--- a/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/memory/UnifiedMemoryManager.scala
@@ -210,7 +210,7 @@ object UnifiedMemoryManager {
   /**
    * Return the total amount of memory shared between execution and storage, in bytes.
    */
-  private def getMaxMemory(conf: SparkConf): Long = {
+  private[spark] def getMaxMemory(conf: SparkConf): Long = {
     val systemMemory = conf.get(TEST_MEMORY)
     val reservedMemory = conf.getLong(TEST_RESERVED_MEMORY.key,
       if (conf.contains(IS_TESTING)) 0 else RESERVED_SYSTEM_MEMORY_BYTES)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index b5d1c7ed69c8f..e5f698042bb60 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -190,7 +190,9 @@ private[spark] class BlockManager(
     private val _shuffleManager: ShuffleManager,
     val blockTransferService: BlockTransferService,
     securityManager: SecurityManager,
-    externalBlockStoreClient: Option[ExternalBlockStoreClient])
+    externalBlockStoreClient: Option[ExternalBlockStoreClient],
+    val maxOnHeapMemory: Long,
+    val maxOffHeapMemory: Long)
   extends BlockDataManager with BlockEvictionHandler with Logging {
 
   // We initialize the ShuffleManager later in SparkContext and Executor, to allow
@@ -236,13 +238,6 @@ private[spark] class BlockManager(
   }
   private[spark] val diskStore = new DiskStore(conf, diskBlockManager, securityManager)
 
-  // Note: depending on the memory manager, `maxMemory` may actually vary over time.
-  // However, since we use this only for reporting and logging, what we actually want here is
-  // the absolute maximum value that `maxMemory` can ever possibly reach. We may need
-  // to revisit whether reporting this value as the "max" is intuitive to the user.
-  private lazy val maxOnHeapMemory = memoryManager.maxOnHeapStorageMemory
-  private lazy val maxOffHeapMemory = memoryManager.maxOffHeapStorageMemory
-
   private[spark] val externalShuffleServicePort = StorageUtils.externalShuffleServicePort(conf)
 
   var blockManagerId: BlockManagerId = _
@@ -2157,6 +2152,25 @@ private[spark] class BlockManager(
 
 
 private[spark] object BlockManager {
+  // scalastyle:off argcount
+  def apply(
+    executorId: String,
+    rpcEnv: RpcEnv,
+    master: BlockManagerMaster,
+    serializerManager: SerializerManager,
+    conf: SparkConf,
+    memoryManager: MemoryManager,
+    mapOutputTracker: MapOutputTracker,
+    shuffleManager: ShuffleManager,
+    blockTransferService: BlockTransferService,
+    securityManager: SecurityManager,
+    externalBlockStoreClient: Option[ExternalBlockStoreClient]): BlockManager =
+  new BlockManager(executorId, rpcEnv, master, serializerManager, conf, memoryManager,
+      mapOutputTracker, shuffleManager, blockTransferService, securityManager,
+      externalBlockStoreClient, memoryManager.maxOnHeapStorageMemory,
+      memoryManager.maxOffHeapStorageMemory)
+  // scalastyle:on argcount
+
   private val ID_GENERATOR = new IdGenerator
 
   def blockIdsToLocations(
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
index 1fbc900727c4c..30c5525cb9fd5 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala
@@ -80,7 +80,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite
     val transfer = new NettyBlockTransferService(
       conf, securityMgr, serializerManager, "localhost", "localhost", 0, 1)
     val memManager = memoryManager.getOrElse(UnifiedMemoryManager(conf, numCores = 1))
-    val store = new BlockManager(name, rpcEnv, master, serializerManager, conf,
+    val store = BlockManager(name, rpcEnv, master, serializerManager, conf,
       memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, None)
     memManager.setMemoryStore(store.memoryStore)
     store.initialize("app-id")
@@ -242,7 +242,7 @@ trait BlockManagerReplicationBehavior extends SparkFunSuite
     conf.set(TEST_MEMORY, 10000L)
     val memManager = UnifiedMemoryManager(conf, numCores = 1)
     val serializerManager = new SerializerManager(serializer, conf)
-    val failableStore = new BlockManager("failable-store", rpcEnv, master, serializerManager, conf,
+    val failableStore = BlockManager("failable-store", rpcEnv, master, serializerManager, conf,
       memManager, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, None)
     memManager.setMemoryStore(failableStore.memoryStore)
     failableStore.initialize("app-id")
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 17dff20dd993b..f69b1f64b05d4 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -143,7 +143,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
     } else {
       None
     }
-    val blockManager = new BlockManager(name, rpcEnv, master, serializerManager, bmConf,
+    val blockManager = BlockManager(name, rpcEnv, master, serializerManager, bmConf,
       memManager, mapOutputTracker, shuffleManager, transfer, bmSecurityMgr, externalShuffleClient)
     memManager.setMemoryStore(blockManager.memoryStore)
     allStores += blockManager
@@ -1344,7 +1344,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
     val transfer = new NettyBlockTransferService(
       conf, securityMgr, serializerManager, "localhost", "localhost", 0, 1)
     val memoryManager = UnifiedMemoryManager(conf, numCores = 1)
-    val store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
+    val store = BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
       serializerManager, conf, memoryManager, mapOutputTracker,
       shuffleManager, transfer, securityMgr, None)
     allStores += store
@@ -1393,7 +1393,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
       val transfer = new NettyBlockTransferService(
         conf, securityMgr, serializerManager, "localhost", "localhost", 0, 1)
       val memoryManager = UnifiedMemoryManager(conf, numCores = 1)
-      val blockManager = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
+      val blockManager = BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
         serializerManager, conf, memoryManager, mapOutputTracker,
         shuffleManager, transfer, securityMgr, None)
       try {
@@ -2248,7 +2248,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
     val transfer = new NettyBlockTransferService(
       conf, securityMgr, serializerManager, "localhost", "localhost", 0, 1)
     val memoryManager = UnifiedMemoryManager(conf, numCores = 1)
-    val store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
+    val store = BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
       serializerManager, conf, memoryManager, mapOutputTracker,
       shuffleManager, transfer, securityMgr, None)
     allStores += store
@@ -2272,7 +2272,7 @@ class BlockManagerSuite extends SparkFunSuite with Matchers with PrivateMethodTe
     val transfer = new NettyBlockTransferService(
       conf, securityMgr, serializerManager, "localhost", "localhost", 0, 1)
     val memoryManager = UnifiedMemoryManager(conf, numCores = 1)
-    val store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
+    val store = BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master,
       serializerManager, conf, memoryManager, mapOutputTracker,
       shuffleManager, transfer, securityMgr, None)
     allStores += store
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
index 1bf74e6e9a36a..4ddb184360115 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala
@@ -290,7 +290,7 @@ abstract class BaseReceivedBlockHandlerSuite(enableEncryption: Boolean)
     val memManager = new UnifiedMemoryManager(conf, maxMem, maxMem / 2, 1)
     val transfer = new NettyBlockTransferService(
       conf, securityMgr, serializerManager, "localhost", "localhost", 0, 1)
-    val blockManager = new BlockManager(name, rpcEnv, blockManagerMaster, serializerManager, conf,
+    val blockManager = BlockManager(name, rpcEnv, blockManagerMaster, serializerManager, conf,
       memManager, mapOutputTracker, shuffleManager, transfer, securityMgr, None)
     memManager.setMemoryStore(blockManager.memoryStore)
     blockManager.initialize("app-id")