diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
index e421a1f4746ea..67b833c557c22 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -37,7 +37,7 @@ class OpenHashMap[K : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
 
   def this() = this(64)
 
-  protected var _keySet = new OpenHashSet[K](initialCapacity)
+  protected var _keySet = new OpenHashSet[K](initialCapacity, 0.7, Some(true))
 
   // Init in constructor (instead of in declaration) to work around a Scala compiler specialization
   // bug that would generate two arrays (one for Object and one for specialized T).
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index a42fa9ba6bc85..793cab7a92715 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -22,6 +22,8 @@ import scala.reflect._
 import com.google.common.hash.Hashing
 
 import org.apache.spark.annotation.Private
+import org.apache.spark.sql.catalyst.util.CollationFactory
+import org.apache.spark.unsafe.types.UTF8String
 
 /**
  * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never
@@ -43,7 +45,9 @@ import org.apache.spark.annotation.Private
 @Private
 class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
     initialCapacity: Int,
-    loadFactor: Double)
+    loadFactor: Double,
+    specialPassedInHasher: Option[Boolean] = None
+  )
   extends Serializable {
 
   require(initialCapacity <= OpenHashSet.MAX_CAPACITY,
@@ -67,7 +71,9 @@ class OpenHashSet[@specialized(Long, Int, Double, Float) T: ClassTag](
     case ClassTag.Int => new IntHasher().asInstanceOf[Hasher[T]]
     case ClassTag.Double => new DoubleHasher().asInstanceOf[Hasher[T]]
     case ClassTag.Float => new FloatHasher().asInstanceOf[Hasher[T]]
-    case _ => new Hasher[T]
+    case _ => specialPassedInHasher.map(_ =>
+      new CollationStringHasher()).getOrElse(
+      new Hasher[T]).asInstanceOf[Hasher[T]]
   }
 
   protected var _capacity = nextPowerOf2(initialCapacity)
@@ -314,6 +320,15 @@ object OpenHashSet {
     override def hash(o: Float): Int = java.lang.Float.floatToIntBits(o)
   }
 
+  class CollationStringHasher extends Hasher[String] {
+    override def hash(o: String): Int = {
+      // scalastyle:off println
+      println("Hash function for collation string is called")
+      // scalastyle:on println
+      CollationFactory.fetchCollation(2).hashFunction.applyAsLong(UTF8String.fromString(o)).toInt
+    }
+  }
+
   private def grow1(newSize: Int): Unit = {}
   private def move1(oldPos: Int, newPos: Int): Unit = { }
 
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
index f28afbdb27ab6..761d5d8cc7ba6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst.expressions.aggregate
 
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup}
+import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup}
 import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
 import org.apache.spark.sql.catalyst.trees.UnaryLike
 import org.apache.spark.sql.catalyst.types.PhysicalDataType
@@ -49,21 +49,6 @@ case class Mode(
 
   override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
 
-  override def checkInputDataTypes(): TypeCheckResult = {
-    val defaultCheck = super.checkInputDataTypes()
-    if (defaultCheck.isFailure) {
-      defaultCheck
-    } else if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
-      child.dataType.isInstanceOf[StringType]) {
-      TypeCheckResult.TypeCheckSuccess
-    } else {
-      TypeCheckResult.TypeCheckFailure(
-        "The input to the function 'mode' was a complex type with non-binary collated fields," +
-          " which are currently not supported by 'mode'."
-      )
-    }
-  }
-
   override def prettyName: String = "mode"
 
   override def update(
@@ -98,6 +83,11 @@ case class Mode(
          case (k, _) => CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
         }(x => x)((x, y) => (x._1, x._2 + y._2)).values
         modeMap
+      case complexType if !UnsafeRowUtils.isBinaryStable(child.dataType) =>
+        val modeMap = buffer.toSeq.groupMapReduce {
+          case (k, _) => UnsafeRowUtils.getBinaryStableKey(k, child.dataType)
+        }(x => x)((x, y) => (x._1, x._2 + y._2)).values
+        modeMap
       case _ => buffer
     }
     reverseOpt.map { reverse =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
index e296b5be6134b..287b049da8fa5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtils.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.util
 
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
 import org.apache.spark.sql.types._
+import org.apache.spark.unsafe.types.UTF8String
 
 object UnsafeRowUtils {
 
@@ -129,8 +130,34 @@ object UnsafeRowUtils {
     None
   }
 
+  def getBinaryStableKey(k: AnyRef, dataType: DataType): AnyRef = {
+    dataType match {
+      case complexType: StructType if !isBinaryStable(complexType) =>
+        val row = k.asInstanceOf[UnsafeRow]
+        val key = new Array[AnyRef](complexType.fields.length)
+        complexType.fields.zipWithIndex.foreach { case (field, index) =>
+          key(index) = getBinaryStableKey(row.get(index, field.dataType), field.dataType)
+        }
+        key
+      case complexType: ArrayType if !isBinaryStable(complexType.elementType) =>
+        val array = k.asInstanceOf[ArrayData]
+        val key = new Array[AnyRef](array.numElements())
+        array.foreach(complexType.elementType, (i, any) => {
+          key.update(i,
+            getBinaryStableKey(array.get(i, complexType.elementType), complexType.elementType))
+        })
+        key
+      case c: StringType if
+        !CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality =>
+        val collationId = c.collationId
+        CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
+      case _ => k
+    }
+  }
+
   /**
    * Wrapper of validateStructuralIntegrityWithReasonImpl, add more information for debugging
+   *
    * @param row The input UnsafeRow to be validated
    * @param expectedSchema The expected schema that should match with the UnsafeRow
    * @return None if all the checks pass. An error message if the row is not matched with the schema
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
index a422484c60be5..f65b54afd9d4e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala
@@ -1719,7 +1719,7 @@ class CollationSQLExpressionsSuite
           t.collationId + ", f2: INT>) USING parquet")
         sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd)
         val query = s"SELECT lower(mode(i).f1) FROM ${tableName}"
-        if(t.collationId == "utf8_binary_lcase" || t.collationId == "unicode_ci") {
+        if(t.collationId == "") { // || t.collationId == "unicode_ci") {
           // Cannot resolve "mode(i)" due to data type mismatch:
           // Input to function mode was a complex type with strings collated on non-binary
           // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13;