Skip to content

Commit

Permalink
[SPARK-47818][CONNECT] Introduce plan cache in SparkConnectPlanner to…
Browse files Browse the repository at this point in the history
… improve performance of Analyze requests

### What changes were proposed in this pull request?

While building the DataFrame step by step, each time a new DataFrame is generated with an empty schema, which is lazily computed on access. However, if a user's code frequently accesses the schema of these new DataFrames using methods such as `df.columns`, it will result in a large number of Analyze requests to the server. Each time, the entire plan needs to be reanalyzed, leading to poor performance, especially when constructing highly complex plans.

Now, by introducing plan cache in SparkConnectPlanner, we aim to reduce the overhead of repeated analysis during this process. This is achieved by saving significant computation if the resolved logical plan of a subtree of can be cached.

A minimal example of the problem:

```
import pyspark.sql.functions as F
df = spark.range(10)
for i in range(200):
  if str(i) not in df.columns: # <-- The df.columns call causes a new Analyze request in every iteration
    df = df.withColumn(str(i), F.col("id") + i)
df.show()
```

With this patch, the performance of the above code improved from ~110s to ~5s.

### Why are the changes needed?

The performance improvement is huge in the above cases.

### Does this PR introduce _any_ user-facing change?

Yes, a static conf `spark.connect.session.planCache.maxSize` and a dynamic conf `spark.connect.session.planCache.enabled` are added.

* `spark.connect.session.planCache.maxSize`: Sets the maximum number of cached resolved logical plans in Spark Connect Session. If set to a value less or equal than zero will disable the plan cache
* `spark.connect.session.planCache.enabled`: When true, the cache of resolved logical plans is enabled if `spark.connect.session.planCache.maxSize` is greater than zero. When false, the cache is disabled even if `spark.connect.session.planCache.maxSize` is greater than zero. The caching is best-effort and not guaranteed.

### How was this patch tested?

Some new tests are added in SparkConnectSessionHolderSuite.scala.

### Was this patch authored or co-authored using generative AI tooling?

No.

Closes apache#46012 from xi-db/SPARK-47818-plan-cache.

Lead-authored-by: Xi Lyu <[email protected]>
Co-authored-by: Xi Lyu <[email protected]>
Signed-off-by: Hyukjin Kwon <[email protected]>
  • Loading branch information
