From 369c40cdb24a0c62453dff1013e92ac05840c93a Mon Sep 17 00:00:00 2001 From: panbingkun Date: Wed, 23 Oct 2024 10:28:27 +0200 Subject: [PATCH] [SPARK-50067][SQL] Codegen Support for SchemaOfCsv(by Invoke & RuntimeReplaceable) ### What changes were proposed in this pull request? The pr aims to add `Codegen` Support for `schema_of_csv`. ### Why are the changes needed? - improve codegen coverage. - simplified code. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Pass GA & Existed UT (eg: CsvFunctionsSuite#`*schema_of_csv*`) ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48595 from panbingkun/SPARK-50067. Authored-by: panbingkun Signed-off-by: Max Gekk --- .../csv/CsvExpressionEvalUtils.scala | 53 +++++++++++++++++++ .../catalyst/expressions/csvExpressions.scala | 36 +++++-------- .../json/JsonExpressionEvalUtils.scala | 4 +- .../function_schema_of_csv.explain | 2 +- 4 files changed, 69 insertions(+), 26 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala new file mode 100644 index 0000000000000..abd0703fa7d70 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csv/CsvExpressionEvalUtils.scala @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.expressions.csv + +import com.univocity.parsers.csv.CsvParser + +import org.apache.spark.sql.catalyst.csv.{CSVInferSchema, CSVOptions} +import org.apache.spark.sql.types.{DataType, NullType, StructType} +import org.apache.spark.unsafe.types.UTF8String + +case class SchemaOfCsvEvaluator(options: Map[String, String]) { + + @transient + private lazy val csvOptions: CSVOptions = { + // 'lineSep' is a plan-wise option so we set a noncharacter, according to + // the unicode specification, which should not appear in Java's strings. + // See also SPARK-38955 and https://www.unicode.org/charts/PDF/UFFF0.pdf. + // scalastyle:off nonascii + val exprOptions = options ++ Map("lineSep" -> '\uFFFF'.toString) + // scalastyle:on nonascii + new CSVOptions(exprOptions, true, "UTC") + } + + @transient + private lazy val csvParser: CsvParser = new CsvParser(csvOptions.asParserSettings) + + @transient + private lazy val csvInferSchema = new CSVInferSchema(csvOptions) + + final def evaluate(csv: UTF8String): Any = { + val row = csvParser.parseLine(csv.toString) + assert(row != null, "Parsed CSV record should not be null.") + val header = row.zipWithIndex.map { case (_, index) => s"_c$index" } + val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) + val fieldTypes = csvInferSchema.inferRowType(startType, row) + val st = StructType(csvInferSchema.toStructFields(fieldTypes, header)) + UTF8String.fromString(st.sql) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala index cdad9938c5d03..5393b2bde93b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/csvExpressions.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.catalyst.expressions import java.io.CharArrayWriter -import com.univocity.parsers.csv.CsvParser - import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess} import org.apache.spark.sql.catalyst.csv._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} +import org.apache.spark.sql.catalyst.expressions.csv.SchemaOfCsvEvaluator +import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.catalyst.util.TypeUtils._ import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase} @@ -170,7 +170,7 @@ case class CsvToStructs( case class SchemaOfCsv( child: Expression, options: Map[String, String]) - extends UnaryExpression with CodegenFallback with QueryErrorsBase { + extends UnaryExpression with RuntimeReplaceable with QueryErrorsBase { def this(child: Expression) = this(child, Map.empty[String, String]) @@ -202,30 +202,20 @@ case class SchemaOfCsv( } } - override def eval(v: InternalRow): Any = { - // 'lineSep' is a plan-wise option so we set a noncharacter, according to - // the unicode specification, which should not appear in Java's strings. - // See also SPARK-38955 and https://www.unicode.org/charts/PDF/UFFF0.pdf. - // scalastyle:off nonascii - val exprOptions = options ++ Map("lineSep" -> '\uFFFF'.toString) - // scalastyle:on nonascii - val parsedOptions = new CSVOptions(exprOptions, true, "UTC") - val parser = new CsvParser(parsedOptions.asParserSettings) - val row = parser.parseLine(csv.toString) - assert(row != null, "Parsed CSV record should not be null.") - - val header = row.zipWithIndex.map { case (_, index) => s"_c$index" } - val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType) - val inferSchema = new CSVInferSchema(parsedOptions) - val fieldTypes = inferSchema.inferRowType(startType, row) - val st = StructType(inferSchema.toStructFields(fieldTypes, header)) - UTF8String.fromString(st.sql) - } - override def prettyName: String = "schema_of_csv" override protected def withNewChildInternal(newChild: Expression): SchemaOfCsv = copy(child = newChild) + + @transient + private lazy val evaluator: SchemaOfCsvEvaluator = SchemaOfCsvEvaluator(options) + + override def replacement: Expression = Invoke( + Literal.create(evaluator, ObjectType(classOf[SchemaOfCsvEvaluator])), + "evaluate", + dataType, + Seq(child), + Seq(child.dataType)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala index dd7d318f430b6..8d8ecc7805367 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/json/JsonExpressionEvalUtils.scala @@ -64,7 +64,7 @@ case class JsonToStructsEvaluator( nullableSchema: DataType, nameOfCorruptRecord: String, timeZoneId: Option[String], - variantAllowDuplicateKeys: Boolean) extends Serializable { + variantAllowDuplicateKeys: Boolean) { // This converts parsed rows to the desired output by the given schema. @transient @@ -117,7 +117,7 @@ case class JsonToStructsEvaluator( case class StructsToJsonEvaluator( options: Map[String, String], inputSchema: DataType, - timeZoneId: Option[String]) extends Serializable { + timeZoneId: Option[String]) { @transient private lazy val writer = new CharArrayWriter() diff --git a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_csv.explain b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_csv.explain index ecd181a4292de..23cd52a6e5663 100644 --- a/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_csv.explain +++ b/sql/connect/common/src/test/resources/query-tests/explain-results/function_schema_of_csv.explain @@ -1,2 +1,2 @@ -Project [schema_of_csv(1|abc, (sep,|)) AS schema_of_csv(1|abc)#0] +Project [invoke(SchemaOfCsvEvaluator(Map(sep -> |)).evaluate(1|abc)) AS schema_of_csv(1|abc)#0] +- LocalRelation , [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]