diff --git a/integration_tests/src/main/python/datasourcev2_write_test.py b/integration_tests/src/main/python/datasourcev2_write_test.py index 1f4bc133d2a..0d1091e97cd 100644 --- a/integration_tests/src/main/python/datasourcev2_write_test.py +++ b/integration_tests/src/main/python/datasourcev2_write_test.py @@ -18,7 +18,7 @@ from data_gen import gen_df, decimal_gens, non_utc_allow from marks import * from spark_session import is_hive_available, is_spark_330_or_later, with_cpu_session, with_gpu_session -from hive_parquet_write_test import _hive_bucket_gens, _hive_array_gens, _hive_struct_gens +from hive_parquet_write_test import _hive_bucket_gens from hive_parquet_write_test import read_single_bucket _hive_write_conf = { @@ -75,7 +75,7 @@ def write_hive_table(spark, out_table): @pytest.mark.skipif(not (is_hive_available() and is_spark_330_or_later()), reason="Must have Hive on Spark 3.3+") @pytest.mark.parametrize('file_format', ['parquet', 'orc']) -@pytest.mark.parametrize('gen', decimal_gens + _hive_array_gens + _hive_struct_gens) +@pytest.mark.parametrize('gen', decimal_gens) def test_write_hive_bucketed_unsupported_types_fallback(spark_tmp_table_factory, file_format, gen): out_table = spark_tmp_table_factory.get() diff --git a/integration_tests/src/main/python/hive_parquet_write_test.py b/integration_tests/src/main/python/hive_parquet_write_test.py index e66b889a986..d0db8b4dcc0 100644 --- a/integration_tests/src/main/python/hive_parquet_write_test.py +++ b/integration_tests/src/main/python/hive_parquet_write_test.py @@ -25,11 +25,26 @@ # "GpuInsertIntoHiveTable" for Parquet write. _write_to_hive_conf = {"spark.sql.hive.convertMetastoreParquet": False} -_hive_bucket_gens = [ +_hive_bucket_basic_gens = [ boolean_gen, byte_gen, short_gen, int_gen, long_gen, string_gen, float_gen, double_gen, DateGen(start=date(1590, 1, 1)), _restricted_timestamp()] -_hive_basic_gens = _hive_bucket_gens + [ +_hive_bucket_basic_struct_gen = StructGen( + [['c'+str(ind), c_gen] for ind, c_gen in enumerate(_hive_bucket_basic_gens)]) + +_hive_bucket_struct_gens = [ + _hive_bucket_basic_struct_gen, + StructGen([['child0', byte_gen], ['child1', _hive_bucket_basic_struct_gen]]), + StructGen([['child0', ArrayGen(short_gen)], ['child1', double_gen]])] + +_hive_bucket_array_gens = [ArrayGen(sub_gen) for sub_gen in _hive_bucket_basic_gens] + [ + ArrayGen(ArrayGen(short_gen, max_length=10), max_length=10), + ArrayGen(ArrayGen(string_gen, max_length=10), max_length=10), + ArrayGen(StructGen([['child0', byte_gen], ['child1', string_gen], ['child2', float_gen]]))] + +_hive_bucket_gens = _hive_bucket_basic_gens + _hive_bucket_struct_gens + _hive_bucket_array_gens + +_hive_basic_gens = _hive_bucket_basic_gens + [ DecimalGen(precision=19, scale=1, nullable=True), DecimalGen(precision=23, scale=5, nullable=True), DecimalGen(precision=36, scale=3, nullable=True)] diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala index 7a01329fef1..f9468f6f6a3 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/GpuOverrides.scala @@ -3309,8 +3309,27 @@ object GpuOverrides extends Logging { "hive hash operator", ExprChecks.projectOnly(TypeSig.INT, TypeSig.INT, repeatingParamCheck = Some(RepeatingParamCheck("input", - TypeSig.commonCudfTypes + TypeSig.NULL, TypeSig.all))), + (TypeSig.commonCudfTypes + TypeSig.NULL + TypeSig.STRUCT + TypeSig.ARRAY).nested() + + TypeSig.psNote(TypeEnum.ARRAY, "Nested levels exceeding 8 layers are not supported") + + TypeSig.psNote(TypeEnum.STRUCT, "Nested levels exceeding 8 layers are not supported"), + TypeSig.all))), (a, conf, p, r) => new ExprMeta[HiveHash](a, conf, p, r) { + override def tagExprForGpu(): Unit = { + def getMaxNestedDepth(inputType: DataType): Int = { + inputType match { + case at: ArrayType => 1 + getMaxNestedDepth(at.elementType) + case st: StructType => + 1 + st.map(f => getMaxNestedDepth(f.dataType)).max + case _ => 1 // primitive types + } + } + val maxDepth = a.children.map(c => getMaxNestedDepth(c.dataType)).max + if (maxDepth > 8) { + willNotWorkOnGpu(s"GPU HiveHash supports 8 levels at most for " + + s"nested types, but got $maxDepth") + } + } + def convertToGpu(): GpuExpression = GpuHiveHash(childExprs.map(_.convertToGpu())) }),