Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HiveHash support for nested types #11660

Draft
wants to merge 1 commit into
base: branch-24.12
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions integration_tests/src/main/python/datasourcev2_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from data_gen import gen_df, decimal_gens, non_utc_allow
from marks import *
from spark_session import is_hive_available, is_spark_330_or_later, with_cpu_session, with_gpu_session
from hive_parquet_write_test import _hive_bucket_gens, _hive_array_gens, _hive_struct_gens
from hive_parquet_write_test import _hive_bucket_gens
from hive_parquet_write_test import read_single_bucket

_hive_write_conf = {
Expand Down Expand Up @@ -75,7 +75,7 @@ def write_hive_table(spark, out_table):
@pytest.mark.skipif(not (is_hive_available() and is_spark_330_or_later()),
reason="Must have Hive on Spark 3.3+")
@pytest.mark.parametrize('file_format', ['parquet', 'orc'])
@pytest.mark.parametrize('gen', decimal_gens + _hive_array_gens + _hive_struct_gens)
@pytest.mark.parametrize('gen', decimal_gens)
def test_write_hive_bucketed_unsupported_types_fallback(spark_tmp_table_factory, file_format, gen):
out_table = spark_tmp_table_factory.get()

Expand Down
19 changes: 17 additions & 2 deletions integration_tests/src/main/python/hive_parquet_write_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,26 @@
# "GpuInsertIntoHiveTable" for Parquet write.
_write_to_hive_conf = {"spark.sql.hive.convertMetastoreParquet": False}

_hive_bucket_gens = [
_hive_bucket_basic_gens = [
boolean_gen, byte_gen, short_gen, int_gen, long_gen, string_gen, float_gen, double_gen,
DateGen(start=date(1590, 1, 1)), _restricted_timestamp()]

_hive_basic_gens = _hive_bucket_gens + [
_hive_bucket_basic_struct_gen = StructGen(
[['c'+str(ind), c_gen] for ind, c_gen in enumerate(_hive_bucket_basic_gens)])

_hive_bucket_struct_gens = [
_hive_bucket_basic_struct_gen,
StructGen([['child0', byte_gen], ['child1', _hive_bucket_basic_struct_gen]]),
StructGen([['child0', ArrayGen(short_gen)], ['child1', double_gen]])]

_hive_bucket_array_gens = [ArrayGen(sub_gen) for sub_gen in _hive_bucket_basic_gens] + [
ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10),
ArrayGen(ArrayGen(string_gen, max_length=10), max_length=10),
ArrayGen(StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]]))]

_hive_bucket_gens = _hive_bucket_basic_gens + _hive_bucket_struct_gens + _hive_bucket_array_gens

_hive_basic_gens = _hive_bucket_basic_gens + [
DecimalGen(precision=19, scale=1, nullable=True),
DecimalGen(precision=23, scale=5, nullable=True),
DecimalGen(precision=36, scale=3, nullable=True)]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3309,8 +3309,27 @@ object GpuOverrides extends Logging {
"hive hash operator",
ExprChecks.projectOnly(TypeSig.INT, TypeSig.INT,
repeatingParamCheck = Some(RepeatingParamCheck("input",
TypeSig.commonCudfTypes + TypeSig.NULL, TypeSig.all))),
(TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY).nested() +
TypeSig.psNote(TypeEnum.ARRAY, "Nested levels exceeding 8 layers are not supported") +
TypeSig.psNote(TypeEnum.STRUCT, "Nested levels exceeding 8 layers are not supported"),
TypeSig.all))),
(a, conf, p, r) => new ExprMeta[HiveHash](a, conf, p, r) {
override def tagExprForGpu(): Unit = {
def getMaxNestedDepth(inputType: DataType): Int = {
inputType match {
case at: ArrayType => 1 + getMaxNestedDepth(at.elementType)
case st: StructType =>
1 + st.map(f => getMaxNestedDepth(f.dataType)).max
case _ => 1 // primitive types
}
}
val maxDepth = a.children.map(c => getMaxNestedDepth(c.dataType)).max
if (maxDepth > 8) {
willNotWorkOnGpu(s"GPU HiveHash supports 8 levels at most for " +
s"nested types, but got $maxDepth")
}
}

def convertToGpu(): GpuExpression =
GpuHiveHash(childExprs.map(_.convertToGpu()))
}),
Expand Down
Loading