From 2268044150db0978099d08dd95a3e08e0100dd9e Mon Sep 17 00:00:00 2001 From: GideonPotok Date: Mon, 20 May 2024 18:34:39 -0400 Subject: [PATCH] what it could look like --- .../catalyst/expressions/aggregate/Mode.scala | 40 +++++++++++-------- .../benchmark/CollationBenchmark.scala | 3 +- 2 files changed, 25 insertions(+), 18 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index f3a98ef749451..49b31e042000d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -42,6 +42,9 @@ case class Mode( this(child, 0, 0, Some(reverse)) } + private lazy val binaryKeys: scala.collection.mutable.Map[String, String] = + scala.collection.mutable.Map.empty + // Returns null for empty inputs override def nullable: Boolean = true @@ -83,8 +86,24 @@ case class Mode( input: InternalRow): OpenHashMap[AnyRef, Long] = { val key = child.eval(input) + val keyNew = child.dataType match { + case c: StringType if + !CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality => + val collationId = c.collationId + key match { + case (key: String, _) => + CollationFactory.getCollationKey(UTF8String.fromString(key), collationId) + case (key: UTF8String, _) => + CollationFactory.getCollationKey(key, collationId) + case (key, _) => key + } + case _ => key + } if (key != null) { - buffer.changeValue(InternalRow.copyValue(key).asInstanceOf[AnyRef], 1L, _ + 1L) + buffer.changeValue(InternalRow.copyValue(keyNew).asInstanceOf[AnyRef], 1L, _ + 1L) + if(key != keyNew && !binaryKeys.contains(keyNew.toString)) { + binaryKeys.put(keyNew.toString, key.toString) + } } buffer } @@ -102,21 +121,7 @@ case class Mode( if (buffer.isEmpty) { return null } - val collationAwareBuffer = child.dataType match { - case c: StringType if - !CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality => - val collationId = c.collationId - val modeMap = buffer.toSeq.groupMapReduce { - case (key: String, _) => - CollationFactory.getCollationKey(UTF8String.fromString(key), collationId) - case (key: UTF8String, _) => - CollationFactory.getCollationKey(key, collationId) - case (key, _) => key - }(x => x)((x, y) => (x._1, x._2 + y._2)).values - modeMap -// case s: StructType => getBufferForStructType(buffer, s) - case _ => buffer - } + val collationAwareBuffer = buffer reverseOpt.map { reverse => val defaultKeyOrdering = if (reverse) { PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse @@ -124,7 +129,8 @@ case class Mode( PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]] } val ordering = Ordering.Tuple2(Ordering.Long, defaultKeyOrdering) - collationAwareBuffer.maxBy { case (key, count) => (count, key) }(ordering) + collationAwareBuffer.maxBy { case (key, count) => (count, + binaryKeys.getOrElse(key.toString, key)) }(ordering) }.getOrElse(collationAwareBuffer.maxBy(_._2))._1 } /* diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala index 9bd5cdfabb4df..2c3228fdc59ec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/CollationBenchmark.scala @@ -333,7 +333,8 @@ object CollationBenchmark extends CollationBenchmarkBase { benchmarkStartsWith(collationTypes, inputs) benchmarkEndsWith(collationTypes, inputs) benchmarkMode(collationTypes, generateBaseInputStringswithUniqueGroupNumber(10000L)) - benchmarkModeStruct(collationTypes.filter(c => c == "UNICODE" || c == "UTF8_BINARY"), generateBaseInputStringswithUniqueGroupNumber(10000L)) + benchmarkModeStruct(collationTypes.filter(c => c == "UNICODE" || c == "UTF8_BINARY"), + generateBaseInputStringswithUniqueGroupNumber(10000L)) benchmarkModeOnDataFrame(collationTypes, generateDataframeInput(10000L)) } }