From 11c5a5e4045bdeee133cd8e6663e35a4458b8898 Mon Sep 17 00:00:00 2001 From: Sundeep Narravula Date: Tue, 25 Feb 2014 16:44:17 -0800 Subject: [PATCH] This patch addresses two cache issues 1. 'cache tablename' directive overwrites existing memory tables and can cause memory leaks when the cache command is used on an already cached table. caching new partition also causes a memory leak if the parition already exists. 2. using 'cache' directive on a partitioned table causes table to be corrupt. Only the last cached partition is accessable with the rest being leaked. --- .../scala/shark/execution/SparkLoadTask.scala | 19 +++++++-- .../memstore2/MemoryMetadataManager.scala | 39 ++++++++++++------- 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/src/main/scala/shark/execution/SparkLoadTask.scala b/src/main/scala/shark/execution/SparkLoadTask.scala index 0b47b8de..7dd481fb 100644 --- a/src/main/scala/shark/execution/SparkLoadTask.scala +++ b/src/main/scala/shark/execution/SparkLoadTask.scala @@ -304,8 +304,15 @@ class SparkLoadTask extends HiveTask[SparkLoadWork] with Serializable with LogHe if (work.cacheMode != CacheType.TACHYON) { val memoryTable = getOrCreateMemoryTable(hiveTable) work.commandType match { - case (SparkLoadWork.CommandTypes.OVERWRITE | SparkLoadWork.CommandTypes.NEW_ENTRY) => - memoryTable.put(tablePartitionRDD, tableStats.toMap) + case (SparkLoadWork.CommandTypes.OVERWRITE | SparkLoadWork.CommandTypes.NEW_ENTRY) => { + val prevRDDandStatsOpt = memoryTable.put(tablePartitionRDD, tableStats.toMap) + if (prevRDDandStatsOpt.isDefined){ + // Prevent memory leaks when partition is overwritten + val (prevRdd, prevStats) = (prevRDDandStatsOpt.get._1, prevRDDandStatsOpt.get._2) + RDDUtils.unpersistRDD(prevRdd) + } + + } case SparkLoadWork.CommandTypes.INSERT => { memoryTable.update(tablePartitionRDD, tableStats) } @@ -398,7 +405,13 @@ class SparkLoadTask extends HiveTask[SparkLoadWork] with Serializable with LogHe (work.commandType == SparkLoadWork.CommandTypes.INSERT)) { partitionedTable.updatePartition(partitionKey, tablePartitionRDD, tableStats) } else { - partitionedTable.putPartition(partitionKey, tablePartitionRDD, tableStats.toMap) + val prevRDDandStatsOpt = partitionedTable.putPartition(partitionKey, tablePartitionRDD, tableStats.toMap) + if (prevRDDandStatsOpt.isDefined){ + // Prevent memory leaks when partition is overwritten + val (prevRdd, prevStats) = (prevRDDandStatsOpt.get._1, prevRDDandStatsOpt.get._2) + RDDUtils.unpersistRDD(prevRdd) + } + } } } diff --git a/src/main/scala/shark/memstore2/MemoryMetadataManager.scala b/src/main/scala/shark/memstore2/MemoryMetadataManager.scala index b0e1c8e5..35a9ebba 100755 --- a/src/main/scala/shark/memstore2/MemoryMetadataManager.scala +++ b/src/main/scala/shark/memstore2/MemoryMetadataManager.scala @@ -54,10 +54,16 @@ class MemoryMetadataManager extends LogHelper { databaseName: String, tableName: String, cacheMode: CacheType.CacheType): MemoryTable = { - val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName) - val newTable = new MemoryTable(databaseName, tableName, cacheMode) - _tables.put(tableKey, newTable) - newTable + val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName) + // Clear out any existing tables with the same key; prevent memory leak + if (containsTable(databaseName, tableName)) { + logInfo("Attempt to create new table when one already exists - " + tableKey) + _tables.get(tableKey).get.asInstanceOf[MemoryTable] + } else { + val newTable = new MemoryTable(databaseName, tableName, cacheMode) + _tables.put(tableKey, newTable) + newTable + } } def createPartitionedMemoryTable( @@ -67,16 +73,21 @@ class MemoryMetadataManager extends LogHelper { tblProps: JavaMap[String, String] ): PartitionedMemoryTable = { val tableKey = MemoryMetadataManager.makeTableKey(databaseName, tableName) - val newTable = new PartitionedMemoryTable(databaseName, tableName, cacheMode) - // Determine the cache policy to use and read any user-specified cache settings. - val cachePolicyStr = tblProps.getOrElse(SharkTblProperties.CACHE_POLICY.varname, - SharkTblProperties.CACHE_POLICY.defaultVal) - val maxCacheSize = tblProps.getOrElse(SharkTblProperties.MAX_PARTITION_CACHE_SIZE.varname, - SharkTblProperties.MAX_PARTITION_CACHE_SIZE.defaultVal).toInt - newTable.setPartitionCachePolicy(cachePolicyStr, maxCacheSize) - - _tables.put(tableKey, newTable) - newTable + // Clear out any existing tables with the same key; prevent memory leak + if (containsTable(databaseName, tableName)) { + logInfo("Attempt to create new table when one already exists - " + tableKey) + _tables.get(tableKey).get.asInstanceOf[PartitionedMemoryTable] + } else { + val newTable = new PartitionedMemoryTable(databaseName, tableName, cacheMode) + // Determine the cache policy to use and read any user-specified cache settings. + val cachePolicyStr = tblProps.getOrElse(SharkTblProperties.CACHE_POLICY.varname, + SharkTblProperties.CACHE_POLICY.defaultVal) + val maxCacheSize = tblProps.getOrElse(SharkTblProperties.MAX_PARTITION_CACHE_SIZE.varname, + SharkTblProperties.MAX_PARTITION_CACHE_SIZE.defaultVal).toInt + newTable.setPartitionCachePolicy(cachePolicyStr, maxCacheSize) + _tables.put(tableKey, newTable) + newTable + } } def getTable(databaseName: String, tableName: String): Option[Table] = {