diff --git a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala index cb4fffc1a3..350aaf7ade 100644 --- a/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala +++ b/spark/src/main/scala/org/apache/comet/serde/QueryPlanSerde.scala @@ -2176,35 +2176,9 @@ object QueryPlanSerde extends Logging with ShimQueryPlanSerde with CometExprShim None } - case Murmur3Hash(children, seed) => - val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType)) - if (firstUnSupportedInput.isDefined) { - withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}") - return None - } - val exprs = children.map(exprToProtoInternal(_, inputs, binding)) - val seedBuilder = ExprOuterClass.Literal - .newBuilder() - .setDatatype(serializeDataType(IntegerType).get) - .setIntVal(seed) - val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build()) - // the seed is put at the end of the arguments - scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*) - - case XxHash64(children, seed) => - val firstUnSupportedInput = children.find(c => !supportedDataType(c.dataType)) - if (firstUnSupportedInput.isDefined) { - withInfo(expr, s"Unsupported datatype ${firstUnSupportedInput.get.dataType}") - return None - } - val exprs = children.map(exprToProtoInternal(_, inputs, binding)) - val seedBuilder = ExprOuterClass.Literal - .newBuilder() - .setDatatype(serializeDataType(LongType).get) - .setLongVal(seed) - val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build()) - // the seed is put at the end of the arguments - scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ seedExpr: _*) + case _: Murmur3Hash => CometMurmur3Hash.convert(expr, inputs, binding) + + case _: XxHash64 => CometXxHash64.convert(expr, inputs, binding) case Sha2(left, numBits) => if (!numBits.foldable) { diff --git a/spark/src/main/scala/org/apache/comet/serde/hash.scala b/spark/src/main/scala/org/apache/comet/serde/hash.scala new file mode 100644 index 0000000000..226c4bab05 --- /dev/null +++ b/spark/src/main/scala/org/apache/comet/serde/hash.scala @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.comet.serde + +import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Murmur3Hash, XxHash64} +import org.apache.spark.sql.types.{DecimalType, IntegerType, LongType} + +import org.apache.comet.CometSparkSessionExtensions.withInfo +import org.apache.comet.serde.QueryPlanSerde.{exprToProtoInternal, scalarExprToProtoWithReturnType, serializeDataType, supportedDataType} + +object CometXxHash64 extends CometExpressionSerde { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + if (!HashUtils.isSupportedType(expr)) { + return None + } + val hash = expr.asInstanceOf[XxHash64] + val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding)) + val seedBuilder = ExprOuterClass.Literal + .newBuilder() + .setDatatype(serializeDataType(LongType).get) + .setLongVal(hash.seed) + val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build()) + // the seed is put at the end of the arguments + scalarExprToProtoWithReturnType("xxhash64", LongType, exprs :+ seedExpr: _*) + } +} + +object CometMurmur3Hash extends CometExpressionSerde { + override def convert( + expr: Expression, + inputs: Seq[Attribute], + binding: Boolean): Option[ExprOuterClass.Expr] = { + if (!HashUtils.isSupportedType(expr)) { + return None + } + val hash = expr.asInstanceOf[Murmur3Hash] + val exprs = hash.children.map(exprToProtoInternal(_, inputs, binding)) + val seedBuilder = ExprOuterClass.Literal + .newBuilder() + .setDatatype(serializeDataType(IntegerType).get) + .setIntVal(hash.seed) + val seedExpr = Some(ExprOuterClass.Expr.newBuilder().setLiteral(seedBuilder).build()) + // the seed is put at the end of the arguments + scalarExprToProtoWithReturnType("murmur3_hash", IntegerType, exprs :+ seedExpr: _*) + } +} + +private object HashUtils { + def isSupportedType(expr: Expression): Boolean = { + for (child <- expr.children) { + child.dataType match { + case dt: DecimalType if dt.precision > 18 => + // Spark converts decimals with precision > 18 into + // Java BigDecimal before hashing + withInfo(expr, s"Unsupported datatype: $dt (precision > 18)") + return false + case dt if !supportedDataType(dt) => + withInfo(expr, s"Unsupported datatype $dt") + return false + case _ => + } + } + true + } +} diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala index f82101b3ac..f226afeaab 100644 --- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala +++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala @@ -1929,19 +1929,45 @@ class CometExpressionSuite extends CometTestBase with AdaptiveSparkPlanHelper { } } - test("hash functions with decimal input") { - withTable("t1", "t2") { - // Apache Spark: if it's a small decimal, i.e. precision <= 18, turn it into long and hash it. - // Else, turn it into bytes and hash it. - sql("create table t1(c1 decimal(18, 2)) using parquet") - sql("insert into t1 values(1.23), (-1.23), (0.0), (null)") - checkSparkAnswerAndOperator("select c1, hash(c1), xxhash64(c1) from t1 order by c1") - - // TODO: comet hash function is not compatible with spark for decimal with precision greater than 18. - // https://github.com/apache/datafusion-comet/issues/1294 -// sql("create table t2(c1 decimal(20, 2)) using parquet") -// sql("insert into t2 values(1.23), (-1.23), (0.0), (null)") -// checkSparkAnswerAndOperator("select c1, hash(c1), xxhash64(c1) from t2 order by c1") + test("hash function with decimal input") { + val testPrecisionScales: Seq[(Int, Int)] = Seq( + (1, 0), + (17, 2), + (18, 2), + (19, 2), + (DecimalType.MAX_PRECISION, DecimalType.MAX_SCALE - 1)) + for ((p, s) <- testPrecisionScales) { + withTable("t1") { + sql(s"create table t1(c1 decimal($p, $s)) using parquet") + sql("insert into t1 values(1.23), (-1.23), (0.0), (null)") + if (p <= 18) { + checkSparkAnswerAndOperator("select c1, hash(c1) from t1 order by c1") + } else { + // not supported natively yet + checkSparkAnswer("select c1, hash(c1) from t1 order by c1") + } + } + } + } + + test("xxhash64 function with decimal input") { + val testPrecisionScales: Seq[(Int, Int)] = Seq( + (1, 0), + (17, 2), + (18, 2), + (19, 2), + (DecimalType.MAX_PRECISION, DecimalType.MAX_SCALE - 1)) + for ((p, s) <- testPrecisionScales) { + withTable("t1") { + sql(s"create table t1(c1 decimal($p, $s)) using parquet") + sql("insert into t1 values(1.23), (-1.23), (0.0), (null)") + if (p <= 18) { + checkSparkAnswerAndOperator("select c1, xxhash64(c1) from t1 order by c1") + } else { + // not supported natively yet + checkSparkAnswer("select c1, xxhash64(c1) from t1 order by c1") + } + } } }