forked from apache/spark
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[SPARK-48000][SQL] Enable hash join support for all collations (Strin…
…gType) ### What changes were proposed in this pull request? Enable collation support for hash join on StringType. Note: support for complex types will be added separately. - Logical plan is rewritten in analysis to replace non-binary strings with `CollationKey` - `CollationKey` is a unary expression that transforms `StringType` to `BinaryType` - Collation keys allow correct & efficient string comparison under specific collation rules ### Why are the changes needed? Improve JOIN performance for collated strings. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? - Unit tests for `CollationKey` in `CollationExpressionSuite` - E2e SQL tests for `RewriteCollationJoin` in `CollationSuite` - Various queries with JOIN in existing TPCDS collation test suite ### Was this patch authored or co-authored using generative AI tooling? No. Closes apache#46599 from uros-db/hash-join-str. Authored-by: Uros Bojanic <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
- Loading branch information
Showing
6 changed files
with
264 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
45 changes: 45 additions & 0 deletions
45
...catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RewriteCollationJoin.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.analysis | ||
|
||
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, CollationKey, Equality} | ||
import org.apache.spark.sql.catalyst.plans.logical.{Join, LogicalPlan} | ||
import org.apache.spark.sql.catalyst.rules.Rule | ||
import org.apache.spark.sql.catalyst.util.CollationFactory | ||
import org.apache.spark.sql.types.StringType | ||
|
||
object RewriteCollationJoin extends Rule[LogicalPlan] { | ||
def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
case j @ Join(_, _, _, Some(condition), _) => | ||
val newCondition = condition transform { | ||
case e @ Equality(l: AttributeReference, r: AttributeReference) => | ||
(l.dataType, r.dataType) match { | ||
case (st: StringType, _: StringType) | ||
if !CollationFactory.fetchCollation(st.collationId).supportsBinaryEquality => | ||
e.withNewChildren(Seq(CollationKey(l), CollationKey(r))) | ||
case _ => | ||
e | ||
} | ||
} | ||
if (!newCondition.fastEquals(condition)) { | ||
j.copy(condition = Some(newCondition)) | ||
} else { | ||
j | ||
} | ||
} | ||
} |
47 changes: 47 additions & 0 deletions
47
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/CollationKey.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.expressions | ||
|
||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} | ||
import org.apache.spark.sql.catalyst.util.CollationFactory | ||
import org.apache.spark.sql.internal.types.StringTypeAnyCollation | ||
import org.apache.spark.sql.types._ | ||
import org.apache.spark.unsafe.types.UTF8String | ||
|
||
case class CollationKey(expr: Expression) extends UnaryExpression with ExpectsInputTypes { | ||
override def inputTypes: Seq[AbstractDataType] = Seq(StringTypeAnyCollation) | ||
override def dataType: DataType = BinaryType | ||
|
||
final lazy val collationId: Int = expr.dataType match { | ||
case st: StringType => | ||
st.collationId | ||
} | ||
|
||
override def nullSafeEval(input: Any): Any = | ||
CollationFactory.getCollationKeyBytes(input.asInstanceOf[UTF8String], collationId) | ||
|
||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
defineCodeGen(ctx, ev, c => s"CollationFactory.getCollationKeyBytes($c, $collationId)") | ||
} | ||
|
||
override protected def withNewChildInternal(newChild: Expression): Expression = { | ||
copy(expr = newChild) | ||
} | ||
|
||
override def child: Expression = expr | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters