From 6dafa39296558765f6cd6313e069fe5a331a14bb Mon Sep 17 00:00:00 2001 From: GideonPotok Date: Mon, 20 May 2024 11:05:03 -0400 Subject: [PATCH] added checkinputdatatype to not support complex types containing nonbinary collations added checkinputdatatype to not support complex types containing nonbinary collations added struct test stuff --- .../catalyst/expressions/aggregate/Mode.scala | 36 +++++- .../sql/CollationSQLExpressionsSuite.scala | 122 +++++++++++++++++- .../benchmark/CollationBenchmark.scala | 38 +++++- 3 files changed, 187 insertions(+), 9 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index 2af6a44102aa6..f3a98ef749451 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup} +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup} import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.types.PhysicalDataType @@ -49,6 +49,33 @@ case class Mode( override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) + override def checkInputDataTypes(): TypeCheckResult = { + checkDataType(child.dataType) + } + + private def checkDataType(dataType: DataType, level1: Boolean = true): TypeCheckResult = { + dataType match { + case ArrayType(elementType, _) => + checkDataType(elementType, level1 = false) + case StructType(fields) => + combineTypeCheckResults(fields.map { field => + checkDataType(field.dataType, level1 = false) + }) + case dt: StringType if !level1 && + !CollationFactory.fetchCollation(dt.collationId).supportsBinaryEquality + => TypeCheckResult.TypeCheckFailure( + s"Input to function $prettyName was a complex type" + + s" with strings collated on non-binary collations," + + s" which is not yet supported.") + case _ => TypeCheckResult.TypeCheckSuccess + } + } + + private def combineTypeCheckResults(results: Array[TypeCheckResult]): TypeCheckResult = { + results.collect({ case f: TypeCheckResult.TypeCheckFailure => f }).headOption.getOrElse( + TypeCheckResult.TypeCheckSuccess) + } + override def prettyName: String = "mode" override def update( @@ -87,10 +114,9 @@ case class Mode( case (key, _) => key }(x => x)((x, y) => (x._1, x._2 + y._2)).values modeMap - case s: StructType => getBufferForStructType(buffer, s) +// case s: StructType => getBufferForStructType(buffer, s) case _ => buffer } - println(s"Buffer: ${buffer.size} => ${collationAwareBuffer.size}") reverseOpt.map { reverse => val defaultKeyOrdering = if (reverse) { PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse @@ -101,7 +127,7 @@ case class Mode( collationAwareBuffer.maxBy { case (key, count) => (count, key) }(ordering) }.getOrElse(collationAwareBuffer.maxBy(_._2))._1 } - +/* private def getBufferForStructType( buffer: OpenHashMap[AnyRef, Long], s: StructType): Iterable[(AnyRef, Long)] = { @@ -125,7 +151,7 @@ case class Mode( } }(x => x)((x, y) => (x._1, x._2 + y._2)).values } - +*/ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Mode = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 19c81872363f0..cabbadf0496d5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -1505,7 +1505,127 @@ class CollationSQLExpressionsSuite t.collationId + ", f2: INT>) USING parquet") sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) val query = s"SELECT lower(mode(i).f1) FROM ${tableName}" - checkAnswer(sql(query), Row(t.result)) + if(t.collationId == "utf8_binary_lcase" || t.collationId == "unicode_ci") { + // Cannot resolve "mode(i)" due to data type mismatch: + // Input to function mode was a complex type with strings collated on non-binary + // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13; + val params = Seq(("sqlExpr", "\"mode(i)\""), + ("msg", "Input to function mode was a complex type with strings" + + " collated on non-binary collations, which is not yet supported."), + ("hint", "")).toMap + checkError( + exception = intercept[AnalysisException] { + sql(query) + }, + errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = params, + queryContext = Array( + ExpectedContext(objectType = "", + objectName = "", + startIndex = 13, + stopIndex = 19, + fragment = "mode(i)") + ) + ) + } else { + checkAnswer(sql(query), Row(t.result)) + } + } + }) + } + + test("Support mode for string expression with collated strings in recursively nested struct") { + case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) + val testCases = Seq( + ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("utf8_binary_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ) + testCases.foreach(t => { + val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => + (0L to numRepeats).map(_ => s"named_struct('f1', " + + s"named_struct('f2', collate('$elt', '${t.collationId}')), 'f3', 1)").mkString(",") + }.mkString(",") + + val tableName = s"t_${t.collationId}_mode_nested_struct" + withTable(tableName) { + sql(s"CREATE TABLE ${tableName}(i STRUCT, f3: INT>) USING parquet") + sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) + val query = s"SELECT lower(mode(i).f1.f2) FROM ${tableName}" + if(t.collationId == "utf8_binary_lcase" || t.collationId == "unicode_ci") { + // Cannot resolve "mode(i)" due to data type mismatch: + // Input to function mode was a complex type with strings collated on non-binary + // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13; + val params = Seq(("sqlExpr", "\"mode(i)\""), + ("msg", "Input to function mode was a complex type with strings" + + " collated on non-binary collations, which is not yet supported."), + ("hint", "")).toMap + checkError( + exception = intercept[AnalysisException] { + sql(query) + }, + errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = params, + queryContext = Array( + ExpectedContext(objectType = "", + objectName = "", + startIndex = 13, + stopIndex = 19, + fragment = "mode(i)") + ) + ) + } else { + checkAnswer(sql(query), Row(t.result)) + } + } + }) + } + + test("Support mode for string expression with collated strings in array complex type") { + case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) + val testCases = Seq( + ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("utf8_binary_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ) + testCases.foreach(t => { + val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => + (0L to numRepeats).map(_ => s"array(named_struct('s1', named_struct('a2', " + + s"array(collate('$elt', '${t.collationId}'))), 'f3', 1))").mkString(",") + }.mkString(",") + + val tableName = s"t_${t.collationId}_mode_nested_struct" + withTable(tableName) { + sql(s"CREATE TABLE ${tableName}(" + + s"i ARRAY>, f3: INT>>)" + + s" USING parquet") + sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) + val query = s"SELECT lower(element_at(element_at(mode(i), 1).s1.a2, 1)) FROM ${tableName}" + if(t.collationId == "utf8_binary_lcase" || t.collationId == "unicode_ci") { + val params = Seq(("sqlExpr", "\"mode(i)\""), + ("msg", "Input to function mode was a complex type with strings" + + " collated on non-binary collations, which is not yet supported."), + ("hint", "")).toMap + checkError( + exception = intercept[AnalysisException] { + sql(query) + }, + errorClass = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", + parameters = params, + queryContext = Array( + ExpectedContext(objectType = "", + objectName = "", + startIndex = 35, + stopIndex = 41, + fragment = "mode(i)") + ) + ) + } else { + checkAnswer(sql(query), Row(t.result)) + } } }) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala index 7232b9e49150d..9bd5cdfabb4df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.benchmark import scala.concurrent.duration._ import org.apache.spark.benchmark.{Benchmark, BenchmarkBase} +import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.aggregate.Mode @@ -27,7 +28,7 @@ import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructT import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.OpenHashMap -abstract class CollationBenchmarkBase extends BenchmarkBase { +abstract class CollationBenchmarkBase extends BenchmarkBase with SqlBasedBenchmark { protected val collationTypes: Seq[String] = Seq("UTF8_BINARY_LCASE", "UNICODE", "UTF8_BINARY", "UNICODE_CI") @@ -248,6 +249,36 @@ abstract class CollationBenchmarkBase extends BenchmarkBase { benchmark.run() } + protected def generateDataframeInput(l: Long): DataFrame = { + spark.createDataFrame(generateSeqInput(l).map(_.toString).map(Tuple1.apply)).toDF("s1") + } + + /** + * Benchmark to measure performance of mode function on a DataFrame column with collation. + * This is necessary for when we attempt to switch to a different mode implementation. + * Replacing changes to eval() with changes to update() in the Mode class. + */ + def benchmarkModeOnDataFrame( + collationTypes: Seq[String], + dfUncollated: DataFrame): Unit = { + val benchmark = + new Benchmark( + s"collation e2e benchmarks - mode - ${dfUncollated.count()} elements", + dfUncollated.count(), + warmupTime = 4.seconds, + output = output) + collationTypes.foreach(collationType => { + val columnName = s"s_$collationType" + val dfCollated = dfUncollated.selectExpr( + s"collate(s1, '$collationType') as $columnName") + benchmark.addCase(s"mode df column with collation - $collationType") { _ => + dfCollated.selectExpr(s"mode($columnName)") + .noop() + } + } + ) + benchmark.run() + } } /** @@ -302,7 +333,8 @@ object CollationBenchmark extends CollationBenchmarkBase { benchmarkStartsWith(collationTypes, inputs) benchmarkEndsWith(collationTypes, inputs) benchmarkMode(collationTypes, generateBaseInputStringswithUniqueGroupNumber(10000L)) - benchmarkModeStruct(collationTypes, generateBaseInputStringswithUniqueGroupNumber(10000L)) + benchmarkModeStruct(collationTypes.filter(c => c == "UNICODE" || c == "UTF8_BINARY"), generateBaseInputStringswithUniqueGroupNumber(10000L)) + benchmarkModeOnDataFrame(collationTypes, generateDataframeInput(10000L)) } } @@ -333,6 +365,6 @@ object CollationNonASCIIBenchmark extends CollationBenchmarkBase { benchmarkStartsWith(collationTypes, inputs) benchmarkEndsWith(collationTypes, inputs) benchmarkMode(collationTypes, inputs) - benchmarkModeStruct(collationTypes, inputs) + benchmarkModeStruct(collationTypes.filter(c => c == "UNICODE" || c == "UTF8_BINARY"), inputs) } }