diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala index 046ff4ef088d8..cc0e178c617af 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroRowReaderSuite.scala @@ -59,7 +59,7 @@ class AvroRowReaderSuite val df = spark.read.format("avro").load(dir.getCanonicalPath) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f + case BatchScanExec(_, f: AvroScan, _, _, _, _) => f } val filePath = fileScan.get.fileIndex.inputFiles(0) val fileSize = new File(new URI(filePath)).length diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index c5e52292caf2d..35e9f43289c16 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -2778,7 +2778,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { }) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f + case BatchScanExec(_, f: AvroScan, _, _, _, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty) @@ -2812,7 +2812,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { assert(filterCondition.isDefined) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f + case BatchScanExec(_, f: AvroScan, _, _, _, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.isEmpty) @@ -2893,7 +2893,7 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { .where("value = 'a'") val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: AvroScan, _, _, _, _, _, _, _) => f + case BatchScanExec(_, f: AvroScan, _, _, _, _) => f } assert(fileScan.nonEmpty) if (filtersPushdown) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala index d43331d57c47a..4b53819739262 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/BatchScanExec.scala @@ -37,23 +37,19 @@ case class BatchScanExec( output: Seq[AttributeReference], @transient scan: Scan, runtimeFilters: Seq[Expression], - keyGroupedPartitioning: Option[Seq[Expression]] = None, ordering: Option[Seq[SortOrder]] = None, @transient table: Table, - commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, - applyPartialClustering: Boolean = false, - replicatePartitions: Boolean = false) extends DataSourceV2ScanExecBase { + spjParams: StoragePartitionJoinParams = StoragePartitionJoinParams() + ) extends DataSourceV2ScanExecBase { - @transient lazy val batch = if (scan == null) null else scan.toBatch + @transient lazy val batch: Batch = if (scan == null) null else scan.toBatch // TODO: unify the equal/hashCode implementation for all data source v2 query plans. override def equals(other: Any): Boolean = other match { case other: BatchScanExec => this.batch != null && this.batch == other.batch && this.runtimeFilters == other.runtimeFilters && - this.commonPartitionValues == other.commonPartitionValues && - this.replicatePartitions == other.replicatePartitions && - this.applyPartialClustering == other.applyPartialClustering + this.spjParams == other.spjParams case _ => false } @@ -119,11 +115,11 @@ case class BatchScanExec( override def outputPartitioning: Partitioning = { super.outputPartitioning match { - case k: KeyGroupedPartitioning if commonPartitionValues.isDefined => + case k: KeyGroupedPartitioning if spjParams.commonPartitionValues.isDefined => // We allow duplicated partition values if // `spark.sql.sources.v2.bucketing.partiallyClusteredDistribution.enabled` is true - val newPartValues = commonPartitionValues.get.flatMap { case (partValue, numSplits) => - Seq.fill(numSplits)(partValue) + val newPartValues = spjParams.commonPartitionValues.get.flatMap { + case (partValue, numSplits) => Seq.fill(numSplits)(partValue) } k.copy(numPartitions = newPartValues.length, partitionValues = newPartValues) case p => p @@ -148,15 +144,17 @@ case class BatchScanExec( s"${SQLConf.V2_BUCKETING_PARTIALLY_CLUSTERED_DISTRIBUTION_ENABLED.key} " + "is enabled") - val groupedPartitions = groupPartitions(finalPartitions.map(_.head), true).get + val groupedPartitions = groupPartitions(finalPartitions.map(_.head), + groupSplits = true).get // This means the input partitions are not grouped by partition values. We'll need to // check `groupByPartitionValues` and decide whether to group and replicate splits // within a partition. - if (commonPartitionValues.isDefined && applyPartialClustering) { + if (spjParams.commonPartitionValues.isDefined && + spjParams.applyPartialClustering) { // A mapping from the common partition values to how many splits the partition // should contain. Note this no longer maintain the partition key ordering. - val commonPartValuesMap = commonPartitionValues + val commonPartValuesMap = spjParams.commonPartitionValues .get .map(t => (InternalRowComparableWrapper(t._1, p.expressions), t._2)) .toMap @@ -168,7 +166,7 @@ case class BatchScanExec( assert(numSplits.isDefined, s"Partition value $partValue does not exist in " + "common partition values from Spark plan") - val newSplits = if (replicatePartitions) { + val newSplits = if (spjParams.replicatePartitions) { // We need to also replicate partitions according to the other side of join Seq.fill(numSplits.get)(splits) } else { @@ -184,11 +182,12 @@ case class BatchScanExec( // Now fill missing partition keys with empty partitions val partitionMapping = nestGroupedPartitions.toMap - finalPartitions = commonPartitionValues.get.flatMap { case (partValue, numSplits) => - // Use empty partition for those partition values that are not present. - partitionMapping.getOrElse( - InternalRowComparableWrapper(partValue, p.expressions), - Seq.fill(numSplits)(Seq.empty)) + finalPartitions = spjParams.commonPartitionValues.get.flatMap { + case (partValue, numSplits) => + // Use empty partition for those partition values that are not present. + partitionMapping.getOrElse( + InternalRowComparableWrapper(partValue, p.expressions), + Seq.fill(numSplits)(Seq.empty)) } } else { val partitionMapping = groupedPartitions.map { case (row, parts) => @@ -222,6 +221,9 @@ case class BatchScanExec( rdd } + override def keyGroupedPartitioning: Option[Seq[Expression]] = + spjParams.keyGroupedPartitioning + override def doCanonicalize(): BatchScanExec = { this.copy( output = output.map(QueryPlan.normalizeExpressions(_, output)), @@ -241,3 +243,24 @@ case class BatchScanExec( s"BatchScan ${table.name()}".trim } } + +case class StoragePartitionJoinParams( + keyGroupedPartitioning: Option[Seq[Expression]] = None, + commonPartitionValues: Option[Seq[(InternalRow, Int)]] = None, + applyPartialClustering: Boolean = false, + replicatePartitions: Boolean = false) { + override def equals(other: Any): Boolean = other match { + case other: StoragePartitionJoinParams => + this.commonPartitionValues == other.commonPartitionValues && + this.replicatePartitions == other.replicatePartitions && + this.applyPartialClustering == other.applyPartialClustering + case _ => + false + } + + override def hashCode(): Int = Objects.hashCode( + commonPartitionValues: Option[Seq[(InternalRow, Int)]], + applyPartialClustering: java.lang.Boolean, + replicatePartitions: java.lang.Boolean) +} + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 542ac2e674864..abd70f322c839 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -142,7 +142,8 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat case _ => false } val batchExec = BatchScanExec(relation.output, relation.scan, runtimeFilters, - relation.keyGroupedPartitioning, relation.ordering, relation.relation.table) + relation.ordering, relation.relation.table, + StoragePartitionJoinParams(relation.keyGroupedPartitioning)) withProjectAndFilter(project, postScanFilters, batchExec, !batchExec.supportsColumnar) :: Nil case PhysicalOperation(p, f, r: StreamingDataSourceV2Relation) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala index 457a9e0a868f5..42c880e7c6262 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala @@ -531,9 +531,11 @@ case class EnsureRequirements( replicatePartitions: Boolean): SparkPlan = plan match { case scan: BatchScanExec => scan.copy( - commonPartitionValues = Some(values), - applyPartialClustering = applyPartialClustering, - replicatePartitions = replicatePartitions + spjParams = scan.spjParams.copy( + commonPartitionValues = Some(values), + applyPartialClustering = applyPartialClustering, + replicatePartitions = replicatePartitions + ) ) case node => node.mapChildren(child => populatePartitionValues( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index d69a68f57262a..93275487f29c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -1014,7 +1014,7 @@ class FileBasedDataSourceSuite extends QueryTest }) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f + case BatchScanExec(_, f: FileScan, _, _, _, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.nonEmpty) @@ -1055,7 +1055,7 @@ class FileBasedDataSourceSuite extends QueryTest assert(filterCondition.isDefined) val fileScan = df.queryExecution.executedPlan collectFirst { - case BatchScanExec(_, f: FileScan, _, _, _, _, _, _, _) => f + case BatchScanExec(_, f: FileScan, _, _, _, _) => f } assert(fileScan.nonEmpty) assert(fileScan.get.partitionFilters.isEmpty) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala index 1b30205a41864..3a70bfc7f4a4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PruneFileSourcePartitionsSuite.scala @@ -170,7 +170,7 @@ class PruneFileSourcePartitionsSuite extends PrunePartitionSuiteBase with Shared override def getScanExecPartitionSize(plan: SparkPlan): Long = { plan.collectFirst { case p: FileSourceScanExec => p.selectedPartitions.length - case BatchScanExec(_, scan: FileScan, _, _, _, _, _, _, _) => + case BatchScanExec(_, scan: FileScan, _, _, _, _) => scan.fileIndex.listFiles(scan.partitionFilters, scan.dataFilters).length }.get } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala index 9a61e6517f749..430e9f848e4aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/PrunePartitionSuiteBase.scala @@ -95,7 +95,7 @@ abstract class PrunePartitionSuiteBase extends StatisticsCollectionTestBase { assert(getScanExecPartitionSize(plan) == expectedPartitionCount) val collectFn: PartialFunction[SparkPlan, Seq[Expression]] = collectPartitionFiltersFn orElse { - case BatchScanExec(_, scan: FileScan, _, _, _, _, _, _, _) => scan.partitionFilters + case BatchScanExec(_, scan: FileScan, _, _, _, _) => scan.partitionFilters } val pushedDownPartitionFilters = plan.collectFirst(collectFn) .map(exps => exps.filterNot(e => e.isInstanceOf[IsNotNull])) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala index 1fba772f5a822..8d503d64e30ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala @@ -42,7 +42,7 @@ class OrcV2SchemaPruningSuite extends SchemaPruningSuite with AdaptiveSparkPlanH override def checkScanSchemata(df: DataFrame, expectedSchemaCatalogStrings: String*): Unit = { val fileSourceScanSchemata = collect(df.queryExecution.executedPlan) { - case BatchScanExec(_, scan: OrcScan, _, _, _, _, _, _, _) => scan.readDataSchema + case BatchScanExec(_, scan: OrcScan, _, _, _, _) => scan.readDataSchema } assert(fileSourceScanSchemata.size === expectedSchemaCatalogStrings.size, s"Found ${fileSourceScanSchemata.size} file sources in dataframe, " +