Skip to content

Commit

Permalink
apacheGH-37028: [C++] Add support for duration types to if_else funct…
Browse files Browse the repository at this point in the history
…ions (apache#37064)

### Rationale for this change

Support for duration types is missing in if else functions, including if_else, coalesce, choose and case_when.

### What changes are included in this PR?
Add support for duration types to these functions.

### Are these changes tested?
Yes.

### Are there any user-facing changes?

No.

* Closes: apache#37028

Authored-by: Jin Shang <[email protected]>
Signed-off-by: Antoine Pitrou <[email protected]>
  • Loading branch information
js8544 authored Aug 14, 2023
1 parent 2e5694c commit c9d4685
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 4 deletions.
10 changes: 9 additions & 1 deletion cpp/src/arrow/compute/kernels/codegen_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t count) {
const std::string* timezone = nullptr;
bool saw_date32 = false;
bool saw_date64 = false;

bool saw_duration = false;
const TypeHolder* end = begin + count;
for (auto it = begin; it != end; it++) {
auto id = it->type->id();
Expand All @@ -271,6 +271,12 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t count) {
finest_unit = std::max(finest_unit, ty.unit());
continue;
}
case Type::DURATION: {
const auto& ty = checked_cast<const DurationType&>(*it->type);
finest_unit = std::max(finest_unit, ty.unit());
saw_duration = true;
continue;
}
default:
return TypeHolder(nullptr);
}
Expand All @@ -283,6 +289,8 @@ TypeHolder CommonTemporal(const TypeHolder* begin, size_t count) {
return date64();
} else if (saw_date32) {
return date32();
} else if (saw_duration) {
return duration(finest_unit);
}
return TypeHolder(nullptr);
}
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/compute/kernels/scalar_if_else.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2798,6 +2798,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveIfElseKernels(func, NumericTypes());
AddPrimitiveIfElseKernels(func, TemporalTypes());
AddPrimitiveIfElseKernels(func, IntervalTypes());
AddPrimitiveIfElseKernels(func, DurationTypes());
AddPrimitiveIfElseKernels(func, {boolean()});
AddNullIfElseKernel(func);
AddBinaryIfElseKernels(func, BaseBinaryTypes());
Expand All @@ -2813,6 +2814,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveCaseWhenKernels(func, NumericTypes());
AddPrimitiveCaseWhenKernels(func, TemporalTypes());
AddPrimitiveCaseWhenKernels(func, IntervalTypes());
AddPrimitiveCaseWhenKernels(func, DurationTypes());
AddPrimitiveCaseWhenKernels(func, {boolean(), null()});
AddCaseWhenKernel(func, Type::FIXED_SIZE_BINARY,
CaseWhenFunctor<FixedSizeBinaryType>::Exec);
Expand All @@ -2836,6 +2838,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveCoalesceKernels(func, NumericTypes());
AddPrimitiveCoalesceKernels(func, TemporalTypes());
AddPrimitiveCoalesceKernels(func, IntervalTypes());
AddPrimitiveCoalesceKernels(func, DurationTypes());
AddPrimitiveCoalesceKernels(func, {boolean(), null()});
AddCoalesceKernel(func, Type::FIXED_SIZE_BINARY,
CoalesceFunctor<FixedSizeBinaryType>::Exec);
Expand All @@ -2861,6 +2864,7 @@ void RegisterScalarIfElse(FunctionRegistry* registry) {
AddPrimitiveChooseKernels(func, NumericTypes());
AddPrimitiveChooseKernels(func, TemporalTypes());
AddPrimitiveChooseKernels(func, IntervalTypes());
AddPrimitiveChooseKernels(func, DurationTypes());
AddPrimitiveChooseKernels(func, {boolean(), null()});
AddChooseKernel(func, Type::FIXED_SIZE_BINARY,
ChooseFunctor<FixedSizeBinaryType>::Exec);
Expand Down
15 changes: 13 additions & 2 deletions cpp/src/arrow/compute/kernels/scalar_if_else_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,16 @@ class TestIfElsePrimitive : public ::testing::Test {};
#ifdef ARROW_VALGRIND
using IfElseNumericBasedTypes =
::testing::Types<UInt32Type, FloatType, Date32Type, Time32Type, TimestampType,
MonthIntervalType>;
MonthIntervalType, DurationType>;
using BaseBinaryArrowTypes = ::testing::Types<BinaryType>;
using ListArrowTypes = ::testing::Types<ListType>;
using IntegralArrowTypes = ::testing::Types<Int32Type>;
#else
using IfElseNumericBasedTypes =
::testing::Types<UInt8Type, UInt16Type, UInt32Type, UInt64Type, Int8Type, Int16Type,
Int32Type, Int64Type, FloatType, DoubleType, Date32Type, Date64Type,
Time32Type, Time64Type, TimestampType, MonthIntervalType>;
Time32Type, Time64Type, TimestampType, MonthIntervalType,
DurationType>;
#endif

TYPED_TEST_SUITE(TestIfElsePrimitive, IfElseNumericBasedTypes);
Expand Down Expand Up @@ -505,6 +506,9 @@ TEST_F(TestIfElseKernel, IfElseDispatchBest) {
{boolean(), timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
CheckDispatchBest(name, {boolean(), date32(), timestamp(TimeUnit::MILLI)},
{boolean(), timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
CheckDispatchBest(name,
{boolean(), duration(TimeUnit::SECOND), duration(TimeUnit::MILLI)},
{boolean(), duration(TimeUnit::MILLI), duration(TimeUnit::MILLI)});
CheckDispatchBest(name, {boolean(), date32(), date64()},
{boolean(), date64(), date64()});
CheckDispatchBest(name, {boolean(), date32(), date32()},
Expand Down Expand Up @@ -2500,6 +2504,11 @@ TEST(TestCaseWhen, DispatchBest) {
{struct_({field("", boolean())}), timestamp(TimeUnit::SECOND), date32()},
{struct_({field("", boolean())}), timestamp(TimeUnit::SECOND),
timestamp(TimeUnit::SECOND)});
CheckDispatchBest("case_when",
{struct_({field("", boolean())}), duration(TimeUnit::SECOND),
duration(TimeUnit::MILLI)},
{struct_({field("", boolean())}), duration(TimeUnit::MILLI),
duration(TimeUnit::MILLI)});
CheckDispatchBest(
"case_when", {struct_({field("", boolean())}), decimal128(38, 0), decimal128(1, 1)},
{struct_({field("", boolean())}), decimal256(39, 1), decimal256(39, 1)});
Expand Down Expand Up @@ -3350,6 +3359,8 @@ TEST(TestCoalesce, DispatchBest) {
{timestamp(TimeUnit::SECOND), timestamp(TimeUnit::SECOND)});
CheckDispatchBest("coalesce", {timestamp(TimeUnit::SECOND), timestamp(TimeUnit::MILLI)},
{timestamp(TimeUnit::MILLI), timestamp(TimeUnit::MILLI)});
CheckDispatchBest("coalesce", {duration(TimeUnit::SECOND), duration(TimeUnit::MILLI)},
{duration(TimeUnit::MILLI), duration(TimeUnit::MILLI)});
CheckDispatchFails("coalesce", {
sparse_union({field("a", boolean())}),
dense_union({field("a", boolean())}),
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/arrow/compute/kernels/test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,10 @@ template <typename T>
enable_if_decimal<T, std::shared_ptr<DataType>> default_type_instance() {
return std::make_shared<T>(5, 2);
}

template <typename T>
enable_if_duration<T, std::shared_ptr<DataType>> default_type_instance() {
return std::make_shared<T>(TimeUnit::type::SECOND);
}
// Random Generator Helpers
class RandomImpl {
protected:
Expand Down

0 comments on commit c9d4685

Please sign in to comment.