Skip to content

Commit

Permalink
added checkinputdatatype to not support complex types containing nonb…
Browse files Browse the repository at this point in the history
…inary collations

added checkinputdatatype to not support complex types containing nonbinary collations

added struct test stuff
  • Loading branch information
GideonPotok committed May 20, 2024
1 parent 2e6d707 commit 6dafa39
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
Expand Down Expand Up @@ -49,6 +49,33 @@ case class Mode(

override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def checkInputDataTypes(): TypeCheckResult = {
checkDataType(child.dataType)
}

private def checkDataType(dataType: DataType, level1: Boolean = true): TypeCheckResult = {
dataType match {
case ArrayType(elementType, _) =>
checkDataType(elementType, level1 = false)
case StructType(fields) =>
combineTypeCheckResults(fields.map { field =>
checkDataType(field.dataType, level1 = false)
})
case dt: StringType if !level1 &&
!CollationFactory.fetchCollation(dt.collationId).supportsBinaryEquality
=> TypeCheckResult.TypeCheckFailure(
s"Input to function $prettyName was a complex type" +
s" with strings collated on non-binary collations," +
s" which is not yet supported.")
case _ => TypeCheckResult.TypeCheckSuccess
}
}

private def combineTypeCheckResults(results: Array[TypeCheckResult]): TypeCheckResult = {
results.collect({ case f: TypeCheckResult.TypeCheckFailure => f }).headOption.getOrElse(
TypeCheckResult.TypeCheckSuccess)
}

override def prettyName: String = "mode"

override def update(
Expand Down Expand Up @@ -87,10 +114,9 @@ case class Mode(
case (key, _) => key
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
modeMap
case s: StructType => getBufferForStructType(buffer, s)
// case s: StructType => getBufferForStructType(buffer, s)
case _ => buffer
}
println(s"Buffer: ${buffer.size} => ${collationAwareBuffer.size}")
reverseOpt.map { reverse =>
val defaultKeyOrdering = if (reverse) {
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse
Expand All @@ -101,7 +127,7 @@ case class Mode(
collationAwareBuffer.maxBy { case (key, count) => (count, key) }(ordering)
}.getOrElse(collationAwareBuffer.maxBy(_._2))._1
}

/*
private def getBufferForStructType(
buffer: OpenHashMap[AnyRef, Long],
s: StructType): Iterable[(AnyRef, Long)] = {
Expand All @@ -125,7 +151,7 @@ case class Mode(
}
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
}

*/
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Mode =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1505,7 +1505,127 @@ class CollationSQLExpressionsSuite
t.collationId + ", f2: INT>) USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(mode(i).f1) FROM ${tableName}"
checkAnswer(sql(query), Row(t.result))
if(t.collationId == "utf8_binary_lcase" || t.collationId == "unicode_ci") {
// Cannot resolve "mode(i)" due to data type mismatch:
// Input to function mode was a complex type with strings collated on non-binary
// collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13;
val params = Seq(("sqlExpr", "\"mode(i)\""),
("msg", "Input to function mode was a complex type with strings" +
" collated on non-binary collations, which is not yet supported."),
("hint", "")).toMap
checkError(
exception = intercept[AnalysisException] {
sql(query)
},
errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = params,
queryContext = Array(
ExpectedContext(objectType = "",
objectName = "",
startIndex = 13,
stopIndex = 19,
fragment = "mode(i)")
)
)
} else {
checkAnswer(sql(query), Row(t.result))
}
}
})
}

test("Support mode for string expression with collated strings in recursively nested struct") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R)
val testCases = Seq(
ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("utf8_binary_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
)
testCases.foreach(t => {
val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
(0L to numRepeats).map(_ => s"named_struct('f1', " +
s"named_struct('f2', collate('$elt', '${t.collationId}')), 'f3', 1)").mkString(",")
}.mkString(",")

val tableName = s"t_${t.collationId}_mode_nested_struct"
withTable(tableName) {
sql(s"CREATE TABLE ${tableName}(i STRUCT<f1: STRUCT<f2: STRING COLLATE " +
t.collationId + ">, f3: INT>) USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(mode(i).f1.f2) FROM ${tableName}"
if(t.collationId == "utf8_binary_lcase" || t.collationId == "unicode_ci") {
// Cannot resolve "mode(i)" due to data type mismatch:
// Input to function mode was a complex type with strings collated on non-binary
// collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13;
val params = Seq(("sqlExpr", "\"mode(i)\""),
("msg", "Input to function mode was a complex type with strings" +
" collated on non-binary collations, which is not yet supported."),
("hint", "")).toMap
checkError(
exception = intercept[AnalysisException] {
sql(query)
},
errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = params,
queryContext = Array(
ExpectedContext(objectType = "",
objectName = "",
startIndex = 13,
stopIndex = 19,
fragment = "mode(i)")
)
)
} else {
checkAnswer(sql(query), Row(t.result))
}
}
})
}

