From c9d4685c440d1fc0f2369654bc5076ea29dbb55c Mon Sep 17 00:00:00 2001 From: Jin Shang Date: Mon, 14 Aug 2023 23:59:01 +0800 Subject: [PATCH] GH-37028: [C++] Add support for duration types to if_else functions (#37064) ### Rationale for this change Support for duration types is missing in if else functions, including if_else, coalesce, choose and case_when. ### What changes are included in this PR? Add support for duration types to these functions. ### Are these changes tested? Yes. ### Are there any user-facing changes? No. * Closes: #37028 Authored-by: Jin Shang Signed-off-by: Antoine Pitrou --- cpp/src/arrow/compute/kernels/codegen_internal.cc | 10 +++++++++- cpp/src/arrow/compute/kernels/scalar_if_else.cc | 4 ++++ .../arrow/compute/kernels/scalar_if_else_test.cc | 15 +++++++++++++-- cpp/src/arrow/compute/kernels/test_util.h | 5 ++++- 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/codegen_internal.cc b/cpp/src/arrow/compute/kernels/codegen_internal.cc index e0156caecfa5d..8e2669bd3dfb9 100644 --- a/cpp/src/arrow/compute/kernels/codegen_internal.cc +++ b/cpp/src/arrow/compute/kernels/codegen_internal.cc @@ -250,7 +250,7 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t count) { const std::string* timezone = nullptr; bool saw_date32 = false; bool saw_date64 = false; - + bool saw_duration = false; const TypeHolder* end = begin + count; for (auto it = begin; it != end; it++) { auto id = it->type->id(); @@ -271,6 +271,12 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t count) { finest_unit = std::max(finest_unit, ty.unit()); continue; } + case Type::DURATION: { + const auto& ty = checked_cast(*it->type); + finest_unit = std::max(finest_unit, ty.unit()); + saw_duration = true; + continue; + } default: return TypeHolder(nullptr); } @@ -283,6 +289,8 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t count) { return date64(); } else if (saw_date32) { return date32(); + } else if (saw_duration) { + return duration(finest_unit); } return TypeHolder(nullptr); } diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else.cc b/cpp/src/arrow/compute/kernels/scalar_if_else.cc index 0dd176b5d4099..6b4b2339e4afe 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else.cc @@ -2798,6 +2798,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddPrimitiveIfElseKernels(func, NumericTypes()); AddPrimitiveIfElseKernels(func, TemporalTypes()); AddPrimitiveIfElseKernels(func, IntervalTypes()); + AddPrimitiveIfElseKernels(func, DurationTypes()); AddPrimitiveIfElseKernels(func, {boolean()}); AddNullIfElseKernel(func); AddBinaryIfElseKernels(func, BaseBinaryTypes()); @@ -2813,6 +2814,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddPrimitiveCaseWhenKernels(func, NumericTypes()); AddPrimitiveCaseWhenKernels(func, TemporalTypes()); AddPrimitiveCaseWhenKernels(func, IntervalTypes()); + AddPrimitiveCaseWhenKernels(func, DurationTypes()); AddPrimitiveCaseWhenKernels(func, {boolean(), null()}); AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY, CaseWhenFunctor::Exec); @@ -2836,6 +2838,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddPrimitiveCoalesceKernels(func, NumericTypes()); AddPrimitiveCoalesceKernels(func, TemporalTypes()); AddPrimitiveCoalesceKernels(func, IntervalTypes()); + AddPrimitiveCoalesceKernels(func, DurationTypes()); AddPrimitiveCoalesceKernels(func, {boolean(), null()}); AddCoalesceKernel(func, Type::FIXED_SIZE_BINARY, CoalesceFunctor::Exec); @@ -2861,6 +2864,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) { AddPrimitiveChooseKernels(func, NumericTypes()); AddPrimitiveChooseKernels(func, TemporalTypes()); AddPrimitiveChooseKernels(func, IntervalTypes()); + AddPrimitiveChooseKernels(func, DurationTypes()); AddPrimitiveChooseKernels(func, {boolean(), null()}); AddChooseKernel(func, Type::FIXED_SIZE_BINARY, ChooseFunctor::Exec); diff --git a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc index ded73f0371435..a9c5a1fc3c96f 100644 --- a/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc +++ b/cpp/src/arrow/compute/kernels/scalar_if_else_test.cc @@ -73,7 +73,7 @@ class TestIfElsePrimitive : public ::testing::Test {}; #ifdef ARROW_VALGRIND using IfElseNumericBasedTypes = ::testing::Types; + MonthIntervalType, DurationType>; using BaseBinaryArrowTypes = ::testing::Types; using ListArrowTypes = ::testing::Types; using IntegralArrowTypes = ::testing::Types; @@ -81,7 +81,8 @@ using IntegralArrowTypes = ::testing::Types; using IfElseNumericBasedTypes = ::testing::Types; + Time32Type, Time64Type, TimestampType, MonthIntervalType, + DurationType>; #endif TYPED_TEST_SUITE(TestIfElsePrimitive, IfElseNumericBasedTypes); @@ -505,6 +506,9 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) { {boolean(), timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)}); CheckDispatchBest(name, {boolean(), date32(), timestamp(TimeUnit::MILLI)}, {boolean(), timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)}); + CheckDispatchBest(name, + {boolean(), duration(TimeUnit::SECOND), duration(TimeUnit::MILLI)}, + {boolean(), duration(TimeUnit::MILLI), duration(TimeUnit::MILLI)}); CheckDispatchBest(name, {boolean(), date32(), date64()}, {boolean(), date64(), date64()}); CheckDispatchBest(name, {boolean(), date32(), date32()}, @@ -2500,6 +2504,11 @@ TEST(TestCaseWhen, DispatchBest) { {struct_({field("", boolean())}), timestamp(TimeUnit::SECOND), date32()}, {struct_({field("", boolean())}), timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND)}); + CheckDispatchBest("case_when", + {struct_({field("", boolean())}), duration(TimeUnit::SECOND), + duration(TimeUnit::MILLI)}, + {struct_({field("", boolean())}), duration(TimeUnit::MILLI), + duration(TimeUnit::MILLI)}); CheckDispatchBest( "case_when", {struct_({field("", boolean())}), decimal128(38, 0), decimal128(1, 1)}, {struct_({field("", boolean())}), decimal256(39, 1), decimal256(39, 1)}); @@ -3350,6 +3359,8 @@ TEST(TestCoalesce, DispatchBest) { {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND)}); CheckDispatchBest("coalesce", {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::MILLI)}, {timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)}); + CheckDispatchBest("coalesce", {duration(TimeUnit::SECOND), duration(TimeUnit::MILLI)}, + {duration(TimeUnit::MILLI), duration(TimeUnit::MILLI)}); CheckDispatchFails("coalesce", { sparse_union({field("a", boolean())}), dense_union({field("a", boolean())}), diff --git a/cpp/src/arrow/compute/kernels/test_util.h b/cpp/src/arrow/compute/kernels/test_util.h index 73762a1ac6758..11e77caeff861 100644 --- a/cpp/src/arrow/compute/kernels/test_util.h +++ b/cpp/src/arrow/compute/kernels/test_util.h @@ -185,7 +185,10 @@ template enable_if_decimal> default_type_instance() { return std::make_shared(5, 2); } - +template +enable_if_duration> default_type_instance() { + return std::make_shared(TimeUnit::type::SECOND); +} // Random Generator Helpers class RandomImpl { protected: