From 2c1c4d2614ae1ff902c244209f7ec3c79102d3e0 Mon Sep 17 00:00:00 2001 From: Chenhao Li Date: Tue, 24 Dec 2024 14:54:02 +0800 Subject: [PATCH] [SPARK-50644][SQL] Read variant struct in Parquet reader ### What changes were proposed in this pull request? It adds support for variant struct in Parquet reader. The concept of variant struct was introduced in https://github.com/apache/spark/pull/49235. It includes all the extracted fields from a variant column that the query requests. ### Why are the changes needed? By producing variant struct in Parquet reader, we can avoid reading/rebuilding the full variant and achieve more efficient variant processing. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit test. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #49263 from chenhao-db/spark_variant_struct_reader. Authored-by: Chenhao Li Signed-off-by: Wenchen Fan --- .../spark/types/variant/ShreddingUtils.java | 9 +- .../spark/types/variant/VariantSchema.java | 6 + .../parquet/ParquetColumnVector.java | 24 +- .../parquet/ParquetReadSupport.scala | 9 + .../parquet/ParquetRowConverter.scala | 26 +- .../parquet/ParquetSchemaConverter.scala | 4 + .../parquet/SparkShreddingUtils.scala | 597 +++++++++++++++++- .../spark/sql/VariantShreddingSuite.scala | 185 +++++- 8 files changed, 820 insertions(+), 40 deletions(-) diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/ShreddingUtils.java b/common/variant/src/main/java/org/apache/spark/types/variant/ShreddingUtils.java index 59e16b77ab01d..6a04bf9a2b259 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/ShreddingUtils.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/ShreddingUtils.java @@ -49,9 +49,8 @@ public static Variant rebuild(ShreddedRow row, VariantSchema schema) { throw malformedVariant(); } byte[] metadata = row.getBinary(schema.topLevelMetadataIdx); - if (schema.variantIdx >= 0 && schema.typedIdx < 0) { - // The variant is unshredded. We are not required to do anything special, but we can have an - // optimization to avoid `rebuild`. + if (schema.isUnshredded()) { + // `rebuild` is unnecessary for unshredded variant. if (row.isNullAt(schema.variantIdx)) { throw malformedVariant(); } @@ -65,8 +64,8 @@ public static Variant rebuild(ShreddedRow row, VariantSchema schema) { // Rebuild a variant value from the shredded data according to the reconstruction algorithm in // https://github.com/apache/parquet-format/blob/master/VariantShredding.md. // Append the result to `builder`. - private static void rebuild(ShreddedRow row, byte[] metadata, VariantSchema schema, - VariantBuilder builder) { + public static void rebuild(ShreddedRow row, byte[] metadata, VariantSchema schema, + VariantBuilder builder) { int typedIdx = schema.typedIdx; int variantIdx = schema.variantIdx; if (typedIdx >= 0 && !row.isNullAt(typedIdx)) { diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java index 551e46214859a..d1e6cc3a727fa 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantSchema.java @@ -138,6 +138,12 @@ public VariantSchema(int typedIdx, int variantIdx, int topLevelMetadataIdx, int this.arraySchema = arraySchema; } + // Return whether the variant column is unshrededed. The user is not required to do anything + // special, but can have certain optimizations for unshrededed variant. + public boolean isUnshredded() { + return topLevelMetadataIdx >= 0 && variantIdx >= 0 && typedIdx < 0; + } + @Override public String toString() { return "VariantSchema{" + diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java index 0b9a25fc46a0f..7fb8be7caf286 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnVector.java @@ -35,7 +35,6 @@ import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.types.VariantType; import org.apache.spark.types.variant.VariantSchema; -import org.apache.spark.unsafe.types.VariantVal; /** * Contains necessary information representing a Parquet column, either of primitive or nested type. @@ -49,6 +48,9 @@ final class ParquetColumnVector { // contains only one child that reads the underlying file content. This `ParquetColumnVector` // should assemble Spark variant values from the file content. private VariantSchema variantSchema; + // Only meaningful if `variantSchema` is not null. See `SparkShreddingUtils.getFieldsToExtract` + // for its meaning. + private FieldToExtract[] fieldsToExtract; /** * Repetition & Definition levels @@ -117,6 +119,7 @@ final class ParquetColumnVector { fileContent, capacity, memoryMode, missingColumns, false, null); children.add(contentVector); variantSchema = SparkShreddingUtils.buildVariantSchema(fileContentCol.sparkType()); + fieldsToExtract = SparkShreddingUtils.getFieldsToExtract(column.sparkType(), variantSchema); repetitionLevels = contentVector.repetitionLevels; definitionLevels = contentVector.definitionLevels; } else if (isPrimitive) { @@ -188,20 +191,11 @@ void assemble() { if (variantSchema != null) { children.get(0).assemble(); WritableColumnVector fileContent = children.get(0).getValueVector(); - int numRows = fileContent.getElementsAppended(); - vector.reset(); - vector.reserve(numRows); - WritableColumnVector valueChild = vector.getChild(0); - WritableColumnVector metadataChild = vector.getChild(1); - for (int i = 0; i < numRows; ++i) { - if (fileContent.isNullAt(i)) { - vector.appendStruct(true); - } else { - vector.appendStruct(false); - VariantVal v = SparkShreddingUtils.rebuild(fileContent.getStruct(i), variantSchema); - valueChild.appendByteArray(v.getValue(), 0, v.getValue().length); - metadataChild.appendByteArray(v.getMetadata(), 0, v.getMetadata().length); - } + if (fieldsToExtract == null) { + SparkShreddingUtils.assembleVariantBatch(fileContent, vector, variantSchema); + } else { + SparkShreddingUtils.assembleVariantStructBatch(fileContent, vector, variantSchema, + fieldsToExtract); } return; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 8dde02a4673f0..af0bf0d51f077 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -35,6 +35,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.datasources.VariantMetadata import org.apache.spark.sql.internal.{LegacyBehaviorPolicy, SQLConf} import org.apache.spark.sql.types._ @@ -221,6 +222,9 @@ object ParquetReadSupport extends Logging { clipParquetMapType( parquetType.asGroupType(), t.keyType, t.valueType, caseSensitive, useFieldId) + case t: StructType if VariantMetadata.isVariantStruct(t) => + clipVariantSchema(parquetType.asGroupType(), t) + case t: StructType => clipParquetGroup(parquetType.asGroupType(), t, caseSensitive, useFieldId) @@ -390,6 +394,11 @@ object ParquetReadSupport extends Logging { .named(parquetRecord.getName) } + private def clipVariantSchema(parquetType: GroupType, variantStruct: StructType): GroupType = { + // TODO(SHREDDING): clip `parquetType` to retain the necessary columns. + parquetType + } + /** * Clips a Parquet [[GroupType]] which corresponds to a Catalyst [[StructType]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 3ed7fe37ccd96..550c2af43a706 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.errors.QueryExecutionErrors -import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.execution.datasources.{DataSourceUtils, VariantMetadata} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{UTF8String, VariantVal} @@ -498,6 +498,9 @@ private[parquet] class ParquetRowConverter( case t: MapType => new ParquetMapConverter(parquetType.asGroupType(), t, updater) + case t: StructType if VariantMetadata.isVariantStruct(t) => + new ParquetVariantConverter(t, parquetType.asGroupType(), updater) + case t: StructType => val wrappedUpdater = { // SPARK-30338: avoid unnecessary InternalRow copying for nested structs: @@ -536,12 +539,7 @@ private[parquet] class ParquetRowConverter( case t: VariantType => if (SQLConf.get.getConf(SQLConf.VARIANT_ALLOW_READING_SHREDDED)) { - // Infer a Spark type from `parquetType`. This piece of code is copied from - // `ParquetArrayConverter`. - val messageType = Types.buildMessage().addField(parquetType).named("foo") - val column = new ColumnIOFactory().getColumnIO(messageType) - val parquetSparkType = schemaConverter.convertField(column.getChild(0)).sparkType - new ParquetVariantConverter(parquetType.asGroupType(), parquetSparkType, updater) + new ParquetVariantConverter(t, parquetType.asGroupType(), updater) } else { new ParquetUnshreddedVariantConverter(parquetType.asGroupType(), updater) } @@ -909,13 +907,14 @@ private[parquet] class ParquetRowConverter( /** Parquet converter for Variant (shredded or unshredded) */ private final class ParquetVariantConverter( - parquetType: GroupType, - parquetSparkType: DataType, - updater: ParentContainerUpdater) + targetType: DataType, parquetType: GroupType, updater: ParentContainerUpdater) extends ParquetGroupConverter(updater) { private[this] var currentRow: Any = _ + private[this] val parquetSparkType = SparkShreddingUtils.parquetTypeToSparkType(parquetType) private[this] val variantSchema = SparkShreddingUtils.buildVariantSchema(parquetSparkType) + private[this] val fieldsToExtract = + SparkShreddingUtils.getFieldsToExtract(targetType, variantSchema) // A struct converter that reads the underlying file data. private[this] val fileConverter = new ParquetRowConverter( schemaConverter, @@ -932,7 +931,12 @@ private[parquet] class ParquetRowConverter( override def end(): Unit = { fileConverter.end() - val v = SparkShreddingUtils.rebuild(currentRow.asInstanceOf[InternalRow], variantSchema) + val row = currentRow.asInstanceOf[InternalRow] + val v = if (fieldsToExtract == null) { + SparkShreddingUtils.assembleVariant(row, variantSchema) + } else { + SparkShreddingUtils.assembleVariantStruct(row, variantSchema, fieldsToExtract) + } updater.set(v) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index 7f1b49e737900..64c2a3126ca9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -28,6 +28,7 @@ import org.apache.parquet.schema.Type.Repetition._ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.execution.datasources.VariantMetadata import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -185,6 +186,9 @@ class ParquetToSparkSchemaConverter( } else { convertVariantField(groupColumn) } + case groupColumn: GroupColumnIO if targetType.exists(VariantMetadata.isVariantStruct) => + val col = convertGroupField(groupColumn) + col.copy(sparkType = targetType.get, variantFileType = Some(col)) case groupColumn: GroupColumnIO => convertGroupField(groupColumn, targetType) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala index f38e188ed042c..a83ca78455faa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/SparkShreddingUtils.scala @@ -17,12 +17,23 @@ package org.apache.spark.sql.execution.datasources.parquet +import org.apache.parquet.io.ColumnIOFactory +import org.apache.parquet.schema.{Type => ParquetType, Types => ParquetTypes} + import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} -import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.expressions.variant._ +import org.apache.spark.sql.catalyst.expressions.variant.VariantPathParser.PathSegment +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.execution.RowToColumnConverter +import org.apache.spark.sql.execution.datasources.VariantMetadata +import org.apache.spark.sql.execution.vectorized.WritableColumnVector import org.apache.spark.sql.types._ import org.apache.spark.types.variant._ +import org.apache.spark.types.variant.VariantUtil.Type import org.apache.spark.unsafe.types._ case class SparkShreddedRow(row: SpecializedGetters) extends ShreddingUtils.ShreddedRow { @@ -45,6 +56,369 @@ case class SparkShreddedRow(row: SpecializedGetters) extends ShreddingUtils.Shre override def numElements(): Int = row.asInstanceOf[ArrayData].numElements() } +// The search result of a `PathSegment` in a `VariantSchema`. +case class SchemaPathSegment( + rawPath: PathSegment, + // Whether this path segment is an object or array extraction. + isObject: Boolean, + // `schema.typedIdx`, if the path exists in the schema (for object extraction, the schema + // should contain an object `typed_value` containing the requested field; similar for array + // extraction). Negative otherwise. + typedIdx: Int, + // For object extraction, it is the index of the desired field in `schema.objectSchema`. If the + // requested field doesn't exist, both `extractionIdx/typedIdx` are set to negative. + // For array extraction, it is the array index. The information is already stored in `rawPath`, + // but accessing a raw int should be more efficient than `rawPath`, which is an `Either`. + extractionIdx: Int) + +// Represent a single field in a variant struct (see `VariantMetadata` for definition), that is, a +// single requested field that the scan should produce by extracting from the variant column. +case class FieldToExtract(path: Array[SchemaPathSegment], reader: ParquetVariantReader) + +// A helper class to cast from scalar `typed_value` into a scalar `dataType`. Need a custom +// expression because it has different error reporting code than `Cast`. +case class ScalarCastHelper( + child: Expression, + dataType: DataType, + castArgs: VariantCastArgs) extends UnaryExpression { + // The expression is only for the internal use of `ScalarReader`, which can guarantee the child + // is not nullable. + assert(!child.nullable) + + // If `cast` is null, it means the cast always fails because the type combination is not allowed. + private val cast = if (Cast.canAnsiCast(child.dataType, dataType)) { + Cast(child, dataType, castArgs.zoneStr, EvalMode.TRY) + } else { + null + } + // Cast the input to string. Only used for reporting an invalid cast. + private val castToString = Cast(child, StringType, castArgs.zoneStr, EvalMode.ANSI) + + override def nullable: Boolean = !castArgs.failOnError + override def withNewChildInternal(newChild: Expression): UnaryExpression = copy(child = newChild) + + // No need to define the interpreted version of `eval`: the codegen must succeed. + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + // Throw an error or do nothing, depending on `castArgs.failOnError`. + val invalidCastCode = if (castArgs.failOnError) { + val castToStringCode = castToString.genCode(ctx) + val typeObj = ctx.addReferenceObj("dataType", dataType) + val cls = classOf[ScalarCastHelper].getName + s""" + ${castToStringCode.code} + $cls.throwInvalidVariantCast(${castToStringCode.value}, $typeObj); + """ + } else { + "" + } + if (cast != null) { + val castCode = cast.genCode(ctx) + val code = code""" + ${castCode.code} + boolean ${ev.isNull} = ${castCode.isNull}; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${castCode.value}; + if (${ev.isNull}) { $invalidCastCode } + """ + ev.copy(code = code) + } else { + val code = code""" + boolean ${ev.isNull} = true; + ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)}; + if (${ev.isNull}) { $invalidCastCode } + """ + ev.copy(code = code) + } + } +} + +object ScalarCastHelper { + // A helper function for codegen. The java compiler doesn't allow throwing a `Throwable` in a + // method without `throws` annotation. + def throwInvalidVariantCast(value: UTF8String, dataType: DataType): Any = + throw QueryExecutionErrors.invalidVariantCast(value.toString, dataType) +} + +// The base class to read Parquet variant values into a Spark type. +// For convenience, we also allow creating an instance of the base class itself. None of its +// functions can be used, but it can serve as a container of `targetType` and `castArgs`. +class ParquetVariantReader( + val schema: VariantSchema, val targetType: DataType, val castArgs: VariantCastArgs) { + // Read from a row containing a Parquet variant value (shredded or unshredded) and return a value + // of `targetType`. The row schema is described by `schema`. + // This function throws MALFORMED_VARIANT if the variant is missing. If the variant can be + // legally missing (the only possible situation is struct fields in object `typed_value`), the + // caller should check for it and avoid calling this function if the variant is missing. + def read(row: InternalRow, topLevelMetadata: Array[Byte]): Any = { + if (schema.typedIdx < 0 || row.isNullAt(schema.typedIdx)) { + if (schema.variantIdx < 0 || row.isNullAt(schema.variantIdx)) { + // Both `typed_value` and `value` are null, meaning the variant is missing. + throw QueryExecutionErrors.malformedVariant() + } + val v = new Variant(row.getBinary(schema.variantIdx), topLevelMetadata) + VariantGet.cast(v, targetType, castArgs) + } else { + readFromTyped(row, topLevelMetadata) + } + } + + // Subclasses should override it to produce the read result when `typed_value` is not null. + protected def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]): Any = + throw QueryExecutionErrors.unreachableError() + + // A util function to rebuild the variant in binary format from a Parquet variant value. + protected final def rebuildVariant(row: InternalRow, topLevelMetadata: Array[Byte]): Variant = { + val builder = new VariantBuilder(false) + ShreddingUtils.rebuild(SparkShreddedRow(row), topLevelMetadata, schema, builder) + builder.result() + } + + // A util function to throw error or return null when an invalid cast happens. + protected final def invalidCast(row: InternalRow, topLevelMetadata: Array[Byte]): Any = { + if (castArgs.failOnError) { + throw QueryExecutionErrors.invalidVariantCast( + rebuildVariant(row, topLevelMetadata).toJson(castArgs.zoneId), targetType) + } else { + null + } + } +} + +object ParquetVariantReader { + // Create a reader for `targetType`. If `schema` is null, meaning that the extraction path doesn't + // exist in `typed_value`, it returns an instance of `ParquetVariantReader`. As described in the + // class comment, the reader is only a container of `targetType` and `castArgs` in this case. + def apply(schema: VariantSchema, targetType: DataType, castArgs: VariantCastArgs, + isTopLevelUnshredded: Boolean = false): ParquetVariantReader = targetType match { + case _ if schema == null => new ParquetVariantReader(schema, targetType, castArgs) + case s: StructType => new StructReader(schema, s, castArgs) + case a: ArrayType => new ArrayReader(schema, a, castArgs) + case m@MapType(_: StringType, _, _) => new MapReader(schema, m, castArgs) + case v: VariantType => new VariantReader(schema, v, castArgs, isTopLevelUnshredded) + case s: AtomicType => new ScalarReader(schema, s, castArgs) + case _ => + // Type check should have rejected map with non-string type. + throw QueryExecutionErrors.unreachableError(s"Invalid target type: `${targetType.sql}`") + } +} + +// Read Parquet variant values into a Spark struct type. It reads unshredded fields (fields that are +// not in the typed object) from the `value`, and reads the shredded fields from the object +// `typed_value`. +// `value` must not contain any shredded field according to the shredding spec, but this requirement +// is not enforced. If `value` does contain a shredded field, no error will occur, and the field in +// object `typed_value` will be the final result. +private[this] final class StructReader( + schema: VariantSchema, targetType: StructType, castArgs: VariantCastArgs) + extends ParquetVariantReader(schema, targetType, castArgs) { + // For each field in `targetType`, store the index of the field with the same name in object + // `typed_value`, or -1 if it doesn't exist in object `typed_value`. + private[this] val fieldInputIndices: Array[Int] = targetType.fields.map { f => + val inputIdx = if (schema.objectSchemaMap != null) schema.objectSchemaMap.get(f.name) else null + if (inputIdx != null) inputIdx.intValue() else -1 + } + // For each field in `targetType`, store the reader from the corresponding field in object + // `typed_value`, or null if it doesn't exist in object `typed_value`. + private[this] val fieldReaders: Array[ParquetVariantReader] = + targetType.fields.zip(fieldInputIndices).map { case (f, inputIdx) => + if (inputIdx >= 0) { + val fieldSchema = schema.objectSchema(inputIdx).schema + ParquetVariantReader(fieldSchema, f.dataType, castArgs) + } else { + null + } + } + // If all fields in `targetType` can be found in object `typed_value`, then the reader doesn't + // need to read from `value`. + private[this] val needUnshreddedObject: Boolean = fieldInputIndices.exists(_ < 0) + + override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]): Any = { + if (schema.objectSchema == null) return invalidCast(row, topLevelMetadata) + val obj = row.getStruct(schema.typedIdx, schema.objectSchema.length) + val result = new GenericInternalRow(fieldInputIndices.length) + var unshreddedObject: Variant = null + if (needUnshreddedObject && schema.variantIdx >= 0 && !row.isNullAt(schema.variantIdx)) { + unshreddedObject = new Variant(row.getBinary(schema.variantIdx), topLevelMetadata) + if (unshreddedObject.getType != Type.OBJECT) throw QueryExecutionErrors.malformedVariant() + } + val numFields = fieldInputIndices.length + var i = 0 + while (i < numFields) { + val inputIdx = fieldInputIndices(i) + if (inputIdx >= 0) { + // Shredded field must not be null. + if (obj.isNullAt(inputIdx)) throw QueryExecutionErrors.malformedVariant() + val fieldSchema = schema.objectSchema(inputIdx).schema + val fieldInput = obj.getStruct(inputIdx, fieldSchema.numFields) + // Only read from the shredded field if it is not missing. + if ((fieldSchema.typedIdx >= 0 && !fieldInput.isNullAt(fieldSchema.typedIdx)) || + (fieldSchema.variantIdx >= 0 && !fieldInput.isNullAt(fieldSchema.variantIdx))) { + result.update(i, fieldReaders(i).read(fieldInput, topLevelMetadata)) + } + } else if (unshreddedObject != null) { + val fieldName = targetType.fields(i).name + val fieldType = targetType.fields(i).dataType + val unshreddedField = unshreddedObject.getFieldByKey(fieldName) + if (unshreddedField != null) { + result.update(i, VariantGet.cast(unshreddedField, fieldType, castArgs)) + } + } + i += 1 + } + result + } +} + +// Read Parquet variant values into a Spark array type. +private[this] final class ArrayReader( + schema: VariantSchema, targetType: ArrayType, castArgs: VariantCastArgs) + extends ParquetVariantReader(schema, targetType, castArgs) { + private[this] val elementReader = if (schema.arraySchema != null) { + ParquetVariantReader(schema.arraySchema, targetType.elementType, castArgs) + } else { + null + } + + override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]): Any = { + if (schema.arraySchema == null) return invalidCast(row, topLevelMetadata) + val elementNumFields = schema.arraySchema.numFields + val arr = row.getArray(schema.typedIdx) + val size = arr.numElements() + val result = new Array[Any](size) + var i = 0 + while (i < size) { + // Shredded array element must not be null. + if (arr.isNullAt(i)) throw QueryExecutionErrors.malformedVariant() + result(i) = elementReader.read(arr.getStruct(i, elementNumFields), topLevelMetadata) + i += 1 + } + new GenericArrayData(result) + } +} + +// Read Parquet variant values into a Spark map type with string key type. The input must be object +// for a valid cast. The resulting map contains shredded fields from object `typed_value` and +// unshredded fields from object `value`. +// `value` must not contain any shredded field according to the shredding spec. Unlike +// `StructReader`, this requirement is enforced in `MapReader`. If `value` does contain a shredded +// field, throw a MALFORMED_VARIANT error. The purpose is to avoid duplicate map keys. +private[this] final class MapReader( + schema: VariantSchema, targetType: MapType, castArgs: VariantCastArgs) + extends ParquetVariantReader(schema, targetType, castArgs) { + // Readers that convert each shredded field into the map value type. + private[this] val valueReaders = if (schema.objectSchema != null) { + schema.objectSchema.map { f => + ParquetVariantReader(f.schema, targetType.valueType, castArgs) + } + } else { + null + } + // `UTF8String` representation of shredded field names. Do the `String -> UTF8String` once, so + // that `readFromTyped` doesn't need to do it repeatedly. + private[this] val shreddedFieldNames = if (schema.objectSchema != null) { + schema.objectSchema.map { f => UTF8String.fromString(f.fieldName) } + } else { + null + } + + override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]): Any = { + if (schema.objectSchema == null) return invalidCast(row, topLevelMetadata) + val obj = row.getStruct(schema.typedIdx, schema.objectSchema.length) + val numShreddedFields = valueReaders.length + var unshreddedObject: Variant = null + if (schema.variantIdx >= 0 && !row.isNullAt(schema.variantIdx)) { + unshreddedObject = new Variant(row.getBinary(schema.variantIdx), topLevelMetadata) + if (unshreddedObject.getType != Type.OBJECT) throw QueryExecutionErrors.malformedVariant() + } + val numUnshreddedFields = if (unshreddedObject != null) unshreddedObject.objectSize() else 0 + var keyArray = new Array[UTF8String](numShreddedFields + numUnshreddedFields) + var valueArray = new Array[Any](numShreddedFields + numUnshreddedFields) + var mapLength = 0 + var i = 0 + while (i < numShreddedFields) { + // Shredded field must not be null. + if (obj.isNullAt(i)) throw QueryExecutionErrors.malformedVariant() + val fieldSchema = schema.objectSchema(i).schema + val fieldInput = obj.getStruct(i, fieldSchema.numFields) + // Only add the shredded field to map if it is not missing. + if ((fieldSchema.typedIdx >= 0 && !fieldInput.isNullAt(fieldSchema.typedIdx)) || + (fieldSchema.variantIdx >= 0 && !fieldInput.isNullAt(fieldSchema.variantIdx))) { + keyArray(mapLength) = shreddedFieldNames(i) + valueArray(mapLength) = valueReaders(i).read(fieldInput, topLevelMetadata) + mapLength += 1 + } + i += 1 + } + i = 0 + while (i < numUnshreddedFields) { + val field = unshreddedObject.getFieldAtIndex(i) + if (schema.objectSchemaMap.containsKey(field.key)) { + throw QueryExecutionErrors.malformedVariant() + } + keyArray(mapLength) = UTF8String.fromString(field.key) + valueArray(mapLength) = VariantGet.cast(field.value, targetType.valueType, castArgs) + mapLength += 1 + i += 1 + } + // Need to shrink the arrays if there are missing shredded fields. + if (mapLength < keyArray.length) { + keyArray = keyArray.slice(0, mapLength) + valueArray = valueArray.slice(0, mapLength) + } + ArrayBasedMapData(keyArray, valueArray) + } +} + +// Read Parquet variant values into a Spark variant type (the binary format). +private[this] final class VariantReader( + schema: VariantSchema, targetType: DataType, castArgs: VariantCastArgs, + // An optional optimization: the user can set it to true if the Parquet variant column is + // unshredded and the extraction path is empty. We are not required to do anything special, bu + // we can avoid rebuilding variant for optimization purpose. + private[this] val isTopLevelUnshredded: Boolean) + extends ParquetVariantReader(schema, targetType, castArgs) { + override def read(row: InternalRow, topLevelMetadata: Array[Byte]): Any = { + if (isTopLevelUnshredded) { + if (row.isNullAt(schema.variantIdx)) throw QueryExecutionErrors.malformedVariant() + return new VariantVal(row.getBinary(schema.variantIdx), topLevelMetadata) + } + val v = rebuildVariant(row, topLevelMetadata) + new VariantVal(v.getValue, v.getMetadata) + } +} + +// Read Parquet variant values into a Spark scalar type. When `typed_value` is not null but not a +// scalar, all other target types should return an invalid cast, but only the string target type can +// still build a string from array/object `typed_value`. For scalar `typed_value`, it depends on +// `ScalarCastHelper` to perform the cast. +// According to the shredding spec, scalar `typed_value` and `value` must not be non-null at the +// same time. The requirement is not enforced in this reader. If they are both non-null, no error +// will occur, and the reader will read from `typed_value`. +private[this] final class ScalarReader( + schema: VariantSchema, targetType: DataType, castArgs: VariantCastArgs) + extends ParquetVariantReader(schema, targetType, castArgs) { + private[this] val castProject = if (schema.scalarSchema != null) { + val scalarType = SparkShreddingUtils.scalarSchemaToSparkType(schema.scalarSchema) + // Read the cast input from ordinal `schema.typedIdx` in the input row. The cast input is never + // null, because `readFromTyped` is only called when `typed_value` is not null. + val input = BoundReference(schema.typedIdx, scalarType, nullable = false) + MutableProjection.create(Seq(ScalarCastHelper(input, targetType, castArgs))) + } else { + null + } + + override def readFromTyped(row: InternalRow, topLevelMetadata: Array[Byte]): Any = { + if (castProject == null) { + return if (targetType.isInstanceOf[StringType]) { + UTF8String.fromString(rebuildVariant(row, topLevelMetadata).toJson(castArgs.zoneId)) + } else { + invalidCast(row, topLevelMetadata) + } + } + val result = castProject(row) + if (result.isNullAt(0)) null else result.get(0, targetType) + } +} + case object SparkShreddingUtils { val VariantValueFieldName = "value"; val TypedValueFieldName = "typed_value"; @@ -126,6 +500,11 @@ case object SparkShreddingUtils { var objectSchema: Array[VariantSchema.ObjectField] = null var arraySchema: VariantSchema = null + // The struct must not be empty or contain duplicate field names. The latter is enforced in the + // loop below (`if (typedIdx != -1)` and other similar checks). + if (schema.fields.isEmpty) { + throw QueryCompilationErrors.invalidVariantShreddingSchema(schema) + } schema.fields.zipWithIndex.foreach { case (f, i) => f.name match { case TypedValueFieldName => @@ -135,8 +514,11 @@ case object SparkShreddingUtils { typedIdx = i f.dataType match { case StructType(fields) => - objectSchema = - new Array[VariantSchema.ObjectField](fields.length) + // The struct must not be empty or contain duplicate field names. + if (fields.isEmpty || fields.map(_.name).distinct.length != fields.length) { + throw QueryCompilationErrors.invalidVariantShreddingSchema(schema) + } + objectSchema = new Array[VariantSchema.ObjectField](fields.length) fields.zipWithIndex.foreach { case (field, fieldIdx) => field.dataType match { case s: StructType => @@ -188,6 +570,32 @@ case object SparkShreddingUtils { scalarSchema, objectSchema, arraySchema) } + // Convert a scalar variant schema into a Spark scalar type. + def scalarSchemaToSparkType(scalar: VariantSchema.ScalarType): DataType = scalar match { + case _: VariantSchema.StringType => StringType + case it: VariantSchema.IntegralType => it.size match { + case VariantSchema.IntegralSize.BYTE => ByteType + case VariantSchema.IntegralSize.SHORT => ShortType + case VariantSchema.IntegralSize.INT => IntegerType + case VariantSchema.IntegralSize.LONG => LongType + } + case _: VariantSchema.FloatType => FloatType + case _: VariantSchema.DoubleType => DoubleType + case _: VariantSchema.BooleanType => BooleanType + case _: VariantSchema.BinaryType => BinaryType + case dt: VariantSchema.DecimalType => DecimalType(dt.precision, dt.scale) + case _: VariantSchema.DateType => DateType + case _: VariantSchema.TimestampType => TimestampType + case _: VariantSchema.TimestampNTZType => TimestampNTZType + } + + // Convert a Parquet type into a Spark data type. + def parquetTypeToSparkType(parquetType: ParquetType): DataType = { + val messageType = ParquetTypes.buildMessage().addField(parquetType).named("foo") + val column = new ColumnIOFactory().getColumnIO(messageType) + new ParquetToSparkSchemaConverter().convertField(column.getChild(0)).sparkType + } + class SparkShreddedResult(schema: VariantSchema) extends VariantShreddingWriter.ShreddedResult { // Result is stored as an InternalRow. val row = new GenericInternalRow(schema.numFields) @@ -243,8 +651,187 @@ case object SparkShreddingUtils { .row } - def rebuild(row: InternalRow, schema: VariantSchema): VariantVal = { + // Return a list of fields to extract. `targetType` must be either variant or variant struct. + // If it is variant, return null because the target is the full variant and there is no field to + // extract. If it is variant struct, return a list of fields matching the variant struct fields. + def getFieldsToExtract(targetType: DataType, inputSchema: VariantSchema): Array[FieldToExtract] = + targetType match { + case _: VariantType => null + case s: StructType if VariantMetadata.isVariantStruct(s) => + s.fields.map { f => + val metadata = VariantMetadata.fromMetadata(f.metadata) + val rawPath = metadata.parsedPath() + val schemaPath = new Array[SchemaPathSegment](rawPath.length) + var schema = inputSchema + // Search `rawPath` in `schema` to produce `schemaPath`. If a raw path segment cannot be + // found at a certain level of the file type, then `typedIdx` will be -1 starting from + // this position, and the final `schema` will be null. + for (i <- rawPath.indices) { + val isObject = rawPath(i).isLeft + var typedIdx = -1 + var extractionIdx = -1 + rawPath(i) match { + case scala.util.Left(key) if schema != null && schema.objectSchema != null => + val fieldIdx = schema.objectSchemaMap.get(key) + if (fieldIdx != null) { + typedIdx = schema.typedIdx + extractionIdx = fieldIdx + schema = schema.objectSchema(fieldIdx).schema + } else { + schema = null + } + case scala.util.Right(index) if schema != null && schema.arraySchema != null => + typedIdx = schema.typedIdx + extractionIdx = index + schema = schema.arraySchema + case _ => + schema = null + } + schemaPath(i) = SchemaPathSegment(rawPath(i), isObject, typedIdx, extractionIdx) + } + val reader = ParquetVariantReader(schema, f.dataType, VariantCastArgs( + metadata.failOnError, + Some(metadata.timeZoneId), + DateTimeUtils.getZoneId(metadata.timeZoneId)), + isTopLevelUnshredded = schemaPath.isEmpty && inputSchema.isUnshredded) + FieldToExtract(schemaPath, reader) + } + case _ => + throw QueryExecutionErrors.unreachableError(s"Invalid target type: `${targetType.sql}`") + } + + // Extract a single variant struct field from a Parquet variant value. It steps into `inputRow` + // according to the variant extraction path, and read the extracted value as the target type. + private def extractField( + inputRow: InternalRow, + topLevelMetadata: Array[Byte], + inputSchema: VariantSchema, + pathList: Array[SchemaPathSegment], + reader: ParquetVariantReader): Any = { + var pathIdx = 0 + val pathLen = pathList.length + var row = inputRow + var schema = inputSchema + while (pathIdx < pathLen) { + val path = pathList(pathIdx) + + if (path.typedIdx < 0) { + // The extraction doesn't exist in `typed_value`. Try to extract the remaining part of the + // path in `value`. + val variantIdx = schema.variantIdx + if (variantIdx < 0 || row.isNullAt(variantIdx)) return null + var v = new Variant(row.getBinary(variantIdx), topLevelMetadata) + while (pathIdx < pathLen) { + v = pathList(pathIdx).rawPath match { + case scala.util.Left(key) if v.getType == Type.OBJECT => v.getFieldByKey(key) + case scala.util.Right(index) if v.getType == Type.ARRAY => v.getElementAtIndex(index) + case _ => null + } + if (v == null) return null + pathIdx += 1 + } + return VariantGet.cast(v, reader.targetType, reader.castArgs) + } + + if (row.isNullAt(path.typedIdx)) return null + if (path.isObject) { + val obj = row.getStruct(path.typedIdx, schema.objectSchema.length) + // Object field must not be null. + if (obj.isNullAt(path.extractionIdx)) throw QueryExecutionErrors.malformedVariant() + schema = schema.objectSchema(path.extractionIdx).schema + row = obj.getStruct(path.extractionIdx, schema.numFields) + // Return null if the field is missing. + if ((schema.typedIdx < 0 || row.isNullAt(schema.typedIdx)) && + (schema.variantIdx < 0 || row.isNullAt(schema.variantIdx))) { + return null + } + } else { + val arr = row.getArray(path.typedIdx) + // Return null if the extraction index is out of bound. + if (path.extractionIdx >= arr.numElements()) return null + // Array element must not be null. + if (arr.isNullAt(path.extractionIdx)) throw QueryExecutionErrors.malformedVariant() + schema = schema.arraySchema + row = arr.getStruct(path.extractionIdx, schema.numFields) + } + pathIdx += 1 + } + reader.read(row, topLevelMetadata) + } + + // Assemble a variant (binary format) from a Parquet variant value. + def assembleVariant(row: InternalRow, schema: VariantSchema): VariantVal = { val v = ShreddingUtils.rebuild(SparkShreddedRow(row), schema) new VariantVal(v.getValue, v.getMetadata) } + + // Assemble a variant struct, in which each field is extracted from the Parquet variant value. + def assembleVariantStruct( + inputRow: InternalRow, + schema: VariantSchema, + fields: Array[FieldToExtract]): InternalRow = { + if (inputRow.isNullAt(schema.topLevelMetadataIdx)) { + throw QueryExecutionErrors.malformedVariant() + } + val topLevelMetadata = inputRow.getBinary(schema.topLevelMetadataIdx) + val numFields = fields.length + val resultRow = new GenericInternalRow(numFields) + var fieldIdx = 0 + while (fieldIdx < numFields) { + resultRow.update(fieldIdx, extractField(inputRow, topLevelMetadata, schema, + fields(fieldIdx).path, fields(fieldIdx).reader)) + fieldIdx += 1 + } + resultRow + } + + // Assemble a batch of variant (binary format) from a batch of Parquet variant values. + def assembleVariantBatch( + input: WritableColumnVector, + output: WritableColumnVector, + schema: VariantSchema): Unit = { + val numRows = input.getElementsAppended + output.reset() + output.reserve(numRows) + val valueChild = output.getChild(0) + val metadataChild = output.getChild(1) + var i = 0 + while (i < numRows) { + if (input.isNullAt(i)) { + output.appendStruct(true) + } else { + output.appendStruct(false) + val v = SparkShreddingUtils.assembleVariant(input.getStruct(i), schema) + valueChild.appendByteArray(v.getValue, 0, v.getValue.length) + metadataChild.appendByteArray(v.getMetadata, 0, v.getMetadata.length) + } + i += 1 + } + } + + // Assemble a batch of variant struct from a batch of Parquet variant values. + def assembleVariantStructBatch( + input: WritableColumnVector, + output: WritableColumnVector, + schema: VariantSchema, + fields: Array[FieldToExtract]): Unit = { + val numRows = input.getElementsAppended + output.reset() + output.reserve(numRows) + val converter = new RowToColumnConverter(StructType(Array(StructField("", output.dataType())))) + val converterVectors = Array(output) + val converterRow = new GenericInternalRow(1) + output.reset() + output.reserve(input.getElementsAppended) + var i = 0 + while (i < numRows) { + if (input.isNullAt(i)) { + converterRow.update(0, null) + } else { + converterRow.update(0, assembleVariantStruct(input.getStruct(i), schema, fields)) + } + converter.convert(converterRow, converterVectors) + i += 1 + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala index 5d5c441052558..b6623bb57a716 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/VariantShreddingSuite.scala @@ -22,13 +22,21 @@ import java.sql.{Date, Timestamp} import java.time.LocalDateTime import org.apache.spark.SparkThrowable +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.variant.VariantExpressionEvalUtils import org.apache.spark.sql.execution.datasources.parquet.{ParquetTest, SparkShreddingUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.types.variant._ +import org.apache.spark.unsafe.types.{UTF8String, VariantVal} class VariantShreddingSuite extends QueryTest with SharedSparkSession with ParquetTest { + def parseJson(s: String): VariantVal = { + val v = VariantBuilder.parseJson(s, false) + new VariantVal(v.getValue, v.getMetadata) + } + // Make a variant value binary by parsing a JSON string. def value(s: String): Array[Byte] = VariantBuilder.parseJson(s, false).getValue @@ -53,9 +61,21 @@ class VariantShreddingSuite extends QueryTest with SharedSparkSession with Parqu def writeSchema(schema: DataType): StructType = StructType(Array(StructField("v", SparkShreddingUtils.variantShreddingSchema(schema)))) + def withPushConfigs(pushConfigs: Seq[Boolean] = Seq(true, false))(fn: => Unit): Unit = { + for (push <- pushConfigs) { + withSQLConf(SQLConf.PUSH_VARIANT_INTO_SCAN.key -> push.toString) { + fn + } + } + } + + def isPushEnabled: Boolean = SQLConf.get.getConf(SQLConf.PUSH_VARIANT_INTO_SCAN) + def testWithTempPath(name: String)(block: File => Unit): Unit = test(name) { - withTempPath { path => - block(path) + withPushConfigs() { + withTempPath { path => + block(path) + } } } @@ -63,6 +83,9 @@ class VariantShreddingSuite extends QueryTest with SharedSparkSession with Parqu spark.createDataFrame(spark.sparkContext.parallelize(rows.map(Row(_)), numSlices = 1), schema) .write.mode("overwrite").parquet(path.getAbsolutePath) + def writeRows(path: File, schema: String, rows: Row*): Unit = + writeRows(path, StructType.fromDDL(schema), rows: _*) + def read(path: File): DataFrame = spark.read.schema("v variant").parquet(path.getAbsolutePath) @@ -150,10 +173,13 @@ class VariantShreddingSuite extends QueryTest with SharedSparkSession with Parqu // Top-level variant must not be missing. writeRows(path, writeSchema(IntegerType), Row(metadata(Nil), null, null)) checkException(path, "v", "MALFORMED_VARIANT") + // Array-element variant must not be missing. writeRows(path, writeSchema(ArrayType(IntegerType)), Row(metadata(Nil), null, Array(Row(null, null)))) checkException(path, "v", "MALFORMED_VARIANT") + checkException(path, "variant_get(v, '$[0]')", "MALFORMED_VARIANT") + // Shredded field must not be null. // Construct the schema manually, because SparkShreddingUtils.variantShreddingSchema will make // `a` non-nullable, which would prevent us from writing the file. @@ -164,12 +190,163 @@ class VariantShreddingSuite extends QueryTest with SharedSparkSession with Parqu StructField("a", StructType(Seq( StructField("value", BinaryType), StructField("typed_value", BinaryType)))))))))))) - writeRows(path, schema, - Row(metadata(Seq("a")), null, Row(null))) + writeRows(path, schema, Row(metadata(Seq("a")), null, Row(null))) checkException(path, "v", "MALFORMED_VARIANT") + checkException(path, "variant_get(v, '$.a')", "MALFORMED_VARIANT") + // `value` must not contain any shredded field. writeRows(path, writeSchema(StructType.fromDDL("a int")), Row(metadata(Seq("a")), value("""{"a": 1}"""), Row(Row(null, null)))) checkException(path, "v", "MALFORMED_VARIANT") + checkException(path, "cast(v as map)", "MALFORMED_VARIANT") + if (isPushEnabled) { + checkExpr(path, "cast(v as struct)", Row(null)) + checkExpr(path, "variant_get(v, '$.a', 'int')", null) + } else { + checkException(path, "cast(v as struct)", "MALFORMED_VARIANT") + checkException(path, "variant_get(v, '$.a', 'int')", "MALFORMED_VARIANT") + } + + // Scalar reader reads from `typed_value` if both `value` and `typed_value` are not null. + // Cast from `value` succeeds, cast from `typed_value` fails. + writeRows(path, "v struct", + Row(metadata(Nil), value("1"), "invalid")) + checkException(path, "cast(v as int)", "INVALID_VARIANT_CAST") + checkExpr(path, "try_cast(v as int)", null) + + // Cast from `value` fails, cast from `typed_value` succeeds. + writeRows(path, "v struct", + Row(metadata(Nil), value("\"invalid\""), "1")) + checkExpr(path, "cast(v as int)", 1) + checkExpr(path, "try_cast(v as int)", 1) + } + + testWithTempPath("extract from shredded object") { path => + val keys1 = Seq("a", "b", "c", "d") + val keys2 = Seq("a", "b", "c", "e", "f") + writeRows(path, "v struct, b struct," + + "c struct>>", + // {"a":1,"b":"2","c":3.3,"d":4.4}, d is in the left over value. + Row(metadata(keys1), shreddedValue("""{"d": 4.4}""", keys1), + Row(Row(null, 1), Row(value("\"2\"")), Row(Decimal("3.3")))), + // {"a":5.4,"b":-6,"e":{"f":[true]}}, e is in the left over value. + Row(metadata(keys2), shreddedValue("""{"e": {"f": [true]}}""", keys2), + Row(Row(value("5.4"), null), Row(value("-6")), Row(null))), + // [{"a":1}], the unshredded array at the top-level is put into `value` as a whole. + Row(metadata(Seq("a")), value("""[{"a": 1}]"""), null)) + + checkAnswer(read(path).selectExpr("variant_get(v, '$.a', 'int')", + "variant_get(v, '$.b', 'long')", "variant_get(v, '$.c', 'double')", + "variant_get(v, '$.d', 'decimal(9, 4)')"), + Seq(Row(1, 2L, 3.3, BigDecimal("4.4")), Row(5, -6L, null, null), Row(null, null, null, null))) + checkExpr(path, "variant_get(v, '$.e.f[0]', 'boolean')", null, true, null) + checkExpr(path, "variant_get(v, '$[0].a', 'boolean')", null, null, true) + checkExpr(path, "try_cast(v as struct)", + Row(1.0F, null), Row(5.4F, parseJson("""{"f": [true]}""")), null) + + // String "2" cannot be cast into boolean. + checkException(path, "variant_get(v, '$.b', 'boolean')", "INVALID_VARIANT_CAST") + // Decimal cannot be cast into date. + checkException(path, "variant_get(v, '$.c', 'date')", "INVALID_VARIANT_CAST") + // The value of `c` doesn't fit into `decimal(1, 1)`. + checkException(path, "variant_get(v, '$.c', 'decimal(1, 1)')", "INVALID_VARIANT_CAST") + checkExpr(path, "try_variant_get(v, '$.b', 'boolean')", null, true, null) + // Scalar cannot be cast into struct. + checkException(path, "variant_get(v, '$.a', 'struct')", "INVALID_VARIANT_CAST") + checkExpr(path, "try_variant_get(v, '$.a', 'struct')", null, null, null) + + checkExpr(path, "try_cast(v as map)", + Map("a" -> 1.0, "b" -> 2.0, "c" -> 3.3, "d" -> 4.4), + Map("a" -> 5.4, "b" -> -6.0, "e" -> null), null) + checkExpr(path, "try_cast(v as array)", null, null, Seq("""{"a":1}""")) + + val strings = Seq("""{"a":1,"b":"2","c":3.3,"d":4.4}""", + """{"a":5.4,"b":-6,"e":{"f":[true]}}""", """[{"a":1}]""") + checkExpr(path, "cast(v as string)", strings: _*) + checkExpr(path, "v", + VariantExpressionEvalUtils.castToVariant( + InternalRow(1, UTF8String.fromString("2"), Decimal("3.3000000000"), Decimal("4.4")), + StructType.fromDDL("a int, b string, c decimal(20, 10), d decimal(2, 1)") + ), + parseJson(strings(1)), + parseJson(strings(2)) + ) + } + + testWithTempPath("extract from shredded array") { path => + val keys = Seq("a", "b") + writeRows(path, "v struct>>>>", + // [{"a":"2000-01-01"},{"a":"1000-01-01","b":[7]}], b is in the left over value. + Row(metadata(keys), null, Array( + Row(null, Row(Row(null, "2000-01-01"))), + Row(shreddedValue("""{"b": [7]}""", keys), Row(Row(null, "1000-01-01"))))), + // [null,{"a":null},{"a":"null"},{}] + Row(metadata(keys), null, Array( + Row(value("null"), null), + Row(null, Row(Row(value("null"), null))), + Row(null, Row(Row(null, "null"))), + Row(null, Row(Row(null, null)))))) + + val date1 = Date.valueOf("2000-01-01") + val date2 = Date.valueOf("1000-01-01") + checkExpr(path, "variant_get(v, '$[0].a', 'date')", date1, null) + // try_cast succeeds. + checkExpr(path, "try_variant_get(v, '$[1].a', 'date')", date2, null) + // The first array returns null because of out-of-bound index. + // The second array returns "null". + checkExpr(path, "variant_get(v, '$[2].a', 'string')", null, "null") + // Return null because of invalid cast. + checkExpr(path, "try_variant_get(v, '$[1].a', 'int')", null, null) + + checkExpr(path, "variant_get(v, '$[0].b[0]', 'int')", null, null) + checkExpr(path, "variant_get(v, '$[1].b[0]', 'int')", 7, null) + // Validate timestamp-related casts uses the session time zone correctly. + Seq("Etc/UTC", "America/Los_Angeles").foreach { tz => + withSQLConf(SQLConf.SESSION_LOCAL_TIMEZONE.key -> tz) { + val expected = sql("select timestamp'1000-01-01', timestamp_ntz'1000-01-01'").head() + checkAnswer(read(path).selectExpr("variant_get(v, '$[1].a', 'timestamp')", + "variant_get(v, '$[1].a', 'timestamp_ntz')"), Seq(expected, Row(null, null))) + } + } + checkException(path, "variant_get(v, '$[0]', 'int')", "INVALID_VARIANT_CAST") + // An out-of-bound array access produces null. It never causes an invalid cast. + checkExpr(path, "variant_get(v, '$[4]', 'int')", null, null) + + checkExpr(path, "cast(v as array>>)", + Seq(Row("2000-01-01", null), Row("1000-01-01", Seq(7))), + Seq(null, Row(null, null), Row("null", null), Row(null, null))) + checkExpr(path, "cast(v as array>)", + Seq(Map("a" -> "2000-01-01"), Map("a" -> "1000-01-01", "b" -> "[7]")), + Seq(null, Map("a" -> null), Map("a" -> "null"), Map())) + checkExpr(path, "try_cast(v as array>)", + Seq(Map("a" -> date1), Map("a" -> date2, "b" -> null)), + Seq(null, Map("a" -> null), Map("a" -> null), Map())) + + val strings = Seq("""[{"a":"2000-01-01"},{"a":"1000-01-01","b":[7]}]""", + """[null,{"a":null},{"a":"null"},{}]""") + checkExpr(path, "cast(v as string)", strings: _*) + checkExpr(path, "v", strings.map(parseJson): _*) + } + + testWithTempPath("missing fields") { path => + writeRows(path, "v struct, b struct>>", + Row(metadata(Nil), Row(Row(null, null), Row(null))), + Row(metadata(Nil), Row(Row(value("null"), null), Row(null))), + Row(metadata(Nil), Row(Row(null, 1), Row(null))), + Row(metadata(Nil), Row(Row(null, null), Row(2))), + Row(metadata(Nil), Row(Row(value("null"), null), Row(2))), + Row(metadata(Nil), Row(Row(null, 3), Row(4)))) + + val strings = Seq("{}", """{"a":null}""", """{"a":1}""", """{"b":2}""", """{"a":null,"b":2}""", + """{"a":3,"b":4}""") + checkExpr(path, "cast(v as string)", strings: _*) + checkExpr(path, "v", strings.map(parseJson): _*) + + checkExpr(path, "variant_get(v, '$.a', 'string')", null, null, "1", null, null, "3") + checkExpr(path, "variant_get(v, '$.a')", null, parseJson("null"), parseJson("1"), null, + parseJson("null"), parseJson("3")) } }