Skip to content

Commit

Permalink
latest review
Browse files Browse the repository at this point in the history
  • Loading branch information
GideonPotok committed May 21, 2024
1 parent 6dafa39 commit 51f397c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 57 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResul
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData}
import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData, UnsafeRowUtils}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, StringType, StructType}
import org.apache.spark.unsafe.types.UTF8String
Expand Down Expand Up @@ -50,32 +50,21 @@ case class Mode(
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def checkInputDataTypes(): TypeCheckResult = {
checkDataType(child.dataType)
}

private def checkDataType(dataType: DataType, level1: Boolean = true): TypeCheckResult = {
dataType match {
case ArrayType(elementType, _) =>
checkDataType(elementType, level1 = false)
case StructType(fields) =>
combineTypeCheckResults(fields.map { field =>
checkDataType(field.dataType, level1 = false)
})
case dt: StringType if !level1 &&
!CollationFactory.fetchCollation(dt.collationId).supportsBinaryEquality
=> TypeCheckResult.TypeCheckFailure(
s"Input to function $prettyName was a complex type" +
s" with strings collated on non-binary collations," +
s" which is not yet supported.")
case _ => TypeCheckResult.TypeCheckSuccess
val defaultCheck = super.checkInputDataTypes()
if (defaultCheck.isFailure) {
defaultCheck
} else {
child.dataType match {
case _: StructType | _: ArrayType if !UnsafeRowUtils.isBinaryStable(child.dataType) =>
TypeCheckResult.TypeCheckFailure(
s"Input to function mode was a complex type" +
s" with non-binary collated fields," +
s" which is not yet supported by mode.")
case _ => TypeCheckResult.TypeCheckSuccess
}
}
}

private def combineTypeCheckResults(results: Array[TypeCheckResult]): TypeCheckResult = {
results.collect({ case f: TypeCheckResult.TypeCheckFailure => f }).headOption.getOrElse(
TypeCheckResult.TypeCheckSuccess)
}

override def prettyName: String = "mode"

override def update(
Expand Down Expand Up @@ -114,7 +103,6 @@ case class Mode(
case (key, _) => key
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
modeMap
// case s: StructType => getBufferForStructType(buffer, s)
case _ => buffer
}
reverseOpt.map { reverse =>
Expand All @@ -127,31 +115,7 @@ case class Mode(
collationAwareBuffer.maxBy { case (key, count) => (count, key) }(ordering)
}.getOrElse(collationAwareBuffer.maxBy(_._2))._1
}
/*
private def getBufferForStructType(
buffer: OpenHashMap[AnyRef, Long],
s: StructType): Iterable[(AnyRef, Long)] = {
val fIsNonBinaryString = s.fields.map(f => (f, f.dataType)).map {
case (f, t: StringType) if !t.supportsBinaryEquality => (f.name, true)
case (f, t) => (f.name, false)
}.toMap
val fCollationIDs = s.fields.collect {
case f if fIsNonBinaryString(f.name) =>
(f.name, f.dataType.asInstanceOf[StringType].collationId)
}.toMap
buffer.groupMapReduce {
case (key: InternalRow, count) =>
key.toSeq(s).zip(s.fields).map {
case (k: String, field) if fIsNonBinaryString(field.name) =>
CollationFactory.getCollationKey(UTF8String.fromString(k), fCollationIDs(field.name))
case (k: UTF8String, field) if fIsNonBinaryString(field.name) =>
CollationFactory.getCollationKey(k, fCollationIDs(field.name))
case (k, _) => k
}
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
}
*/

override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Mode =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1510,8 +1510,8 @@ class CollationSQLExpressionsSuite
// Input to function mode was a complex type with strings collated on non-binary
// collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13;
val params = Seq(("sqlExpr", "\"mode(i)\""),
("msg", "Input to function mode was a complex type with strings" +
" collated on non-binary collations, which is not yet supported."),
("msg", "Input to function mode was a complex type with non-binary collated fields" +
", which is not yet supported by mode."),
("hint", "")).toMap
checkError(
exception = intercept[AnalysisException] {
Expand Down Expand Up @@ -1559,8 +1559,8 @@ class CollationSQLExpressionsSuite
// Input to function mode was a complex type with strings collated on non-binary
// collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13;
val params = Seq(("sqlExpr", "\"mode(i)\""),
("msg", "Input to function mode was a complex type with strings" +
" collated on non-binary collations, which is not yet supported."),
("msg", "Input to function mode was a complex type with non-binary collated fields" +
", which is not yet supported by mode."),
("hint", "")).toMap
checkError(
exception = intercept[AnalysisException] {
Expand Down Expand Up @@ -1606,8 +1606,8 @@ class CollationSQLExpressionsSuite
val query = s"SELECT lower(element_at(element_at(mode(i), 1).s1.a2, 1)) FROM ${tableName}"
if(t.collationId == "utf8_binary_lcase" || t.collationId == "unicode_ci") {
val params = Seq(("sqlExpr", "\"mode(i)\""),
("msg", "Input to function mode was a complex type with strings" +
" collated on non-binary collations, which is not yet supported."),
("msg", "Input to function mode was a complex type with non-binary collated fields" +
", which is not yet supported by mode."),
("hint", "")).toMap
checkError(
exception = intercept[AnalysisException] {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@ object CollationBenchmark extends CollationBenchmarkBase {
benchmarkStartsWith(collationTypes, inputs)
benchmarkEndsWith(collationTypes, inputs)
benchmarkMode(collationTypes, generateBaseInputStringswithUniqueGroupNumber(10000L))
benchmarkModeStruct(collationTypes.filter(c => c == "UNICODE" || c == "UTF8_BINARY"), generateBaseInputStringswithUniqueGroupNumber(10000L))
benchmarkModeStruct(collationTypes.filter(c => c == "UNICODE" || c == "UTF8_BINARY"),
generateBaseInputStringswithUniqueGroupNumber(10000L))
benchmarkModeOnDataFrame(collationTypes, generateDataframeInput(10000L))
}
}
Expand Down

0 comments on commit 51f397c

Please sign in to comment.