2 people authored and HyukjinKwon committed Apr 16, 2024
1 parent 6762d1f commit a1fc6d5
Show file tree
Hide file tree
Showing 5 changed files with 345 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -273,4 +273,22 @@ object Connect {
.version("4.0.0")
.timeConf(TimeUnit.MILLISECONDS)
.createWithDefaultString("2s")

val CONNECT_SESSION_PLAN_CACHE_SIZE =
buildStaticConf("spark.connect.session.planCache.maxSize")
.doc("Sets the maximum number of cached resolved logical plans in Spark Connect Session." +
" If set to a value less or equal than zero will disable the plan cache.")
.version("4.0.0")
.intConf
.createWithDefault(5)

val CONNECT_SESSION_PLAN_CACHE_ENABLED =
buildConf("spark.connect.session.planCache.enabled")
.doc("When true, the cache of resolved logical plans is enabled if" +
s" '${CONNECT_SESSION_PLAN_CACHE_SIZE.key}' is greater than zero." +
s" When false, the cache is disabled even if '${CONNECT_SESSION_PLAN_CACHE_SIZE.key}' is" +
" greater than zero. The caching is best-effort and not guaranteed.")
.version("4.0.0")
.booleanConf
.createWithDefault(true)
}
Original file line number Diff line number Diff line change
Expand Up @@ -115,95 +115,118 @@ class SparkConnectPlanner(
private lazy val pythonExec =
sys.env.getOrElse("PYSPARK_PYTHON", sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", "python3"))

// The root of the query plan is a relation and we apply the transformations to it.
def transformRelation(rel: proto.Relation): LogicalPlan = {
val plan = rel.getRelTypeCase match {
// DataFrame API
case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString)
case proto.Relation.RelTypeCase.HTML_STRING => transformHtmlString(rel.getHtmlString)
case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead)
case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject)
case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter)
case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit)
case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset)
case proto.Relation.RelTypeCase.TAIL => transformTail(rel.getTail)
case proto.Relation.RelTypeCase.JOIN => transformJoinOrJoinWith(rel.getJoin)
case proto.Relation.RelTypeCase.AS_OF_JOIN => transformAsOfJoin(rel.getAsOfJoin)
case proto.Relation.RelTypeCase.DEDUPLICATE => transformDeduplicate(rel.getDeduplicate)
case proto.Relation.RelTypeCase.SET_OP => transformSetOperation(rel.getSetOp)
case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop)
case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate)
case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql)
case proto.Relation.RelTypeCase.WITH_RELATIONS
if isValidSQLWithRefs(rel.getWithRelations) =>
transformSqlWithRefs(rel.getWithRelations)
case proto.Relation.RelTypeCase.LOCAL_RELATION =>
transformLocalRelation(rel.getLocalRelation)
case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample)
case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange)
case proto.Relation.RelTypeCase.SUBQUERY_ALIAS =>
transformSubqueryAlias(rel.getSubqueryAlias)
case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
case proto.Relation.RelTypeCase.FILL_NA => transformNAFill(rel.getFillNa)
case proto.Relation.RelTypeCase.DROP_NA => transformNADrop(rel.getDropNa)
case proto.Relation.RelTypeCase.REPLACE => transformReplace(rel.getReplace)
case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary)
case proto.Relation.RelTypeCase.DESCRIBE => transformStatDescribe(rel.getDescribe)
case proto.Relation.RelTypeCase.COV => transformStatCov(rel.getCov)
case proto.Relation.RelTypeCase.CORR => transformStatCorr(rel.getCorr)
case proto.Relation.RelTypeCase.APPROX_QUANTILE =>
transformStatApproxQuantile(rel.getApproxQuantile)
case proto.Relation.RelTypeCase.CROSSTAB =>
transformStatCrosstab(rel.getCrosstab)
case proto.Relation.RelTypeCase.FREQ_ITEMS => transformStatFreqItems(rel.getFreqItems)
case proto.Relation.RelTypeCase.SAMPLE_BY =>
transformStatSampleBy(rel.getSampleBy)
case proto.Relation.RelTypeCase.TO_SCHEMA => transformToSchema(rel.getToSchema)
case proto.Relation.RelTypeCase.TO_DF =>
transformToDF(rel.getToDf)
case proto.Relation.RelTypeCase.WITH_COLUMNS_RENAMED =>
transformWithColumnsRenamed(rel.getWithColumnsRenamed)
case proto.Relation.RelTypeCase.WITH_COLUMNS => transformWithColumns(rel.getWithColumns)
case proto.Relation.RelTypeCase.WITH_WATERMARK =>
transformWithWatermark(rel.getWithWatermark)
case proto.Relation.RelTypeCase.CACHED_LOCAL_RELATION =>
transformCachedLocalRelation(rel.getCachedLocalRelation)
case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint)
case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot)
case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION =>
transformRepartitionByExpression(rel.getRepartitionByExpression)
case proto.Relation.RelTypeCase.MAP_PARTITIONS =>
transformMapPartitions(rel.getMapPartitions)
case proto.Relation.RelTypeCase.GROUP_MAP =>
transformGroupMap(rel.getGroupMap)
case proto.Relation.RelTypeCase.CO_GROUP_MAP =>
transformCoGroupMap(rel.getCoGroupMap)
case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE =>
transformApplyInPandasWithState(rel.getApplyInPandasWithState)
case proto.Relation.RelTypeCase.COMMON_INLINE_USER_DEFINED_TABLE_FUNCTION =>
transformCommonInlineUserDefinedTableFunction(rel.getCommonInlineUserDefinedTableFunction)
case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION =>
transformCachedRemoteRelation(rel.getCachedRemoteRelation)
case proto.Relation.RelTypeCase.COLLECT_METRICS =>
transformCollectMetrics(rel.getCollectMetrics, rel.getCommon.getPlanId)
case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")

