Skip to content

Commit

Permalink
[SPARK-50067][SQL] Codegen Support for SchemaOfCsv(by Invoke & Runtim…
Browse files Browse the repository at this point in the history
…eReplaceable)

### What changes were proposed in this pull request?
The pr aims to add `Codegen` Support for `schema_of_csv`.

### Why are the changes needed?
- improve codegen coverage.
- simplified code.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Pass GA & Existed UT (eg: CsvFunctionsSuite#`*schema_of_csv*`)

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#48595 from panbingkun/SPARK-50067.

Authored-by: panbingkun <[email protected]>
Signed-off-by: Max Gekk <[email protected]>
  • Loading branch information
panbingkun authored and MaxGekk committed Oct 23, 2024
1 parent 2cb7a16 commit 369c40c
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 26 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions.csv

import com.univocity.parsers.csv.CsvParser

import org.apache.spark.sql.catalyst.csv.{CSVInferSchema, CSVOptions}
import org.apache.spark.sql.types.{DataType, NullType, StructType}
import org.apache.spark.unsafe.types.UTF8String

case class SchemaOfCsvEvaluator(options: Map[String, String]) {

@transient
private lazy val csvOptions: CSVOptions = {
// 'lineSep' is a plan-wise option so we set a noncharacter, according to
// the unicode specification, which should not appear in Java's strings.
// See also SPARK-38955 and https://www.unicode.org/charts/PDF/UFFF0.pdf.
// scalastyle:off nonascii
val exprOptions = options ++ Map("lineSep" -> '\uFFFF'.toString)
// scalastyle:on nonascii
new CSVOptions(exprOptions, true, "UTC")
}

@transient
private lazy val csvParser: CsvParser = new CsvParser(csvOptions.asParserSettings)

@transient
private lazy val csvInferSchema = new CSVInferSchema(csvOptions)

final def evaluate(csv: UTF8String): Any = {
val row = csvParser.parseLine(csv.toString)
assert(row != null, "Parsed CSV record should not be null.")
val header = row.zipWithIndex.map { case (_, index) => s"_c$index" }
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val fieldTypes = csvInferSchema.inferRowType(startType, row)
val st = StructType(csvInferSchema.toStructFields(fieldTypes, header))
UTF8String.fromString(st.sql)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,14 @@ package org.apache.spark.sql.catalyst.expressions

import java.io.CharArrayWriter

import com.univocity.parsers.csv.CsvParser

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{DataTypeMismatch, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.csv._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
import org.apache.spark.sql.catalyst.expressions.csv.SchemaOfCsvEvaluator
import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.TypeUtils._
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryErrorsBase}
Expand Down Expand Up @@ -170,7 +170,7 @@ case class CsvToStructs(
case class SchemaOfCsv(
child: Expression,
options: Map[String, String])
extends UnaryExpression with CodegenFallback with QueryErrorsBase {
extends UnaryExpression with RuntimeReplaceable with QueryErrorsBase {

def this(child: Expression) = this(child, Map.empty[String, String])

Expand Down Expand Up @@ -202,30 +202,20 @@ case class SchemaOfCsv(
}
}

override def eval(v: InternalRow): Any = {
// 'lineSep' is a plan-wise option so we set a noncharacter, according to
// the unicode specification, which should not appear in Java's strings.
// See also SPARK-38955 and https://www.unicode.org/charts/PDF/UFFF0.pdf.
// scalastyle:off nonascii
val exprOptions = options ++ Map("lineSep" -> '\uFFFF'.toString)
// scalastyle:on nonascii
val parsedOptions = new CSVOptions(exprOptions, true, "UTC")
val parser = new CsvParser(parsedOptions.asParserSettings)
val row = parser.parseLine(csv.toString)
assert(row != null, "Parsed CSV record should not be null.")

val header = row.zipWithIndex.map { case (_, index) => s"_c$index" }
val startType: Array[DataType] = Array.fill[DataType](header.length)(NullType)
val inferSchema = new CSVInferSchema(parsedOptions)
val fieldTypes = inferSchema.inferRowType(startType, row)
val st = StructType(inferSchema.toStructFields(fieldTypes, header))
UTF8String.fromString(st.sql)
}

override def prettyName: String = "schema_of_csv"

override protected def withNewChildInternal(newChild: Expression): SchemaOfCsv =
copy(child = newChild)

@transient
private lazy val evaluator: SchemaOfCsvEvaluator = SchemaOfCsvEvaluator(options)

override def replacement: Expression = Invoke(
Literal.create(evaluator, ObjectType(classOf[SchemaOfCsvEvaluator])),
"evaluate",
dataType,
Seq(child),
Seq(child.dataType))
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ case class JsonToStructsEvaluator(
nullableSchema: DataType,
nameOfCorruptRecord: String,
timeZoneId: Option[String],
variantAllowDuplicateKeys: Boolean) extends Serializable {
variantAllowDuplicateKeys: Boolean) {

// This converts parsed rows to the desired output by the given schema.
@transient
Expand Down Expand Up @@ -117,7 +117,7 @@ case class JsonToStructsEvaluator(
case class StructsToJsonEvaluator(
options: Map[String, String],
inputSchema: DataType,
timeZoneId: Option[String]) extends Serializable {
timeZoneId: Option[String]) {

@transient
private lazy val writer = new CharArrayWriter()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
Project [schema_of_csv(1|abc, (sep,|)) AS schema_of_csv(1|abc)#0]
Project [invoke(SchemaOfCsvEvaluator(Map(sep -> |)).evaluate(1|abc)) AS schema_of_csv(1|abc)#0]
+- LocalRelation <empty>, [id#0L, a#0, b#0, d#0, e#0, f#0, g#0]

0 comments on commit 369c40c

Please sign in to comment.