diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc index 049607b3ea5c8..98eb37e9c5fd2 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc @@ -750,13 +750,14 @@ struct SparseUnionSelectionImpl LIFT_BASE_MEMBERS(); TypedBufferBuilder child_id_buffer_builder_; - std::vector type_codes_; + const int8_t type_code_for_null_; SparseUnionSelectionImpl(KernelContext* ctx, const ExecSpan& batch, int64_t output_length, ExecResult* out) : Base(ctx, batch, output_length, out), child_id_buffer_builder_(ctx->memory_pool()), - type_codes_(checked_cast(*this->values.type).type_codes()) {} + type_code_for_null_( + checked_cast(*this->values.type).type_codes()[0]) {} template Status GenerateOutput() { @@ -768,7 +769,7 @@ struct SparseUnionSelectionImpl return Status::OK(); }, [&]() { - child_id_buffer_builder_.UnsafeAppend(type_codes_[0]); + child_id_buffer_builder_.UnsafeAppend(type_code_for_null_); return Status::OK(); })); return Status::OK(); diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc b/cpp/src/arrow/compute/kernels/vector_selection_test.cc index ac5bfe449c49c..30e85c1f71089 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc @@ -282,11 +282,6 @@ class TestFilterKernel : public ::testing::Test { const std::shared_ptr& expected) { DoAssertFilter(values, filter, expected); - if (values->type_id() == Type::DENSE_UNION) { - // Concatenation of dense union not supported - return; - } - // Check slicing: add M(=3) dummy values at the start and end of `values`, // add N(=2) dummy values at the start and end of `filter`. ARROW_SCOPED_TRACE("for sliced values and filter"); @@ -785,17 +780,6 @@ TEST_F(TestFilterKernelWithUnion, FilterUnion) { [2, null] ])"); this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json); - - // Sliced - // (check this manually as concatenation of dense unions isn't supported: ARROW-4975) - auto values = ArrayFromJSON(union_type, union_json)->Slice(2, 4); - auto filter = ArrayFromJSON(boolean(), "[0, 1, 1, null, 0, 1, 1]")->Slice(2, 4); - auto expected = ArrayFromJSON(union_type, R"([ - [5, "hello"], - [2, null], - [2, 111] - ])"); - this->AssertFilter(values, filter, expected); } } @@ -1029,13 +1013,11 @@ void CheckTake(const std::shared_ptr& type, const std::string& values_ AssertTakeArrays(values, indices, expected); // Check sliced values - if (type->id() != Type::DENSE_UNION) { - ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(type, 2)); - ASSERT_OK_AND_ASSIGN(auto values_sliced, - Concatenate({values_filler, values, values_filler})); - values_sliced = values_sliced->Slice(2, values->length()); - AssertTakeArrays(values_sliced, indices, expected); - } + ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(type, 2)); + ASSERT_OK_AND_ASSIGN(auto values_sliced, + Concatenate({values_filler, values, values_filler})); + values_sliced = values_sliced->Slice(2, values->length()); + AssertTakeArrays(values_sliced, indices, expected); // Check sliced indices ASSERT_OK_AND_ASSIGN(auto zero, MakeScalar(index_type, int8_t{0})); @@ -1484,8 +1466,8 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) { {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}), sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) { auto union_json = R"([ - [2, null], [2, 222], + [2, null], [5, "hello"], [5, "eh"], [2, null], @@ -1493,21 +1475,21 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) { [5, null] ])"; CheckTake(union_type, union_json, "[]", "[]"); - CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([ + CheckTake(union_type, union_json, "[3, 0, 3, 0, 3]", R"([ [5, "eh"], [2, 222], [5, "eh"], [2, 222], [5, "eh"] ])"); - CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([ + CheckTake(union_type, union_json, "[4, 2, 0, 6]", R"([ [2, null], [5, "hello"], [2, 222], [5, null] ])"); CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json); - CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ + CheckTake(union_type, union_json, "[1, 2, 2, 2, 2, 2, 2]", R"([ [2, null], [5, "hello"], [5, "hello"], @@ -1516,6 +1498,15 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) { [5, "hello"], [5, "hello"] ])"); + CheckTake(union_type, union_json, "[0, null, 1, null, 2, 2, 2]", R"([ + [2, 222], + [2, null], + [2, null], + [2, null], + [5, "hello"], + [5, "hello"], + [5, "hello"] + ])"); } }