diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index d5f1bf8c15d0a..7bbccbb3fe932 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project, Union} import org.apache.spark.sql.catalyst.trees.TreePattern._ -import org.apache.spark.sql.catalyst.util.{CollationAwareUTF8String, TypeUtils} +import org.apache.spark.sql.catalyst.util.{CollationAwareUTF8String, CollationFactory, TypeUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -657,7 +657,14 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } @transient lazy val set: Set[Any] = child.dataType match { - case st: StringType if st.supportsLowercaseEquality => new InSet.LCaseSet(hset) + case st: StringType => + if (st.supportsBinaryEquality) { + hset + } else if (st.supportsLowercaseEquality) { + new InSet.LCaseSet(hset) + } else { + new InSet.CollationSet(hset, st.collationId) + } case t: AtomicType if !t.isInstanceOf[BinaryType] => hset case _: NullType => hset case _ => @@ -785,6 +792,19 @@ object InSet { strSet.contains(CollationAwareUTF8String.lowerCaseCodePoints(elem.asInstanceOf[UTF8String])) } } + class CollationSet(inputSet: Set[Any], collationId: Int) extends immutable.Set[Any] { + private val collation = CollationFactory.fetchCollation(collationId) + override def incl(elem: Any): Set[Any] = inputSet.incl(elem) + override def excl(elem: Any): Set[Any] = inputSet.excl(elem) + override def iterator: Iterator[Any] = inputSet.iterator + + override def contains(elem: Any): Boolean = { + assert(elem != null, "InSet guarantees non-null input") + inputSet.exists { p => + collation.equalsFunction(p.asInstanceOf[UTF8String], elem.asInstanceOf[UTF8String]) + } + } + } } @ExpressionDescription( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala index 07162e2eea554..e421c071eec21 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala @@ -231,10 +231,15 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { ("aBc", "UTF8_LCASE", Set("b", "aa", "xyz")) -> false, ("aBc", "UTF8_LCASE", Set("b", "AbC", null)) -> true, (null, "UTF8_LCASE", Set("b", "AbC", null)) -> null, + (" aa", "UTF8_BINARY_RTRIM", Set(" aa")) -> true, + (" aa ", "UTF8_BINARY_RTRIM", Set(" aa")) -> true, + ("a ", "UTF8_BINARY_RTRIM", Set()) -> false, + ("a ", "UTF8_BINARY_RTRIM", Set("a", "b", null)) -> true, + (null, "UTF8_BINARY_RTRIM", Set("1", "2")) -> null ).foreach { case ((elem, collation, inputSet), result) => - val hset = inputSet.map(UTF8String.fromString).asInstanceOf[Set[Any]] + val iset = inputSet.map(UTF8String.fromString).asInstanceOf[Set[Any]] checkEvaluation( - InSet(Literal.create(elem, StringType(collation)), hset), + InSet(Literal.create(elem, StringType(collation)), iset), result ) }