Skip to content

Commit

Permalink
[SPARK-48000][SQL] Enable hash join support for all collations (Strin…
Browse files Browse the repository at this point in the history
…gType)

### What changes were proposed in this pull request?
Enable collation support for hash join on StringType. Note: support for complex types will be added separately.
- Logical plan is rewritten in analysis to replace non-binary strings with `CollationKey`
- `CollationKey` is a unary expression that transforms `StringType` to `BinaryType`
- Collation keys allow correct & efficient string comparison under specific collation rules

### Why are the changes needed?
Improve JOIN performance for collated strings.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?

- Unit tests for `CollationKey` in `CollationExpressionSuite`
- E2e SQL tests for `RewriteCollationJoin` in `CollationSuite`
- Various queries with JOIN in existing TPCDS collation test suite

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#46599 from uros-db/hash-join-str.

Authored-by: Uros Bojanic <[email protected]>
Signed-off-by: Wenchen Fan <[email protected]>
  • Loading branch information
uros-db authored and cloud-fan committed May 29, 2024
1 parent a86bca1 commit e6236af
Show file tree
Hide file tree
Showing 6 changed files with 264 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -817,4 +817,15 @@ public static UTF8String getCollationKey(UTF8String input, int collationId) {
}
}

public static byte[] getCollationKeyBytes(UTF8String input, int collationId) {
Collation collation = fetchCollation(collationId);
if (collation.supportsBinaryEquality) {
return input.getBytes();
} else if (collation.supportsLowercaseEquality) {
return input.toLowerCase().getBytes();
} else {
return collation.collator.getCollationKey(input.toString()).toByteArray();
}
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.analysis

import org.apache.spark.sql.catalyst.expressions.{AttributeReference, CollationKey, Equality}
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.types.StringType

object RewriteCollationJoin extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case j @ Join(_, _, _, Some(condition), _) =>
val newCondition = condition transform {
case e @ Equality(l: AttributeReference, r: AttributeReference) =>
(l.dataType, r.dataType) match {
case (st: StringType, _: StringType)
if !CollationFactory.fetchCollation(st.collationId).supportsBinaryEquality =>
e.withNewChildren(Seq(CollationKey(l), CollationKey(r)))
case _ =>
e
}
}
if (!newCondition.fastEquals(condition)) {
j.copy(condition = Some(newCondition))
} else {
j
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.internal.types.StringTypeAnyCollation
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation)
override def dataType: DataType = BinaryType

final lazy val collationId: Int = expr.dataType match {
case st: StringType =>
st.collationId
}

override def nullSafeEval(input: Any): Any =
CollationFactory.getCollationKeyBytes(input.asInstanceOf[UTF8String], collationId)

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
defineCodeGen(ctx, ev, c => s"CollationFactory.getCollationKeyBytes($c, $collationId)")
}

override protected def withNewChildInternal(newChild: Expression): Expression = {
copy(expr = newChild)
}

override def child: Expression = expr
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
test("validate default collation") {
Expand Down Expand Up @@ -163,6 +164,31 @@ class CollationExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}

test("CollationKey generates correct collation key for collated string") {
val testCases = Seq(
("", "UTF8_BINARY", UTF8String.fromString("").getBytes),
("aa", "UTF8_BINARY", UTF8String.fromString("aa").getBytes),
("AA", "UTF8_BINARY", UTF8String.fromString("AA").getBytes),
("aA", "UTF8_BINARY", UTF8String.fromString("aA").getBytes),
("", "UTF8_BINARY_LCASE", UTF8String.fromString("").getBytes),
("aa", "UTF8_BINARY_LCASE", UTF8String.fromString("aa").getBytes),
("AA", "UTF8_BINARY_LCASE", UTF8String.fromString("aa").getBytes),
("aA", "UTF8_BINARY_LCASE", UTF8String.fromString("aa").getBytes),
("", "UNICODE", UTF8String.fromString("").getBytes),
("aa", "UNICODE", UTF8String.fromString("aa").getBytes),
("AA", "UNICODE", UTF8String.fromString("AA").getBytes),
("aA", "UNICODE", UTF8String.fromString("aA").getBytes),
("", "UNICODE_CI", Array[Byte](1, 0)),
("aa", "UNICODE_CI", Array[Byte](42, 42, 1, 6, 0)),
("AA", "UNICODE_CI", Array[Byte](42, 42, 1, 6, 0)),
("aA", "UNICODE_CI", Array[Byte](42, 42, 1, 6, 0))
)
for ((input, collation, expected) <- testCases) {
val str = Literal.create(input, StringType(collation))
checkEvaluation(CollationKey(str), expected)
}
}

