diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index e421a1f4746ea..67b833c557c22 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -37,7 +37,7 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( def this() = this(64) - protected var _keySet = new OpenHashSet[K](initialCapacity) + protected var _keySet = new OpenHashSet[K](initialCapacity, 0.7, Some(true)) // Init in constructor (instead of in declaration) to work around a Scala compiler specialization // bug that would generate two arrays (one for Object and one for specialized T). diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index a42fa9ba6bc85..793cab7a92715 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -22,6 +22,8 @@ import scala.reflect._ import com.google.common.hash.Hashing import org.apache.spark.annotation.Private +import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.unsafe.types.UTF8String /** * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never @@ -43,7 +45,9 @@ import org.apache.spark.annotation.Private @Private class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( initialCapacity: Int, - loadFactor: Double) + loadFactor: Double, + specialPassedInHasher: Option[Boolean] = None + ) extends Serializable { require(initialCapacity <= OpenHashSet.MAX_CAPACITY, @@ -67,7 +71,9 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag]( case ClassTag.Int => new IntHasher().asInstanceOf[Hasher[T]] case ClassTag.Double => new DoubleHasher().asInstanceOf[Hasher[T]] case ClassTag.Float => new FloatHasher().asInstanceOf[Hasher[T]] - case _ => new Hasher[T] + case _ => specialPassedInHasher.map(_ => + new CollationStringHasher()).getOrElse( + new Hasher[T]).asInstanceOf[Hasher[T]] } protected var _capacity = nextPowerOf2(initialCapacity) @@ -314,6 +320,15 @@ object OpenHashSet { override def hash(o: Float): Int = java.lang.Float.floatToIntBits(o) } + class CollationStringHasher extends Hasher[String] { + override def hash(o: String): Int = { + // scalastyle:off println + println("Hash function for collation string is called") + // scalastyle:on println + CollationFactory.fetchCollation(2).hashFunction.applyAsLong(UTF8String.fromString(o)).toInt + } + } + private def grow1(newSize: Int): Unit = {} private def move1(oldPos: Int, newPos: Int): Unit = { } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index f28afbdb27ab6..761d5d8cc7ba6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup} +import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup} import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder} import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.types.PhysicalDataType @@ -49,21 +49,6 @@ case class Mode( override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - override def checkInputDataTypes(): TypeCheckResult = { - val defaultCheck = super.checkInputDataTypes() - if (defaultCheck.isFailure) { - defaultCheck - } else if (UnsafeRowUtils.isBinaryStable(child.dataType) || - child.dataType.isInstanceOf[StringType]) { - TypeCheckResult.TypeCheckSuccess - } else { - TypeCheckResult.TypeCheckFailure( - "The input to the function 'mode' was a complex type with non-binary collated fields," + - " which are currently not supported by 'mode'." - ) - } - } - override def prettyName: String = "mode" override def update( @@ -98,6 +83,11 @@ case class Mode( case (k, _) => CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId) }(x => x)((x, y) => (x._1, x._2 + y._2)).values modeMap + case complexType if !UnsafeRowUtils.isBinaryStable(child.dataType) => + val modeMap = buffer.toSeq.groupMapReduce { + case (k, _) => UnsafeRowUtils.getBinaryStableKey(k, child.dataType) + }(x => x)((x, y) => (x._1, x._2 + y._2)).values + modeMap case _ => buffer } reverseOpt.map { reverse => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala index e296b5be6134b..287b049da8fa5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String object UnsafeRowUtils { @@ -129,8 +130,34 @@ object UnsafeRowUtils { None } + def getBinaryStableKey(k: AnyRef, dataType: DataType): AnyRef = { + dataType match { + case complexType: StructType if !isBinaryStable(complexType) => + val row = k.asInstanceOf[UnsafeRow] + val key = new Array[AnyRef](complexType.fields.length) + complexType.fields.zipWithIndex.foreach { case (field, index) => + key(index) = getBinaryStableKey(row.get(index, field.dataType), field.dataType) + } + key + case complexType: ArrayType if !isBinaryStable(complexType.elementType) => + val array = k.asInstanceOf[ArrayData] + val key = new Array[AnyRef](array.numElements()) + array.foreach(complexType.elementType, (i, any) => { + key.update(i, + getBinaryStableKey(array.get(i, complexType.elementType), complexType.elementType)) + }) + key + case c: StringType if + !CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality => + val collationId = c.collationId + CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId) + case _ => k + } + } + /** * Wrapper of validateStructuralIntegrityWithReasonImpl, add more information for debugging + * * @param row The input UnsafeRow to be validated * @param expectedSchema The expected schema that should match with the UnsafeRow * @return None if all the checks pass. An error message if the row is not matched with the schema diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index a422484c60be5..f65b54afd9d4e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -1719,7 +1719,7 @@ class CollationSQLExpressionsSuite t.collationId + ", f2: INT>) USING parquet") sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) val query = s"SELECT lower(mode(i).f1) FROM ${tableName}" - if(t.collationId == "utf8_binary_lcase" || t.collationId == "unicode_ci") { + if(t.collationId == "") { // || t.collationId == "unicode_ci") { // Cannot resolve "mode(i)" due to data type mismatch: // Input to function mode was a complex type with strings collated on non-binary // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13;