Skip to content

Commit

Permalink
Update sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expr…
Browse files Browse the repository at this point in the history
…essions/aggregate/Mode.scala

Co-authored-by: Uros Bojanic <[email protected]>
  • Loading branch information
GideonPotok and uros-db committed May 24, 2024
1 parent 0bab248 commit f054589
Show file tree
Hide file tree
Showing 3 changed files with 4 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ case class Mode(
defaultCheck
} else if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
child.dataType.isInstanceOf[StringType]) {
defaultCheck
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
"The input to the function 'mode' was a complex type with non-binary collated fields," +
Expand Down Expand Up @@ -95,11 +95,7 @@ case class Mode(
!CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality =>
val collationId = c.collationId
val modeMap = buffer.toSeq.groupMapReduce {
case (key: String, _) =>
CollationFactory.getCollationKey(UTF8String.fromString(key), collationId)
case (key: UTF8String, _) =>
CollationFactory.getCollationKey(key, collationId)
case (key, _) => key
case (k, _) => CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
modeMap
case _ => buffer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1670,21 +1670,11 @@ class CollationSQLExpressionsSuite
}

test("Support Mode.eval(buffer)") {
case class ModeTestCase[R](
collationId: String,
bufferValues: Map[String, Long],
result: R)
case class UTF8StringModeTestCase[R](
collationId: String,
bufferValues: Map[UTF8String, Long],
result: R)

val bufferValues = Map("a" -> 5L, "b" -> 4L, "B" -> 3L, "d" -> 2L, "e" -> 1L)
val testCasesStrings = Seq(ModeTestCase("utf8_binary", bufferValues, "a"),
ModeTestCase("utf8_binary_lcase", bufferValues, "b"),
ModeTestCase("unicode_ci", bufferValues, "b"),
ModeTestCase("unicode", bufferValues, "a"))

val bufferValuesUTF8String = Map(
UTF8String.fromString("a") -> 5L,
UTF8String.fromString("b") -> 4L,
Expand All @@ -1698,13 +1688,6 @@ class CollationSQLExpressionsSuite
UTF8StringModeTestCase("unicode_ci", bufferValuesUTF8String, "b"),
UTF8StringModeTestCase("unicode", bufferValuesUTF8String, "a"))

testCasesStrings.foreach(t => {
val buffer = new OpenHashMap[AnyRef, Long](5)
val myMode = Mode(child = Literal.create("some_column_name", StringType(t.collationId)))
t.bufferValues.foreach { case (k, v) => buffer.update(k, v) }
assert(myMode.eval(buffer).toString.toLowerCase() == t.result.toLowerCase())
})

testCasesUTF8String.foreach(t => {
val buffer = new OpenHashMap[AnyRef, Long](5)
val myMode = Mode(child = Literal.create("some_column_name", StringType(t.collationId)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ abstract class CollationBenchmarkBase extends BenchmarkBase with SqlBasedBenchma
collationTypes.foreach { collationType => {
val buffer = new OpenHashMap[AnyRef, Long](value.size)
value.foreach(v => {
buffer.update(v.toString, (v.hashCode() % 1000).toLong)
buffer.update(v, (v.hashCode() % 1000).toLong)
})
val modeCurrent = Mode(child =
Literal.create("some_column_name", StringType(collationType)))
Expand All @@ -230,7 +230,7 @@ abstract class CollationBenchmarkBase extends BenchmarkBase with SqlBasedBenchma
val buffer = new OpenHashMap[AnyRef, Long](value.size)
value.foreach(v => {
buffer.update(InternalRow.fromSeq(
Seq(v.toString, UTF8String.fromString(v.toString), 3)),
Seq(v, v, 3)),
(v.hashCode() % 1000).toLong)
})
val st = StructType(Seq(
Expand Down

0 comments on commit f054589

Please sign in to comment.