// Catalog API (internal-only)
case proto.Relation.RelTypeCase.CATALOG => transformCatalog(rel.getCatalog)

// Handle plugins for Spark Connect Relation types.
case proto.Relation.RelTypeCase.EXTENSION =>
transformRelationPlugin(rel.getExtension)
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
}

if (rel.hasCommon && rel.getCommon.hasPlanId) {
plan.setTagValue(LogicalPlan.PLAN_ID_TAG, rel.getCommon.getPlanId)
}
plan
/**
* The root of the query plan is a relation and we apply the transformations to it. The resolved
* logical plan will not get cached. If the result needs to be cached, use
* `transformRelation(rel, cachePlan = true)` instead.
* @param rel
* The relation to transform.
* @return
* The resolved logical plan.
*/
def transformRelation(rel: proto.Relation): LogicalPlan =
transformRelation(rel, cachePlan = false)

/**
* The root of the query plan is a relation and we apply the transformations to it.
* @param rel
* The relation to transform.
* @param cachePlan
* Set to true for a performance optimization, if the plan is likely to be reused, e.g. built
* upon by further dataset transformation. The default is false.
* @return
* The resolved logical plan.
*/
def transformRelation(rel: proto.Relation, cachePlan: Boolean): LogicalPlan = {
sessionHolder.usePlanCache(rel, cachePlan) { rel =>
val plan = rel.getRelTypeCase match {
// DataFrame API
case proto.Relation.RelTypeCase.SHOW_STRING => transformShowString(rel.getShowString)
case proto.Relation.RelTypeCase.HTML_STRING => transformHtmlString(rel.getHtmlString)
case proto.Relation.RelTypeCase.READ => transformReadRel(rel.getRead)
case proto.Relation.RelTypeCase.PROJECT => transformProject(rel.getProject)
case proto.Relation.RelTypeCase.FILTER => transformFilter(rel.getFilter)
case proto.Relation.RelTypeCase.LIMIT => transformLimit(rel.getLimit)
case proto.Relation.RelTypeCase.OFFSET => transformOffset(rel.getOffset)
case proto.Relation.RelTypeCase.TAIL => transformTail(rel.getTail)
case proto.Relation.RelTypeCase.JOIN => transformJoinOrJoinWith(rel.getJoin)
case proto.Relation.RelTypeCase.AS_OF_JOIN => transformAsOfJoin(rel.getAsOfJoin)
case proto.Relation.RelTypeCase.DEDUPLICATE => transformDeduplicate(rel.getDeduplicate)
case proto.Relation.RelTypeCase.SET_OP => transformSetOperation(rel.getSetOp)
case proto.Relation.RelTypeCase.SORT => transformSort(rel.getSort)
case proto.Relation.RelTypeCase.DROP => transformDrop(rel.getDrop)
case proto.Relation.RelTypeCase.AGGREGATE => transformAggregate(rel.getAggregate)
case proto.Relation.RelTypeCase.SQL => transformSql(rel.getSql)
case proto.Relation.RelTypeCase.WITH_RELATIONS
if isValidSQLWithRefs(rel.getWithRelations) =>
transformSqlWithRefs(rel.getWithRelations)
case proto.Relation.RelTypeCase.LOCAL_RELATION =>
transformLocalRelation(rel.getLocalRelation)
case proto.Relation.RelTypeCase.SAMPLE => transformSample(rel.getSample)
case proto.Relation.RelTypeCase.RANGE => transformRange(rel.getRange)
case proto.Relation.RelTypeCase.SUBQUERY_ALIAS =>
transformSubqueryAlias(rel.getSubqueryAlias)
case proto.Relation.RelTypeCase.REPARTITION => transformRepartition(rel.getRepartition)
case proto.Relation.RelTypeCase.FILL_NA => transformNAFill(rel.getFillNa)
case proto.Relation.RelTypeCase.DROP_NA => transformNADrop(rel.getDropNa)
case proto.Relation.RelTypeCase.REPLACE => transformReplace(rel.getReplace)
case proto.Relation.RelTypeCase.SUMMARY => transformStatSummary(rel.getSummary)
case proto.Relation.RelTypeCase.DESCRIBE => transformStatDescribe(rel.getDescribe)
case proto.Relation.RelTypeCase.COV => transformStatCov(rel.getCov)
case proto.Relation.RelTypeCase.CORR => transformStatCorr(rel.getCorr)
case proto.Relation.RelTypeCase.APPROX_QUANTILE =>
transformStatApproxQuantile(rel.getApproxQuantile)
case proto.Relation.RelTypeCase.CROSSTAB =>
transformStatCrosstab(rel.getCrosstab)
case proto.Relation.RelTypeCase.FREQ_ITEMS => transformStatFreqItems(rel.getFreqItems)
case proto.Relation.RelTypeCase.SAMPLE_BY =>
transformStatSampleBy(rel.getSampleBy)
case proto.Relation.RelTypeCase.TO_SCHEMA => transformToSchema(rel.getToSchema)
case proto.Relation.RelTypeCase.TO_DF =>
transformToDF(rel.getToDf)
case proto.Relation.RelTypeCase.WITH_COLUMNS_RENAMED =>
transformWithColumnsRenamed(rel.getWithColumnsRenamed)
case proto.Relation.RelTypeCase.WITH_COLUMNS => transformWithColumns(rel.getWithColumns)
case proto.Relation.RelTypeCase.WITH_WATERMARK =>
transformWithWatermark(rel.getWithWatermark)
case proto.Relation.RelTypeCase.CACHED_LOCAL_RELATION =>
transformCachedLocalRelation(rel.getCachedLocalRelation)
case proto.Relation.RelTypeCase.HINT => transformHint(rel.getHint)
case proto.Relation.RelTypeCase.UNPIVOT => transformUnpivot(rel.getUnpivot)
case proto.Relation.RelTypeCase.REPARTITION_BY_EXPRESSION =>
transformRepartitionByExpression(rel.getRepartitionByExpression)
case proto.Relation.RelTypeCase.MAP_PARTITIONS =>
transformMapPartitions(rel.getMapPartitions)
case proto.Relation.RelTypeCase.GROUP_MAP =>
transformGroupMap(rel.getGroupMap)
case proto.Relation.RelTypeCase.CO_GROUP_MAP =>
transformCoGroupMap(rel.getCoGroupMap)
case proto.Relation.RelTypeCase.APPLY_IN_PANDAS_WITH_STATE =>
transformApplyInPandasWithState(rel.getApplyInPandasWithState)
case proto.Relation.RelTypeCase.COMMON_INLINE_USER_DEFINED_TABLE_FUNCTION =>
transformCommonInlineUserDefinedTableFunction(
rel.getCommonInlineUserDefinedTableFunction)
case proto.Relation.RelTypeCase.CACHED_REMOTE_RELATION =>
transformCachedRemoteRelation(rel.getCachedRemoteRelation)
case proto.Relation.RelTypeCase.COLLECT_METRICS =>
transformCollectMetrics(rel.getCollectMetrics, rel.getCommon.getPlanId)
case proto.Relation.RelTypeCase.PARSE => transformParse(rel.getParse)
case proto.Relation.RelTypeCase.RELTYPE_NOT_SET =>
throw new IndexOutOfBoundsException("Expected Relation to be set, but is empty.")

// Catalog API (internal-only)
case proto.Relation.RelTypeCase.CATALOG => transformCatalog(rel.getCatalog)

// Handle plugins for Spark Connect Relation types.
case proto.Relation.RelTypeCase.EXTENSION =>
transformRelationPlugin(rel.getExtension)
case _ => throw InvalidPlanInput(s"${rel.getUnknown} not supported.")
}
if (rel.hasCommon && rel.getCommon.hasPlanId) {
plan.setTagValue(LogicalPlan.PLAN_ID_TAG, rel.getCommon.getPlanId)
}
plan
}
}