test("Support mode for string expression with collated strings in array complex type") {
case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R)
val testCases = Seq(
ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("utf8_binary_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"),
ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"),
ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b")
)
testCases.foreach(t => {
val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) =>
(0L to numRepeats).map(_ => s"array(named_struct('s1', named_struct('a2', " +
s"array(collate('$elt', '${t.collationId}'))), 'f3', 1))").mkString(",")
}.mkString(",")

val tableName = s"t_${t.collationId}_mode_nested_struct"
withTable(tableName) {
sql(s"CREATE TABLE ${tableName}(" +
s"i ARRAY<STRUCT<s1: STRUCT<a2: ARRAY<STRING COLLATE ${t.collationId}>>, f3: INT>>)" +
s" USING parquet")
sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
val query = s"SELECT lower(element_at(element_at(mode(i), 1).s1.a2, 1)) FROM ${tableName}"
if(t.collationId == "utf8_binary_lcase" || t.collationId == "unicode_ci") {
val params = Seq(("sqlExpr", "\"mode(i)\""),
("msg", "Input to function mode was a complex type with strings" +
" collated on non-binary collations, which is not yet supported."),
("hint", "")).toMap
checkError(
exception = intercept[AnalysisException] {
sql(query)
},
errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT",
parameters = params,
queryContext = Array(
ExpectedContext(objectType = "",
objectName = "",
startIndex = 35,
stopIndex = 41,
fragment = "mode(i)")
)
)
} else {
checkAnswer(sql(query), Row(t.result))
}
}
})
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.benchmark
import scala.concurrent.duration._

import org.apache.spark.benchmark.{Benchmark, BenchmarkBase}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.aggregate.Mode
Expand All @@ -27,7 +28,7 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.OpenHashMap

abstract class CollationBenchmarkBase extends BenchmarkBase {
abstract class CollationBenchmarkBase extends BenchmarkBase with SqlBasedBenchmark {
protected val collationTypes: Seq[String] =
Seq("UTF8_BINARY_LCASE", "UNICODE", "UTF8_BINARY", "UNICODE_CI")

Expand Down Expand Up @@ -248,6 +249,36 @@ abstract class CollationBenchmarkBase extends BenchmarkBase {

benchmark.run()
}
protected def generateDataframeInput(l: Long): DataFrame = {
spark.createDataFrame(generateSeqInput(l).map(_.toString).map(Tuple1.apply)).toDF("s1")
}

/**
* Benchmark to measure performance of mode function on a DataFrame column with collation.
* This is necessary for when we attempt to switch to a different mode implementation.
* Replacing changes to eval() with changes to update() in the Mode class.
*/
def benchmarkModeOnDataFrame(
collationTypes: Seq[String],
dfUncollated: DataFrame): Unit = {
val benchmark =
new Benchmark(
s"collation e2e benchmarks - mode - ${dfUncollated.count()} elements",
dfUncollated.count(),
warmupTime = 4.seconds,
output = output)
collationTypes.foreach(collationType => {
val columnName = s"s_$collationType"
val dfCollated = dfUncollated.selectExpr(
s"collate(s1, '$collationType') as $columnName")
benchmark.addCase(s"mode df column with collation - $collationType") { _ =>
dfCollated.selectExpr(s"mode($columnName)")
.noop()
}
}
)
benchmark.run()
}
}

/**
Expand Down Expand Up @@ -302,7 +333,8 @@ object CollationBenchmark extends CollationBenchmarkBase {
benchmarkStartsWith(collationTypes, inputs)
benchmarkEndsWith(collationTypes, inputs)
benchmarkMode(collationTypes, generateBaseInputStringswithUniqueGroupNumber(10000L))
benchmarkModeStruct(collationTypes, generateBaseInputStringswithUniqueGroupNumber(10000L))
benchmarkModeStruct(collationTypes.filter(c => c == "UNICODE" || c == "UTF8_BINARY"), generateBaseInputStringswithUniqueGroupNumber(10000L))
benchmarkModeOnDataFrame(collationTypes, generateDataframeInput(10000L))
}
}

Expand Down Expand Up @@ -333,6 +365,6 @@ object CollationNonASCIIBenchmark extends CollationBenchmarkBase {
benchmarkStartsWith(collationTypes, inputs)
benchmarkEndsWith(collationTypes, inputs)
benchmarkMode(collationTypes, inputs)
benchmarkModeStruct(collationTypes, inputs)
benchmarkModeStruct(collationTypes.filter(c => c == "UNICODE" || c == "UTF8_BINARY"), inputs)
}
}

0 comments on commit 6dafa39

Please sign in to comment.