From cba22d6855d27b2f9a340906bb1f7eb8396bd988 Mon Sep 17 00:00:00 2001 From: Jin Shang Date: Thu, 27 Jul 2023 07:13:31 +0800 Subject: [PATCH] GH-36905: selection functions support sparse union --- .../vector_selection_filter_internal.cc | 2 + .../kernels/vector_selection_internal.cc | 81 +++++++++++++++++++ .../kernels/vector_selection_internal.h | 2 + .../kernels/vector_selection_take_internal.cc | 1 + .../compute/kernels/vector_selection_test.cc | 53 +++++++----- docs/source/cpp/compute.rst | 18 ++--- 6 files changed, 125 insertions(+), 32 deletions(-) diff --git a/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc index 13e92ba27ec2d..b1acfabb5fddb 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc @@ -1047,6 +1047,7 @@ void PopulateFilterKernels(std::vector* out) { {InputType(Type::LARGE_LIST), plain_filter, LargeListFilterExec}, {InputType(Type::FIXED_SIZE_LIST), plain_filter, FSLFilterExec}, {InputType(Type::DENSE_UNION), plain_filter, DenseUnionFilterExec}, + {InputType(Type::SPARSE_UNION), plain_filter, SparseUnionFilterExec}, {InputType(Type::STRUCT), plain_filter, StructFilterExec}, {InputType(Type::MAP), plain_filter, MapFilterExec}, @@ -1064,6 +1065,7 @@ void PopulateFilterKernels(std::vector* out) { {InputType(Type::LARGE_LIST), ree_filter, LargeListFilterExec}, {InputType(Type::FIXED_SIZE_LIST), ree_filter, FSLFilterExec}, {InputType(Type::DENSE_UNION), ree_filter, DenseUnionFilterExec}, + {InputType(Type::SPARSE_UNION), ree_filter, SparseUnionFilterExec}, {InputType(Type::STRUCT), ree_filter, StructFilterExec}, {InputType(Type::MAP), ree_filter, MapFilterExec}, }; diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc index 23b8b75bfa024..be5b12409fa2b 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc @@ -741,6 +741,79 @@ struct DenseUnionSelectionImpl } }; +struct SparseUnionSelectionImpl + : public Selection { + using Base = Selection; + LIFT_BASE_MEMBERS(); + + TypedBufferBuilder child_id_buffer_builder_; + std::vector type_codes_; + std::vector child_indices_builders_; + + SparseUnionSelectionImpl(KernelContext* ctx, const ExecSpan& batch, + int64_t output_length, ExecResult* out) + : Base(ctx, batch, output_length, out), + child_id_buffer_builder_(ctx->memory_pool()), + type_codes_(checked_cast(*this->values.type).type_codes()), + child_indices_builders_(type_codes_.size()) { + for (auto& child_indices_builder : child_indices_builders_) { + child_indices_builder = Int32Builder(ctx->memory_pool()); + } + } + + template + Status GenerateOutput() { + SparseUnionArray typed_values(this->values.ToArrayData()); + Adapter adapter(this); + RETURN_NOT_OK(adapter.Generate( + [&](int64_t index) { + int8_t child_id = typed_values.child_id(index); + child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]); + // TODO(jinshang): We use a naive approach for now: apply take for each child + // array. There is room for optimization because the unselected child arrays can + // have any value at this slot. + for (auto& child_indices_builder : child_indices_builders_) { + child_indices_builder.UnsafeAppend(index); + } + return Status::OK(); + }, + [&]() { + int8_t child_id = 0; + child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]); + for (auto& child_indices_builder : child_indices_builders_) { + child_indices_builder.UnsafeAppendNull(); + } + return Status::OK(); + })); + return Status::OK(); + } + + Status Init() override { + RETURN_NOT_OK(child_id_buffer_builder_.Reserve(output_length)); + for (auto& child_index_builder : child_indices_builders_) { + RETURN_NOT_OK(child_index_builder.Reserve(output_length)); + } + return Status::OK(); + } + + Status Finish() override { + ARROW_ASSIGN_OR_RAISE(auto child_ids_buffer, child_id_buffer_builder_.Finish()); + SparseUnionArray typed_values(this->values.ToArrayData()); + auto num_fields = typed_values.num_fields(); + auto num_rows = child_ids_buffer->size(); + BufferVector buffers{nullptr, std::move(child_ids_buffer)}; + *out = ArrayData(typed_values.type(), num_rows, std::move(buffers), 0); + for (auto i = 0; i < num_fields; i++) { + ARROW_ASSIGN_OR_RAISE(auto child_indices_array, + child_indices_builders_[i].Finish()); + ARROW_ASSIGN_OR_RAISE(std::shared_ptr child_array, + Take(*typed_values.field(i), *child_indices_array)); + out->child_data.push_back(child_array->data()); + } + return Status::OK(); + } +}; + struct FSLSelectionImpl : public Selection { Int64Builder child_index_builder; @@ -863,6 +936,10 @@ Status DenseUnionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResul return FilterExec(ctx, batch, out); } +Status SparseUnionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + return FilterExec(ctx, batch, out); +} + Status MapFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { return FilterExec>(ctx, batch, out); } @@ -909,6 +986,10 @@ Status DenseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* return TakeExec(ctx, batch, out); } +Status SparseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + return TakeExec(ctx, batch, out); +} + Status StructTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { return TakeExec(ctx, batch, out); } diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.h b/cpp/src/arrow/compute/kernels/vector_selection_internal.h index bcffdd820db3c..76cfa28c65ac7 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_internal.h +++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.h @@ -73,6 +73,7 @@ Status ListFilterExec(KernelContext*, const ExecSpan&, ExecResult*); Status LargeListFilterExec(KernelContext*, const ExecSpan&, ExecResult*); Status FSLFilterExec(KernelContext*, const ExecSpan&, ExecResult*); Status DenseUnionFilterExec(KernelContext*, const ExecSpan&, ExecResult*); +Status SparseUnionFilterExec(KernelContext*, const ExecSpan&, ExecResult*); Status MapFilterExec(KernelContext*, const ExecSpan&, ExecResult*); Status VarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*); @@ -82,6 +83,7 @@ Status ListTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status FSLTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status DenseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*); +Status SparseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status StructTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status MapTakeExec(KernelContext*, const ExecSpan&, ExecResult*); diff --git a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc index ab80127731ceb..612de8505d3ab 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc @@ -844,6 +844,7 @@ void PopulateTakeKernels(std::vector* out) { {InputType(Type::LARGE_LIST), take_indices, LargeListTakeExec}, {InputType(Type::FIXED_SIZE_LIST), take_indices, FSLTakeExec}, {InputType(Type::DENSE_UNION), take_indices, DenseUnionTakeExec}, + {InputType(Type::SPARSE_UNION), take_indices, SparseUnionTakeExec}, {InputType(Type::STRUCT), take_indices, StructTakeExec}, {InputType(Type::MAP), take_indices, MapTakeExec}, }; diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc b/cpp/src/arrow/compute/kernels/vector_selection_test.cc index 5b624911ff5fd..ac5bfe449c49c 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc @@ -759,8 +759,10 @@ TEST_F(TestFilterKernelWithStruct, FilterStruct) { class TestFilterKernelWithUnion : public TestFilterKernel {}; TEST_F(TestFilterKernelWithUnion, FilterUnion) { - auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5}); - auto union_json = R"([ + for (const auto& union_type : + {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}), + sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) { + auto union_json = R"([ [2, null], [2, 222], [5, "hello"], @@ -769,31 +771,32 @@ TEST_F(TestFilterKernelWithUnion, FilterUnion) { [2, 111], [5, null] ])"; - this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]"); - this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([ + this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]"); + this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([ [2, 222], [5, "hello"], [2, null], [2, 111], [5, null] ])"); - this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([ + this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([ [2, null], [5, "hello"], [2, null] ])"); - this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json); + this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json); - // Sliced - // (check this manually as concatenation of dense unions isn't supported: ARROW-4975) - auto values = ArrayFromJSON(union_type, union_json)->Slice(2, 4); - auto filter = ArrayFromJSON(boolean(), "[0, 1, 1, null, 0, 1, 1]")->Slice(2, 4); - auto expected = ArrayFromJSON(union_type, R"([ + // Sliced + // (check this manually as concatenation of dense unions isn't supported: ARROW-4975) + auto values = ArrayFromJSON(union_type, union_json)->Slice(2, 4); + auto filter = ArrayFromJSON(boolean(), "[0, 1, 1, null, 0, 1, 1]")->Slice(2, 4); + auto expected = ArrayFromJSON(union_type, R"([ [5, "hello"], [2, null], [2, 111] ])"); - this->AssertFilter(values, filter, expected); + this->AssertFilter(values, filter, expected); + } } class TestFilterKernelWithRecordBatch : public TestFilterKernel { @@ -1477,8 +1480,10 @@ TEST_F(TestTakeKernelWithStruct, TakeStruct) { class TestTakeKernelWithUnion : public TestTakeKernelTyped {}; TEST_F(TestTakeKernelWithUnion, TakeUnion) { - auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5}); - auto union_json = R"([ + for (const auto& union_type : + {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}), + sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) { + auto union_json = R"([ [2, null], [2, 222], [5, "hello"], @@ -1487,22 +1492,22 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) { [2, 111], [5, null] ])"; - CheckTake(union_type, union_json, "[]", "[]"); - CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([ + CheckTake(union_type, union_json, "[]", "[]"); + CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([ [5, "eh"], [2, 222], [5, "eh"], [2, 222], [5, "eh"] ])"); - CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([ + CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([ [2, null], [5, "hello"], [2, 222], [5, null] ])"); - CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json); - CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ + CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json); + CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ [2, null], [5, "hello"], [5, "hello"], @@ -1511,6 +1516,7 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) { [5, "hello"], [5, "hello"] ])"); + } } class TestPermutationsWithTake : public ::testing::Test { @@ -2162,8 +2168,10 @@ TEST_F(TestDropNullKernelWithStruct, DropNullStruct) { class TestDropNullKernelWithUnion : public TestDropNullKernelTyped {}; TEST_F(TestDropNullKernelWithUnion, DropNullUnion) { - auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5}); - auto union_json = R"([ + for (const auto& union_type : + {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}), + sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) { + auto union_json = R"([ [2, null], [2, 222], [5, "hello"], @@ -2172,7 +2180,8 @@ TEST_F(TestDropNullKernelWithUnion, DropNullUnion) { [2, 111], [5, null] ])"; - CheckDropNull(union_type, union_json, union_json); + CheckDropNull(union_type, union_json, union_json); + } } class TestDropNullKernelWithRecordBatch : public TestDropNullKernelTyped { diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 55e29588129b8..b17625420beba 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1680,28 +1680,26 @@ These functions select and return a subset of their input. +---------------+--------+--------------+--------------+--------------+-------------------------+-----------+ | Function name | Arity | Input type 1 | Input type 2 | Output type | Options class | Notes | +===============+========+==============+==============+==============+=========================+===========+ -| array_filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(1) \(3) | +| array_filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(2) | +---------------+--------+--------------+--------------+--------------+-------------------------+-----------+ -| array_take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(1) \(4) | +| array_take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(3) | +---------------+--------+--------------+--------------+--------------+-------------------------+-----------+ -| drop_null | Unary | Any | - | Input type 1 | | \(1) \(2) | +| drop_null | Unary | Any | - | Input type 1 | | \(1) | +---------------+--------+--------------+--------------+--------------+-------------------------+-----------+ -| filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(1) \(3) | +| filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(3) | +---------------+--------+--------------+--------------+--------------+-------------------------+-----------+ -| take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(1) \(4) | +| take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(4) | +---------------+--------+--------------+--------------+--------------+-------------------------+-----------+ -* \(1) Sparse unions are unsupported. - -* \(2) Each element in the input is appended to the output iff it is non-null. +* \(1) Each element in the input is appended to the output iff it is non-null. If the input is a record batch or table, any null value in a column drops the entire row. -* \(3) Each element in input 1 (the values) is appended to the output iff +* \(2) Each element in input 1 (the values) is appended to the output iff the corresponding element in input 2 (the filter) is true. How nulls in the filter are handled can be configured using FilterOptions. -* \(4) For each element *i* in input 2 (the indices), the *i*'th element +* \(3) For each element *i* in input 2 (the indices), the *i*'th element in input 1 (the values) is appended to the output. Containment tests