@DeveloperApi
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,17 @@ import scala.jdk.CollectionConverters._
import scala.util.Try

import com.google.common.base.Ticker
import com.google.common.cache.CacheBuilder
import com.google.common.cache.{Cache, CacheBuilder}

import org.apache.spark.{SparkException, SparkSQLException}
import org.apache.spark.{SparkEnv, SparkException, SparkSQLException}
import org.apache.spark.api.python.PythonFunction.PythonAccumulator
import org.apache.spark.connect.proto
import org.apache.spark.internal.Logging
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.connect.common.InvalidPlanInput
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.planner.PythonStreamingQueryListener
import org.apache.spark.sql.connect.planner.StreamingForeachBatchHelper
import org.apache.spark.sql.connect.service.SessionHolder.{ERROR_CACHE_SIZE, ERROR_CACHE_TIMEOUT_SEC}
Expand All @@ -50,6 +53,27 @@ case class SessionKey(userId: String, sessionId: String)
case class SessionHolder(userId: String, sessionId: String, session: SparkSession)
extends Logging {

// Cache which stores recently resolved logical plans to improve the performance of plan analysis.
// Only plans that explicitly specify "cachePlan = true" in transformRelation will be cached.
// Analyzing a large plan may be expensive, and it is not uncommon to build the plan step-by-step
// with several analysis during the process. This cache aids the recursive analysis process by
// memorizing `LogicalPlan`s which may be a sub-tree in a subsequent plan.
private lazy val planCache: Option[Cache[proto.Relation, LogicalPlan]] = {
if (SparkEnv.get.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE) <= 0) {
logWarning(
s"Session plan cache is disabled due to non-positive cache size." +
s" Current value of '${Connect.CONNECT_SESSION_PLAN_CACHE_SIZE.key}' is" +
s" ${SparkEnv.get.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE)}.")
None
} else {
Some(
CacheBuilder
.newBuilder()
.maximumSize(SparkEnv.get.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_SIZE))
.build[proto.Relation, LogicalPlan]())
}
}

