Skip to content

Commit

Permalink
[SPARK-37768][SQL][FOLLOWUP] Schema pruning for the metadata struct
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Follow-up PR of apache#34575. Support the metadata struct schema pruning for all file formats.

### Why are the changes needed?
Performance improvements.

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
Existing UTs and a new UT.

Closes apache#35147 from Yaohua628/spark-37768.

Authored-by: yaohua <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
Yaohua628 authored and cloud-fan committed Jan 18, 2022
1 parent 450418b commit 54f91d3
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ object SchemaPruning extends SQLConfHelper {
* 1. The schema field ordering at original schema is still preserved in pruned schema.
* 2. The top-level fields are not pruned here.
*/
def pruneDataSchema(
dataSchema: StructType,
def pruneSchema(
schema: StructType,
requestedRootFields: Seq[RootField]): StructType = {
val resolver = conf.resolver
// Merge the requested root fields into a single schema. Note the ordering of the fields
Expand All @@ -44,10 +44,10 @@ object SchemaPruning extends SQLConfHelper {
.map { root: RootField => StructType(Array(root.field)) }
.reduceLeft(_ merge _)
val mergedDataSchema =
StructType(dataSchema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d)))
StructType(schema.map(d => mergedSchema.find(m => resolver(m.name, d.name)).getOrElse(d)))
// Sort the fields of mergedDataSchema according to their order in dataSchema,
// recursively. This makes mergedDataSchema a pruned schema of dataSchema
sortLeftFieldsByRight(mergedDataSchema, dataSchema).asInstanceOf[StructType]
sortLeftFieldsByRight(mergedDataSchema, schema).asInstanceOf[StructType]
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ case class AttributeReference(
AttributeReference(name, dataType, nullable, newMetadata)(exprId, qualifier)
}

override def withDataType(newType: DataType): Attribute = {
override def withDataType(newType: DataType): AttributeReference = {
AttributeReference(name, newType, nullable, metadata)(exprId, qualifier)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ object ObjectSerializerPruning extends Rule[LogicalPlan] {

if (conf.serializerNestedSchemaPruningEnabled && rootFields.nonEmpty) {
// Prunes nested fields in serializers.
val prunedSchema = SchemaPruning.pruneDataSchema(
val prunedSchema = SchemaPruning.pruneSchema(
StructType.fromAttributes(prunedSerializer.map(_.toAttribute)), rootFields)
val nestedPrunedSerializer = prunedSerializer.zipWithIndex.map { case (serializer, idx) =>
pruneSerializer(serializer, prunedSchema(idx).dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class SchemaPruningSuite extends SparkFunSuite with SQLHelper {
// `derivedFromAtt` doesn't affect the result of pruned schema.
SchemaPruning.RootField(field = f, derivedFromAtt = true)
}
val prunedSchema = SchemaPruning.pruneDataSchema(schema, requestedRootFields)
val prunedSchema = SchemaPruning.pruneSchema(schema, requestedRootFields)
assert(prunedSchema === expectedSchema)
}

Expand Down Expand Up @@ -140,7 +140,7 @@ class SchemaPruningSuite extends SparkFunSuite with SQLHelper {
assert(field.metadata.getString("foo") == "bar")

val schema = StructType(Seq(field))
val prunedSchema = SchemaPruning.pruneDataSchema(schema, rootFields)
val prunedSchema = SchemaPruning.pruneSchema(schema, rootFields)
assert(prunedSchema.head.metadata.getString("foo") == "bar")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -213,11 +213,10 @@ object FileSourceStrategy extends Strategy with PredicateHelper with Logging {
val outputSchema = readDataColumns.toStructType
logInfo(s"Output Data Schema: ${outputSchema.simpleString(5)}")

val metadataStructOpt = requiredAttributes.collectFirst {
val metadataStructOpt = l.output.collectFirst {
case MetadataAttribute(attr) => attr
}

// TODO (yaohua): should be able to prune the metadata struct only containing what needed
val metadataColumns = metadataStructOpt.map { metadataStruct =>
metadataStruct.dataType.asInstanceOf[StructType].fields.map { field =>
MetadataAttribute(field.name, field.dataType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,58 +31,68 @@ import org.apache.spark.sql.util.SchemaUtils._
* By "physical column", we mean a column as defined in the data source format like Parquet format
* or ORC format. For example, in Spark SQL, a root-level Parquet column corresponds to a SQL
* column, and a nested Parquet column corresponds to a [[StructField]].
*
* Also prunes the unnecessary metadata columns if any for all file formats.
*/
object SchemaPruning extends Rule[LogicalPlan] {
import org.apache.spark.sql.catalyst.expressions.SchemaPruning._

override def apply(plan: LogicalPlan): LogicalPlan =
if (conf.nestedSchemaPruningEnabled) {
apply0(plan)
} else {
plan
}

private def apply0(plan: LogicalPlan): LogicalPlan =
plan transformDown {
case op @ PhysicalOperation(projects, filters,
l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _))
if canPruneRelation(hadoopFsRelation) =>

prunePhysicalColumns(l.output, projects, filters, hadoopFsRelation.dataSchema,
prunedDataSchema => {
l @ LogicalRelation(hadoopFsRelation: HadoopFsRelation, _, _, _)) =>
prunePhysicalColumns(l, projects, filters, hadoopFsRelation,
(prunedDataSchema, prunedMetadataSchema) => {
val prunedHadoopRelation =
hadoopFsRelation.copy(dataSchema = prunedDataSchema)(hadoopFsRelation.sparkSession)
buildPrunedRelation(l, prunedHadoopRelation)
buildPrunedRelation(l, prunedHadoopRelation, prunedMetadataSchema)
}).getOrElse(op)
}

/**
* This method returns optional logical plan. `None` is returned if no nested field is required or
* all nested fields are required.
*
* This method will prune both the data schema and the metadata schema
*/
private def prunePhysicalColumns(
output: Seq[AttributeReference],
relation: LogicalRelation,
projects: Seq[NamedExpression],
filters: Seq[Expression],
dataSchema: StructType,
leafNodeBuilder: StructType => LeafNode): Option[LogicalPlan] = {
hadoopFsRelation: HadoopFsRelation,
leafNodeBuilder: (StructType, StructType) => LeafNode): Option[LogicalPlan] = {

val (normalizedProjects, normalizedFilters) =
normalizeAttributeRefNames(output, projects, filters)
normalizeAttributeRefNames(relation.output, projects, filters)
val requestedRootFields = identifyRootFields(normalizedProjects, normalizedFilters)

// If requestedRootFields includes a nested field, continue. Otherwise,
// return op
if (requestedRootFields.exists { root: RootField => !root.derivedFromAtt }) {
val prunedDataSchema = pruneDataSchema(dataSchema, requestedRootFields)

// If the data schema is different from the pruned data schema, continue. Otherwise,
// return op. We effect this comparison by counting the number of "leaf" fields in
// each schemata, assuming the fields in prunedDataSchema are a subset of the fields
// in dataSchema.
if (countLeaves(dataSchema) > countLeaves(prunedDataSchema)) {
val prunedRelation = leafNodeBuilder(prunedDataSchema)
val projectionOverSchema = ProjectionOverSchema(prunedDataSchema)
val prunedDataSchema = if (canPruneDataSchema(hadoopFsRelation)) {
pruneSchema(hadoopFsRelation.dataSchema, requestedRootFields)
} else {
hadoopFsRelation.dataSchema
}

val metadataSchema =
relation.output.collect { case MetadataAttribute(attr) => attr }.toStructType
val prunedMetadataSchema = if (metadataSchema.nonEmpty) {
pruneSchema(metadataSchema, requestedRootFields)
} else {
metadataSchema
}

// If the data schema is different from the pruned data schema
// OR
// the metadata schema is different from the pruned metadata schema, continue.
// Otherwise, return None.
if (countLeaves(hadoopFsRelation.dataSchema) > countLeaves(prunedDataSchema) ||
countLeaves(metadataSchema) > countLeaves(prunedMetadataSchema)) {
val prunedRelation = leafNodeBuilder(prunedDataSchema, prunedMetadataSchema)
val projectionOverSchema =
ProjectionOverSchema(prunedDataSchema.merge(prunedMetadataSchema))
Some(buildNewProjection(projects, normalizedProjects, normalizedFilters,
prunedRelation, projectionOverSchema))
} else {
Expand All @@ -96,9 +106,10 @@ object SchemaPruning extends Rule[LogicalPlan] {
/**
* Checks to see if the given relation can be pruned. Currently we support Parquet and ORC v1.
*/
private def canPruneRelation(fsRelation: HadoopFsRelation) =
fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] ||
fsRelation.fileFormat.isInstanceOf[OrcFileFormat]
private def canPruneDataSchema(fsRelation: HadoopFsRelation): Boolean =
conf.nestedSchemaPruningEnabled && (
fsRelation.fileFormat.isInstanceOf[ParquetFileFormat] ||
fsRelation.fileFormat.isInstanceOf[OrcFileFormat])

/**
* Normalizes the names of the attribute references in the given projects and filters to reflect
Expand Down Expand Up @@ -162,29 +173,25 @@ object SchemaPruning extends Rule[LogicalPlan] {
*/
private def buildPrunedRelation(
outputRelation: LogicalRelation,
prunedBaseRelation: HadoopFsRelation) = {
val prunedOutput = getPrunedOutput(outputRelation.output, prunedBaseRelation.schema)
// also add the metadata output if any
// TODO: should be able to prune the metadata schema
val metaOutput = outputRelation.output.collect {
case MetadataAttribute(attr) => attr
}
outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput ++ metaOutput)
prunedBaseRelation: HadoopFsRelation,
prunedMetadataSchema: StructType) = {
val finalSchema = prunedBaseRelation.schema.merge(prunedMetadataSchema)
val prunedOutput = getPrunedOutput(outputRelation.output, finalSchema)
outputRelation.copy(relation = prunedBaseRelation, output = prunedOutput)
}

// Prune the given output to make it consistent with `requiredSchema`.
private def getPrunedOutput(
output: Seq[AttributeReference],
requiredSchema: StructType): Seq[AttributeReference] = {
// We need to replace the expression ids of the pruned relation output attributes
// with the expression ids of the original relation output attributes so that
// references to the original relation's output are not broken
val outputIdMap = output.map(att => (att.name, att.exprId)).toMap
// We need to update the data type of the output attributes to use the pruned ones.
// so that references to the original relation's output are not broken
val nameAttributeMap = output.map(att => (att.name, att)).toMap
requiredSchema
.toAttributes
.map {
case att if outputIdMap.contains(att.name) =>
att.withExprId(outputIdMap(att.name))
case att if nameAttributeMap.contains(att.name) =>
nameAttributeMap(att.name).withDataType(att.dataType)
case att => att
}
}
Expand All @@ -203,6 +210,4 @@ object SchemaPruning extends Rule[LogicalPlan] {
case _ => 1
}
}


}
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ object PushDownUtils extends PredicateHelper {
case r: SupportsPushDownRequiredColumns if SQLConf.get.nestedSchemaPruningEnabled =>
val rootFields = SchemaPruning.identifyRootFields(projects, filters)
val prunedSchema = if (rootFields.nonEmpty) {
SchemaPruning.pruneDataSchema(relation.schema, rootFields)
SchemaPruning.pruneSchema(relation.schema, rootFields)
} else {
new StructType()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import java.sql.Timestamp
import java.text.SimpleDateFormat

import org.apache.spark.sql.{AnalysisException, Column, DataFrame, QueryTest, Row}
import org.apache.spark.sql.execution.FileSourceScanExec
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession
Expand Down Expand Up @@ -384,4 +385,51 @@ class FileMetadataStructSuite extends QueryTest with SharedSparkSession {
}
}
}

metadataColumnsTest("prune metadata schema in projects", schema) { (df, f0, f1) =>
val prunedDF = df.select("name", "age", "info.id", METADATA_FILE_NAME)
val fileSourceScanMetaCols = prunedDF.queryExecution.sparkPlan.collectFirst {
case p: FileSourceScanExec => p.metadataColumns
}.get
assert(fileSourceScanMetaCols.size == 1)
assert(fileSourceScanMetaCols.head.name == "file_name")

checkAnswer(
prunedDF,
Seq(Row("jack", 24, 12345L, f0(METADATA_FILE_NAME)),
Row("lily", 31, 54321L, f1(METADATA_FILE_NAME)))
)
}

metadataColumnsTest("prune metadata schema in filters", schema) { (df, f0, f1) =>
val prunedDF = df.select("name", "age", "info.id")
.where(col(METADATA_FILE_PATH).contains("data/f0"))

val fileSourceScanMetaCols = prunedDF.queryExecution.sparkPlan.collectFirst {
case p: FileSourceScanExec => p.metadataColumns
}.get
assert(fileSourceScanMetaCols.size == 1)
assert(fileSourceScanMetaCols.head.name == "file_path")

checkAnswer(
prunedDF,
Seq(Row("jack", 24, 12345L))
)
}

metadataColumnsTest("prune metadata schema in projects and filters", schema) { (df, f0, f1) =>
val prunedDF = df.select("name", "age", "info.id", METADATA_FILE_SIZE)
.where(col(METADATA_FILE_PATH).contains("data/f0"))

val fileSourceScanMetaCols = prunedDF.queryExecution.sparkPlan.collectFirst {
case p: FileSourceScanExec => p.metadataColumns
}.get
assert(fileSourceScanMetaCols.size == 2)
assert(fileSourceScanMetaCols.map(_.name).toSet == Set("file_size", "file_path"))

checkAnswer(
prunedDF,
Seq(Row("jack", 24, 12345L, f0(METADATA_FILE_SIZE)))
)
}
}

0 comments on commit 54f91d3

Please sign in to comment.