test("collation name normalization in collation expression") {
Seq(
("en_USA", "en_USA"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution

import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.analysis.RewriteCollationJoin
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer._
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
Expand Down Expand Up @@ -92,7 +93,8 @@ class SparkOptimizer(
EliminateLimits,
ConstantFolding) :+
Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*) :+
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition)
Batch("Replace CTE with Repartition", Once, ReplaceCTERefWithRepartition) :+
Batch("RewriteCollationJoin", Once, RewriteCollationJoin)

override def nonExcludableRules: Seq[String] = super.nonExcludableRules :+
ExtractPythonUDFFromJoinCondition.ruleName :+
Expand Down
166 changes: 132 additions & 34 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema}
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable}
Expand All @@ -30,8 +30,8 @@ import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.execution.joins._
import org.apache.spark.sql.internal.{SqlApiConf, SQLConf}
import org.apache.spark.sql.internal.types.{AbstractMapType, StringTypeAnyCollation}
import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType}

Expand Down Expand Up @@ -769,37 +769,6 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
})
}

test("hash based joins not allowed for non-binary collated strings") {
val in = (('a' to 'z') ++ ('A' to 'Z')).map(_.toString * 3).map(e => Row.apply(e, e))

val schema = StructType(StructField(
"col_non_binary",
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))) ::
StructField("col_binary", StringType) :: Nil)
val df1 = spark.createDataFrame(sparkContext.parallelize(in), schema)

// Binary collations are allowed to use hash join.
assert(collectFirst(
df1.hint("broadcast").join(df1, df1("col_binary") === df1("col_binary"))
.queryExecution.executedPlan) {
case _: BroadcastHashJoinExec => ()
}.nonEmpty)

// Even with hint broadcast, hash join is not used for non-binary collated strings.
assert(collectFirst(
df1.hint("broadcast").join(df1, df1("col_non_binary") === df1("col_non_binary"))
.queryExecution.executedPlan) {
case _: BroadcastHashJoinExec => ()
}.isEmpty)

// Instead they will default to sort merge join.
assert(collectFirst(
df1.hint("broadcast").join(df1, df1("col_non_binary") === df1("col_non_binary"))
.queryExecution.executedPlan) {
case _: SortMergeJoinExec => ()
}.nonEmpty)
}

test("Generated column expressions using collations - errors out") {
checkError(
exception = intercept[AnalysisException] {
Expand Down Expand Up @@ -1030,6 +999,135 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
}
}

test("hash join should be used for collated strings") {
val t1 = "T_1"
val t2 = "T_2"

case class HashJoinTestCase[R](collation: String, result: R)
val testCases = Seq(
HashJoinTestCase("UTF8_BINARY", Seq(Row("aa", 1, "aa", 2))),
HashJoinTestCase("UTF8_BINARY_LCASE", Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2))),
HashJoinTestCase("UNICODE", Seq(Row("aa", 1, "aa", 2))),
HashJoinTestCase("UNICODE_CI", Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2)))
)