// Time when the session was started.
private val startTimeMs: Long = System.currentTimeMillis()

Expand Down Expand Up @@ -388,6 +412,57 @@ case class SessionHolder(userId: String, sessionId: String, session: SparkSessio
*/
private[connect] val pythonAccumulator: Option[PythonAccumulator] =
Try(session.sparkContext.collectionAccumulator[Array[Byte]]).toOption

/**
* Transform a relation into a logical plan, using the plan cache if enabled. The plan cache is
* enable only if `spark.connect.session.planCache.maxSize` is greater than zero AND
* `spark.connect.session.planCache.enabled` is true.
* @param rel
* The relation to transform.
* @param cachePlan
* Whether to cache the result logical plan.
* @param transform
* Function to transform the relation into a logical plan.
* @return
* The logical plan.
*/
private[connect] def usePlanCache(rel: proto.Relation, cachePlan: Boolean)(
transform: proto.Relation => LogicalPlan): LogicalPlan = {
val planCacheEnabled =
Option(session).forall(_.conf.get(Connect.CONNECT_SESSION_PLAN_CACHE_ENABLED, true))
// We only cache plans that have a plan ID.
val hasPlanId = rel.hasCommon && rel.getCommon.hasPlanId

def getPlanCache(rel: proto.Relation): Option[LogicalPlan] =
planCache match {
case Some(cache) if planCacheEnabled && hasPlanId =>
Option(cache.getIfPresent(rel)) match {
case Some(plan) =>
logDebug(s"Using cached plan for relation '$rel': $plan")
Some(plan)
case None => None
}
case _ => None
}
def putPlanCache(rel: proto.Relation, plan: LogicalPlan): Unit =
planCache match {
case Some(cache) if planCacheEnabled && hasPlanId =>
cache.put(rel, plan)
case _ =>
}

getPlanCache(rel)
.getOrElse({
val plan = transform(rel)
if (cachePlan) {
putPlanCache(rel, plan)
}
plan
})
}

// For testing. Expose the plan cache for testing purposes.
private[service] def getPlanCache: Option[Cache[proto.Relation, LogicalPlan]] = planCache
}

object SessionHolder {
Expand Down
Loading

0 comments on commit a1fc6d5

Please sign in to comment.