Skip to content

Commit

Permalink
[SPARK-42454][SQL] SPJ: encapsulate all SPJ related parameters in Bat…
Browse files Browse the repository at this point in the history
…chScanExec

### What changes were proposed in this pull request?
Pull out the SPJ-related attribute of BatchScanExec into a case class

### Why are the changes needed?
We plan to have further evolution of SPJ parameters to support more SPJ features.  So we want to stabilize the definition of BatchScanExec to not have to touch the many places in the code that it is pattern-matched/unapplied, etc..

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Existing unit test to verify no behavior change.

Closes apache#41990 from szehon-ho/spj_refactor.

Authored-by: Szehon Ho <[email protected]>
Signed-off-by: Chao Sun <[email protected]>
  • Loading branch information
szehon-ho authored and sunchao committed Jul 15, 2023
1 parent 0a0c367 commit c63ba6c
Show file tree
Hide file tree
Showing 9 changed files with 59 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ class AvroRowReaderSuite

val df = spark.read.format("avro").load(dir.getCanonicalPath)
val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
case BatchScanExec(_, f: AvroScan, _, _, _, _) => f
}
val filePath = fileScan.get.fileIndex.inputFiles(0)
val fileSize = new File(new URI(filePath)).length
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2778,7 +2778,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
})

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
case BatchScanExec(_, f: AvroScan, _, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
Expand Down Expand Up @@ -2812,7 +2812,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
assert(filterCondition.isDefined)

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
case BatchScanExec(_, f: AvroScan, _, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
Expand Down Expand Up @@ -2893,7 +2893,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper {
.where("value = 'a'")

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f
case BatchScanExec(_, f: AvroScan, _, _, _, _) => f
}
assert(fileScan.nonEmpty)
if (filtersPushdown) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,19 @@ case class BatchScanExec(
output: Seq[AttributeReference],
@transient scan: Scan,
runtimeFilters: Seq[Expression],
keyGroupedPartitioning: Option[Seq[Expression]] = None,
ordering: Option[Seq[SortOrder]] = None,
@transient table: Table,
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
applyPartialClustering: Boolean = false,
replicatePartitions: Boolean = false) extends DataSourceV2ScanExecBase {
spjParams: StoragePartitionJoinParams = StoragePartitionJoinParams()
) extends DataSourceV2ScanExecBase {

@transient lazy val batch = if (scan == null) null else scan.toBatch
@transient lazy val batch: Batch = if (scan == null) null else scan.toBatch

// TODO: unify the equal/hashCode implementation for all data source v2 query plans.
override def equals(other: Any): Boolean = other match {
case other: BatchScanExec =>
this.batch != null && this.batch == other.batch &&
this.runtimeFilters == other.runtimeFilters &&
this.commonPartitionValues == other.commonPartitionValues &&
this.replicatePartitions == other.replicatePartitions &&
this.applyPartialClustering == other.applyPartialClustering
this.spjParams == other.spjParams
case _ =>
false
}
Expand Down Expand Up @@ -119,11 +115,11 @@ case class BatchScanExec(

override def outputPartitioning: Partitioning = {
super.outputPartitioning match {
case k: KeyGroupedPartitioning if commonPartitionValues.isDefined =>
case k: KeyGroupedPartitioning if spjParams.commonPartitionValues.isDefined =>
// We allow duplicated partition values if
// `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true
val newPartValues = commonPartitionValues.get.flatMap { case (partValue, numSplits) =>
Seq.fill(numSplits)(partValue)
val newPartValues = spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) => Seq.fill(numSplits)(partValue)
}
k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues)
case p => p
Expand All @@ -148,15 +144,17 @@ case class BatchScanExec(
s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " +
"is enabled")

val groupedPartitions = groupPartitions(finalPartitions.map(_.head), true).get
val groupedPartitions = groupPartitions(finalPartitions.map(_.head),
groupSplits = true).get

// This means the input partitions are not grouped by partition values. We'll need to
// check `groupByPartitionValues` and decide whether to group and replicate splits
// within a partition.
if (commonPartitionValues.isDefined && applyPartialClustering) {
if (spjParams.commonPartitionValues.isDefined &&
spjParams.applyPartialClustering) {
// A mapping from the common partition values to how many splits the partition
// should contain. Note this no longer maintain the partition key ordering.
val commonPartValuesMap = commonPartitionValues
val commonPartValuesMap = spjParams.commonPartitionValues
.get
.map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2))
.toMap
Expand All @@ -168,7 +166,7 @@ case class BatchScanExec(
assert(numSplits.isDefined, s"Partition value $partValue does not exist in " +
"common partition values from Spark plan")

val newSplits = if (replicatePartitions) {
val newSplits = if (spjParams.replicatePartitions) {
// We need to also replicate partitions according to the other side of join
Seq.fill(numSplits.get)(splits)
} else {
Expand All @@ -184,11 +182,12 @@ case class BatchScanExec(

// Now fill missing partition keys with empty partitions
val partitionMapping = nestGroupedPartitions.toMap
finalPartitions = commonPartitionValues.get.flatMap { case (partValue, numSplits) =>
// Use empty partition for those partition values that are not present.
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions),
Seq.fill(numSplits)(Seq.empty))
finalPartitions = spjParams.commonPartitionValues.get.flatMap {
case (partValue, numSplits) =>
// Use empty partition for those partition values that are not present.
partitionMapping.getOrElse(
InternalRowComparableWrapper(partValue, p.expressions),
Seq.fill(numSplits)(Seq.empty))
}
} else {
val partitionMapping = groupedPartitions.map { case (row, parts) =>
Expand Down Expand Up @@ -222,6 +221,9 @@ case class BatchScanExec(
rdd
}

override def keyGroupedPartitioning: Option[Seq[Expression]] =
spjParams.keyGroupedPartitioning

override def doCanonicalize(): BatchScanExec = {
this.copy(
output = output.map(QueryPlan.normalizeExpressions(_, output)),
Expand All @@ -241,3 +243,24 @@ case class BatchScanExec(
s"BatchScan ${table.name()}".trim
}
}

case class StoragePartitionJoinParams(
keyGroupedPartitioning: Option[Seq[Expression]] = None,
commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None,
applyPartialClustering: Boolean = false,
replicatePartitions: Boolean = false) {
override def equals(other: Any): Boolean = other match {
case other: StoragePartitionJoinParams =>
this.commonPartitionValues == other.commonPartitionValues &&
this.replicatePartitions == other.replicatePartitions &&
this.applyPartialClustering == other.applyPartialClustering
case _ =>
false
}

override def hashCode(): Int = Objects.hashCode(
commonPartitionValues: Option[Seq[(InternalRow, Int)]],
applyPartialClustering: java.lang.Boolean,
replicatePartitions: java.lang.Boolean)
}

Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat
case _ => false
}
val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters,
relation.keyGroupedPartitioning, relation.ordering, relation.relation.table)
relation.ordering, relation.relation.table,
StoragePartitionJoinParams(relation.keyGroupedPartitioning))
withProjectAndFilter(project, postScanFilters, batchExec, !batchExec.supportsColumnar) :: Nil

case PhysicalOperation(p, f, r: StreamingDataSourceV2Relation)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -531,9 +531,11 @@ case class EnsureRequirements(
replicatePartitions: Boolean): SparkPlan = plan match {
case scan: BatchScanExec =>
scan.copy(
commonPartitionValues = Some(values),
applyPartialClustering = applyPartialClustering,
replicatePartitions = replicatePartitions
spjParams = scan.spjParams.copy(
commonPartitionValues = Some(values),
applyPartialClustering = applyPartialClustering,
replicatePartitions = replicatePartitions
)
)
case node =>
node.mapChildren(child => populatePartitionValues(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1014,7 +1014,7 @@ class FileBasedDataSourceSuite extends QueryTest
})

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f
case BatchScanExec(_, f: FileScan, _, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.nonEmpty)
Expand Down Expand Up @@ -1055,7 +1055,7 @@ class FileBasedDataSourceSuite extends QueryTest
assert(filterCondition.isDefined)

val fileScan = df.queryExecution.executedPlan collectFirst {
case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f
case BatchScanExec(_, f: FileScan, _, _, _, _) => f
}
assert(fileScan.nonEmpty)
assert(fileScan.get.partitionFilters.isEmpty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase with Shared
override def getScanExecPartitionSize(plan: SparkPlan): Long = {
plan.collectFirst {
case p: FileSourceScanExec => p.selectedPartitions.length
case BatchScanExec(_, scan: FileScan, _, _, _, _, _, _, _) =>
case BatchScanExec(_, scan: FileScan, _, _, _, _) =>
scan.fileIndex.listFiles(scan.partitionFilters, scan.dataFilters).length
}.get
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase {
assert(getScanExecPartitionSize(plan) == expectedPartitionCount)

val collectFn: PartialFunction[SparkPlan, Seq[Expression]] = collectPartitionFiltersFn orElse {
case BatchScanExec(_, scan: FileScan, _, _, _, _, _, _, _) => scan.partitionFilters
case BatchScanExec(_, scan: FileScan, _, _, _, _) => scan.partitionFilters
}
val pushedDownPartitionFilters = plan.collectFirst(collectFn)
.map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull]))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class OrcV2SchemaPruningSuite extends SchemaPruningSuite with AdaptiveSparkPlanH
override def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = {
val fileSourceScanSchemata =
collect(df.queryExecution.executedPlan) {
case BatchScanExec(_, scan: OrcScan, _, _, _, _, _, _, _) => scan.readDataSchema
case BatchScanExec(_, scan: OrcScan, _, _, _, _) => scan.readDataSchema
}
assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size,
s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +
Expand Down

0 comments on commit c63ba6c

Please sign in to comment.