testCases.foreach(t => {
withTable(t1, t2) {
sql(s"CREATE TABLE $t1 (x STRING COLLATE ${t.collation}, i int) USING PARQUET")
sql(s"INSERT INTO $t1 VALUES ('aa', 1)")

sql(s"CREATE TABLE $t2 (y STRING COLLATE ${t.collation}, j int) USING PARQUET")
sql(s"INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2)")

val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
checkAnswer(df, t.result)

val queryPlan = df.queryExecution.executedPlan

// confirm that hash join is used instead of sort merge join
assert(
collectFirst(queryPlan) {
case _: HashJoin => ()
}.nonEmpty
)
assert(
collectFirst(queryPlan) {
case _: SortMergeJoinExec => ()
}.isEmpty
)

// if collation doesn't support binary equality, collation key should be injected
if (!CollationFactory.fetchCollation(t.collation).supportsBinaryEquality) {
assert(collectFirst(queryPlan) {
case b: HashJoin => b.leftKeys.head
}.head.isInstanceOf[CollationKey])
}
}
})
}

test("rewrite with collationkey should be an excludable rule") {
val t1 = "T_1"
val t2 = "T_2"
val collation = "UTF8_BINARY_LCASE"
val collationRewriteJoinRule = "org.apache.spark.sql.catalyst.analysis.RewriteCollationJoin"
withTable(t1, t2) {
withSQLConf(SQLConf.OPTIMIZER_EXCLUDED_RULES.key -> collationRewriteJoinRule) {
sql(s"CREATE TABLE $t1 (x STRING COLLATE $collation, i int) USING PARQUET")
sql(s"INSERT INTO $t1 VALUES ('aa', 1)")

sql(s"CREATE TABLE $t2 (y STRING COLLATE $collation, j int) USING PARQUET")
sql(s"INSERT INTO $t2 VALUES ('AA', 2), ('aa', 2)")

val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.y")
checkAnswer(df, Seq(Row("aa", 1, "AA", 2), Row("aa", 1, "aa", 2)))

val queryPlan = df.queryExecution.executedPlan

// confirm that shuffle join is used instead of hash join
assert(
collectFirst(queryPlan) {
case _: HashJoin => ()
}.isEmpty
)
assert(
collectFirst(queryPlan) {
case _: SortMergeJoinExec => ()
}.nonEmpty
)
}
}
}

test("rewrite with collationkey shouldn't disrupt multiple join conditions") {
val t1 = "T_1"
val t2 = "T_2"

case class HashMultiJoinTestCase[R](
type1: String,
type2: String,
data1: String,
data2: String,
result: R
)
val testCases = Seq(
HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "INT",
"'a', 0, 1", "'a', 0, 1", Row("a", 0, 1, "a", 0, 1)),
HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "STRING COLLATE UTF8_BINARY",
"'a', 'a', 1", "'a', 'a', 1", Row("a", "a", 1, "a", "a", 1)),
HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY", "STRING COLLATE UTF8_BINARY_LCASE",
"'a', 'a', 1", "'a', 'A', 1", Row("a", "a", 1, "a", "A", 1)),
HashMultiJoinTestCase("STRING COLLATE UTF8_BINARY_LCASE", "STRING COLLATE UNICODE_CI",
"'a', 'a', 1", "'A', 'A', 1", Row("a", "a", 1, "A", "A", 1))
)

testCases.foreach(t => {
withTable(t1, t2) {
sql(s"CREATE TABLE $t1 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET")
sql(s"INSERT INTO $t1 VALUES (${t.data1})")
sql(s"CREATE TABLE $t2 (x ${t.type1}, y ${t.type2}, i int) USING PARQUET")
sql(s"INSERT INTO $t2 VALUES (${t.data2})")

val df = sql(s"SELECT * FROM $t1 JOIN $t2 ON $t1.x = $t2.x AND $t1.y = $t2.y")
checkAnswer(df, t.result)

val queryPlan = df.queryExecution.executedPlan

// confirm that hash join is used instead of sort merge join
assert(
collectFirst(queryPlan) {
case _: HashJoin => ()
}.nonEmpty
)
assert(
collectFirst(queryPlan) {
case _: SortMergeJoinExec => ()
}.isEmpty
)
}
})
}

test("hll sketch aggregate should respect collation") {
case class HllSketchAggTestCase[R](c: String, result: R)
val testCases = Seq(
Expand Down

0 comments on commit e6236af

Please sign in to comment.