diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java index fce12510afaf5..f7e6f76199cee 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java @@ -817,4 +817,15 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) { } } + public static byte[] getCollationKeyBytes(UTF8String input, int collationId) { + Collation collation = fetchCollation(collationId); + if (collation.supportsBinaryEquality) { + return input.getBytes(); + } else if (collation.supportsLowercaseEquality) { + return input.toLowerCase().getBytes(); + } else { + return collation.collator.getCollationKey(input.toString()).toByteArray(); + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteCollationJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteCollationJoin.scala new file mode 100644 index 0000000000000..fd443fd19a1fe --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteCollationJoin.scala @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, CollationKey, Equality} +import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.types.StringType + +object RewriteCollationJoin extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case j @ Join(_, _, _, Some(condition), _) => + val newCondition = condition transform { + case e @ Equality(l: AttributeReference, r: AttributeReference) => + (l.dataType, r.dataType) match { + case (st: StringType, _: StringType) + if !CollationFactory.fetchCollation(st.collationId).supportsBinaryEquality => + e.withNewChildren(Seq(CollationKey(l), CollationKey(r))) + case _ => + e + } + } + if (!newCondition.fastEquals(condition)) { + j.copy(condition = Some(newCondition)) + } else { + j + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala new file mode 100644 index 0000000000000..6e400d026e0ee --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.util.CollationFactory +import org.apache.spark.sql.internal.types.StringTypeAnyCollation +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) + override def dataType: DataType = BinaryType + + final lazy val collationId: Int = expr.dataType match { + case st: StringType => + st.collationId + } + + override def nullSafeEval(input: Any): Any = + CollationFactory.getCollationKeyBytes(input.asInstanceOf[UTF8String], collationId) + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + defineCodeGen(ctx, ev, c => s"CollationFactory.getCollationKeyBytes($c, $collationId)") + } + + override protected def withNewChildInternal(newChild: Expression): Expression = { + copy(expr = newChild) + } + + override def child: Expression = expr +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala index c3495a0c112c3..f278b8e5899d8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollationExpressionSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.{SparkException, SparkFunSuite} import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { test("validate default collation") { @@ -163,6 +164,31 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { } } + test("CollationKey generates correct collation key for collated string") { + val testCases = Seq( + ("", "UTF8_BINARY", UTF8String.fromString("").getBytes), + ("aa", "UTF8_BINARY", UTF8String.fromString("aa").getBytes), + ("AA", "UTF8_BINARY", UTF8String.fromString("AA").getBytes), + ("aA", "UTF8_BINARY", UTF8String.fromString("aA").getBytes), + ("", "UTF8_BINARY_LCASE", UTF8String.fromString("").getBytes), + ("aa", "UTF8_BINARY_LCASE", UTF8String.fromString("aa").getBytes), + ("AA", "UTF8_BINARY_LCASE", UTF8String.fromString("aa").getBytes), + ("aA", "UTF8_BINARY_LCASE", UTF8String.fromString("aa").getBytes), + ("", "UNICODE", UTF8String.fromString("").getBytes), + ("aa", "UNICODE", UTF8String.fromString("aa").getBytes), + ("AA", "UNICODE", UTF8String.fromString("AA").getBytes), + ("aA", "UNICODE", UTF8String.fromString("aA").getBytes), + ("", "UNICODE_CI", Array[Byte](1, 0)), + ("aa", "UNICODE_CI", Array[Byte](42, 42, 1, 6, 0)), + ("AA", "UNICODE_CI", Array[Byte](42, 42, 1, 6, 0)), + ("aA", "UNICODE_CI", Array[Byte](42, 42, 1, 6, 0)) + ) + for ((input, collation, expected) <- testCases) { + val str = Literal.create(input, StringType(collation)) + checkEvaluation(CollationKey(str), expected) + } + } + test("collation name normalization in collation expression") { Seq( ("en_USA", "en_USA"), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala index 6173703ef3cd9..3382a1161ddba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.ExperimentalMethods +import org.apache.spark.sql.catalyst.analysis.RewriteCollationJoin import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.optimizer._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -92,7 +93,8 @@ class SparkOptimizer( EliminateLimits, ConstantFolding) :+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) :+ - Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition) + Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition) :+ + Batch("RewriteCollationJoin", Once, RewriteCollationJoin) override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+ ExtractPythonUDFFromJoinCondition.ruleName :+ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 4f8587395b3e6..9b3bfe1c77b3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava import org.apache.spark.SparkException import org.apache.spark.sql.catalyst.ExtendedAnalysisException -import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.CollationFactory import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema} import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} @@ -30,8 +30,8 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} -import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} -import org.apache.spark.sql.internal.SqlApiConf +import org.apache.spark.sql.execution.joins._ +import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation} import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} @@ -769,37 +769,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { }) } - test("hash based joins not allowed for non-binary collated strings") { - val in = (('a' to 'z') ++ ('A' to 'Z')).map(_.toString * 3).map(e => Row.apply(e, e)) - - val schema = StructType(StructField( - "col_non_binary", - StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) :: - StructField("col_binary", StringType) :: Nil) - val df1 = spark.createDataFrame(sparkContext.parallelize(in), schema) - - // Binary collations are allowed to use hash join. - assert(collectFirst( - df1.hint("broadcast").join(df1, df1("col_binary") === df1("col_binary")) - .queryExecution.executedPlan) { - case _: BroadcastHashJoinExec => () - }.nonEmpty) - - // Even with hint broadcast, hash join is not used for non-binary collated strings. - assert(collectFirst( - df1.hint("broadcast").join(df1, df1("col_non_binary") === df1("col_non_binary")) - .queryExecution.executedPlan) { - case _: BroadcastHashJoinExec => () - }.isEmpty) - - // Instead they will default to sort merge join. - assert(collectFirst( - df1.hint("broadcast").join(df1, df1("col_non_binary") === df1("col_non_binary")) - .queryExecution.executedPlan) { - case _: SortMergeJoinExec => () - }.nonEmpty) - } - test("Generated column expressions using collations - errors out") { checkError( exception = intercept[AnalysisException] { @@ -1030,6 +999,135 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { } } + test("hash join should be used for collated strings") { + val t1 = "T_1" + val t2 = "T_2" + + case class HashJoinTestCase[R](collation: String, result: R) + val testCases = Seq( + HashJoinTestCase("UTF8_BINARY", Seq(Row("aa", 1, "aa", 2))), + HashJoinTestCase("UTF8_BINARY_LCASE", Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2))), + HashJoinTestCase("UNICODE", Seq(Row("aa", 1, "aa", 2))), + HashJoinTestCase("UNICODE_CI", Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2))) + ) + + testCases.foreach(t => { + withTable(t1, t2) { + sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES ('aa', 1)") + + sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2)") + + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, t.result) + + val queryPlan = df.queryExecution.executedPlan + + // confirm that hash join is used instead of sort merge join + assert( + collectFirst(queryPlan) { + case _: HashJoin => () + }.nonEmpty + ) + assert( + collectFirst(queryPlan) { + case _: SortMergeJoinExec => () + }.isEmpty + ) + + // if collation doesn't support binary equality, collation key should be injected + if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) { + assert(collectFirst(queryPlan) { + case b: HashJoin => b.leftKeys.head + }.head.isInstanceOf[CollationKey]) + } + } + }) + } + + test("rewrite with collationkey should be an excludable rule") { + val t1 = "T_1" + val t2 = "T_2" + val collation = "UTF8_BINARY_LCASE" + val collationRewriteJoinRule = "org.apache.spark.sql.catalyst.analysis.RewriteCollationJoin" + withTable(t1, t2) { + withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> collationRewriteJoinRule) { + sql(s"CREATE TABLE $t1 (x STRING COLLATE $collation, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES ('aa', 1)") + + sql(s"CREATE TABLE $t2 (y STRING COLLATE $collation, j int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2)") + + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y") + checkAnswer(df, Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2))) + + val queryPlan = df.queryExecution.executedPlan + + // confirm that shuffle join is used instead of hash join + assert( + collectFirst(queryPlan) { + case _: HashJoin => () + }.isEmpty + ) + assert( + collectFirst(queryPlan) { + case _: SortMergeJoinExec => () + }.nonEmpty + ) + } + } + } + + test("rewrite with collationkey shouldn't disrupt multiple join conditions") { + val t1 = "T_1" + val t2 = "T_2" + + case class HashMultiJoinTestCase[R]( + type1: String, + type2: String, + data1: String, + data2: String, + result: R + ) + val testCases = Seq( + HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "INT", + "'a', 0, 1", "'a', 0, 1", Row("a", 0, 1, "a", 0, 1)), + HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "STRING COLLATE UTF8_BINARY", + "'a', 'a', 1", "'a', 'a', 1", Row("a", "a", 1, "a", "a", 1)), + HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "STRING COLLATE UTF8_BINARY_LCASE", + "'a', 'a', 1", "'a', 'A', 1", Row("a", "a", 1, "a", "A", 1)), + HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY_LCASE", "STRING COLLATE UNICODE_CI", + "'a', 'a', 1", "'A', 'A', 1", Row("a", "a", 1, "A", "A", 1)) + ) + + testCases.foreach(t => { + withTable(t1, t2) { + sql(s"CREATE TABLE $t1 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") + sql(s"INSERT INTO $t1 VALUES (${t.data1})") + sql(s"CREATE TABLE $t2 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET") + sql(s"INSERT INTO $t2 VALUES (${t.data2})") + + val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.x AND $t1.y = $t2.y") + checkAnswer(df, t.result) + + val queryPlan = df.queryExecution.executedPlan + + // confirm that hash join is used instead of sort merge join + assert( + collectFirst(queryPlan) { + case _: HashJoin => () + }.nonEmpty + ) + assert( + collectFirst(queryPlan) { + case _: SortMergeJoinExec => () + }.isEmpty + ) + } + }) + } + test("hll sketch aggregate should respect collation") { case class HllSketchAggTestCase[R](c: String, result: R) val testCases = Seq(