diff --git a/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc index 13e92ba27ec2d..be6d1653b5722 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc @@ -27,6 +27,7 @@ #include "arrow/chunked_array.h" #include "arrow/compute/api_vector.h" #include "arrow/compute/exec.h" +#include "arrow/compute/kernel.h" #include "arrow/compute/kernels/codegen_internal.h" #include "arrow/compute/kernels/vector_selection_filter_internal.h" #include "arrow/compute/kernels/vector_selection_internal.h" @@ -49,8 +50,7 @@ using internal::CopyBitmap; using internal::CountSetBits; using internal::OptionalBitBlockCounter; -namespace compute { -namespace internal { +namespace compute::internal { namespace { @@ -863,20 +863,29 @@ Status ExtensionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult return Status::OK(); } -Status StructFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { - // Transform filter to selection indices and then use Take. +// Transform filter to selection indices and then use Take. +Status FilterWithTakeExec(const ArrayKernelExec& take_exec, KernelContext* ctx, + const ExecSpan& batch, ExecResult* out) { std::shared_ptr indices; RETURN_NOT_OK(GetTakeIndices(batch[1].array, FilterState::Get(ctx).null_selection_behavior, ctx->memory_pool()) .Value(&indices)); + KernelContext take_ctx(*ctx); + TakeState state{TakeOptions::NoBoundsCheck()}; + take_ctx.SetState(&state); + ExecSpan take_batch({batch[0], ArraySpan(*indices)}, batch.length); + return take_exec(&take_ctx, take_batch, out); +} - Datum result; - RETURN_NOT_OK(Take(batch[0].array.ToArrayData(), Datum(indices), - TakeOptions::NoBoundsCheck(), ctx->exec_context()) - .Value(&result)); - out->value = result.array(); - return Status::OK(); +// Due to the special treatment with their Take kernels, we filter Struct and SparseUnion +// arrays by transforming filter to selection indices and call Take. +Status StructFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + return FilterWithTakeExec(StructTakeExec, ctx, batch, out); +} + +Status SparseUnionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + return FilterWithTakeExec(SparseUnionTakeExec, ctx, batch, out); } // ---------------------------------------------------------------------- @@ -1047,6 +1056,7 @@ void PopulateFilterKernels(std::vector* out) { {InputType(Type::LARGE_LIST), plain_filter, LargeListFilterExec}, {InputType(Type::FIXED_SIZE_LIST), plain_filter, FSLFilterExec}, {InputType(Type::DENSE_UNION), plain_filter, DenseUnionFilterExec}, + {InputType(Type::SPARSE_UNION), plain_filter, SparseUnionFilterExec}, {InputType(Type::STRUCT), plain_filter, StructFilterExec}, {InputType(Type::MAP), plain_filter, MapFilterExec}, @@ -1064,12 +1074,12 @@ void PopulateFilterKernels(std::vector* out) { {InputType(Type::LARGE_LIST), ree_filter, LargeListFilterExec}, {InputType(Type::FIXED_SIZE_LIST), ree_filter, FSLFilterExec}, {InputType(Type::DENSE_UNION), ree_filter, DenseUnionFilterExec}, + {InputType(Type::SPARSE_UNION), ree_filter, SparseUnionFilterExec}, {InputType(Type::STRUCT), ree_filter, StructFilterExec}, {InputType(Type::MAP), ree_filter, MapFilterExec}, }; } -} // namespace internal -} // namespace compute +} // namespace compute::internal } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc index 23b8b75bfa024..98eb37e9c5fd2 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.cc @@ -45,8 +45,7 @@ namespace arrow { using internal::CheckIndexBounds; -namespace compute { -namespace internal { +namespace compute::internal { void RegisterSelectionFunction(const std::string& name, FunctionDoc doc, VectorKernel base_kernel, @@ -171,9 +170,6 @@ void VisitPlainxREEFilterOutputSegments( namespace { -using FilterState = OptionsWrapper; -using TakeState = OptionsWrapper; - // ---------------------------------------------------------------------- // Implement take for other data types where there is less performance // sensitivity by visiting the selected indices. @@ -741,6 +737,66 @@ struct DenseUnionSelectionImpl } }; +// We need a slightly different approach for SparseUnion. For Take, we can +// invoke Take on each child's data with boundschecking disabled. For +// Filter on the other hand, if we naively call Filter on each child, then the +// filter output length will have to be redundantly computed. Thus, for Filter +// we instead convert the filter to selection indices and then invoke take. + +// SparseUnion selection implementation. ONLY used for Take +struct SparseUnionSelectionImpl + : public Selection { + using Base = Selection; + LIFT_BASE_MEMBERS(); + + TypedBufferBuilder child_id_buffer_builder_; + const int8_t type_code_for_null_; + + SparseUnionSelectionImpl(KernelContext* ctx, const ExecSpan& batch, + int64_t output_length, ExecResult* out) + : Base(ctx, batch, output_length, out), + child_id_buffer_builder_(ctx->memory_pool()), + type_code_for_null_( + checked_cast(*this->values.type).type_codes()[0]) {} + + template + Status GenerateOutput() { + SparseUnionArray typed_values(this->values.ToArrayData()); + Adapter adapter(this); + RETURN_NOT_OK(adapter.Generate( + [&](int64_t index) { + child_id_buffer_builder_.UnsafeAppend(typed_values.type_code(index)); + return Status::OK(); + }, + [&]() { + child_id_buffer_builder_.UnsafeAppend(type_code_for_null_); + return Status::OK(); + })); + return Status::OK(); + } + + Status Init() override { + RETURN_NOT_OK(child_id_buffer_builder_.Reserve(output_length)); + return Status::OK(); + } + + Status Finish() override { + ARROW_ASSIGN_OR_RAISE(auto child_ids_buffer, child_id_buffer_builder_.Finish()); + SparseUnionArray typed_values(this->values.ToArrayData()); + auto num_fields = typed_values.num_fields(); + auto num_rows = child_ids_buffer->size(); + BufferVector buffers{nullptr, std::move(child_ids_buffer)}; + *out = ArrayData(typed_values.type(), num_rows, std::move(buffers), 0); + out->child_data.reserve(num_fields); + for (auto i = 0; i < num_fields; i++) { + ARROW_ASSIGN_OR_RAISE(auto child_datum, + Take(*typed_values.field(i), *this->selection.ToArrayData())); + out->child_data.emplace_back(std::move(child_datum).array()); + } + return Status::OK(); + } +}; + struct FSLSelectionImpl : public Selection { Int64Builder child_index_builder; @@ -909,6 +965,10 @@ Status DenseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* return TakeExec(ctx, batch, out); } +Status SparseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { + return TakeExec(ctx, batch, out); +} + Status StructTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { return TakeExec(ctx, batch, out); } @@ -917,6 +977,5 @@ Status MapTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) { return TakeExec>(ctx, batch, out); } -} // namespace internal -} // namespace compute +} // namespace compute::internal } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_selection_internal.h b/cpp/src/arrow/compute/kernels/vector_selection_internal.h index bcffdd820db3c..b9eba6ea6631f 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_internal.h +++ b/cpp/src/arrow/compute/kernels/vector_selection_internal.h @@ -26,10 +26,12 @@ #include "arrow/compute/exec.h" #include "arrow/compute/function.h" #include "arrow/compute/kernel.h" +#include "arrow/compute/kernels/codegen_internal.h" -namespace arrow { -namespace compute { -namespace internal { +namespace arrow::compute::internal { + +using FilterState = OptionsWrapper; +using TakeState = OptionsWrapper; struct SelectionKernelData { InputType value_type; @@ -82,9 +84,8 @@ Status ListTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status FSLTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status DenseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*); +Status SparseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status StructTakeExec(KernelContext*, const ExecSpan&, ExecResult*); Status MapTakeExec(KernelContext*, const ExecSpan&, ExecResult*); -} // namespace internal -} // namespace compute -} // namespace arrow +} // namespace arrow::compute::internal diff --git a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc index ab80127731ceb..612de8505d3ab 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_take_internal.cc @@ -844,6 +844,7 @@ void PopulateTakeKernels(std::vector* out) { {InputType(Type::LARGE_LIST), take_indices, LargeListTakeExec}, {InputType(Type::FIXED_SIZE_LIST), take_indices, FSLTakeExec}, {InputType(Type::DENSE_UNION), take_indices, DenseUnionTakeExec}, + {InputType(Type::SPARSE_UNION), take_indices, SparseUnionTakeExec}, {InputType(Type::STRUCT), take_indices, StructTakeExec}, {InputType(Type::MAP), take_indices, MapTakeExec}, }; diff --git a/cpp/src/arrow/compute/kernels/vector_selection_test.cc b/cpp/src/arrow/compute/kernels/vector_selection_test.cc index 5b624911ff5fd..30e85c1f71089 100644 --- a/cpp/src/arrow/compute/kernels/vector_selection_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_selection_test.cc @@ -282,11 +282,6 @@ class TestFilterKernel : public ::testing::Test { const std::shared_ptr& expected) { DoAssertFilter(values, filter, expected); - if (values->type_id() == Type::DENSE_UNION) { - // Concatenation of dense union not supported - return; - } - // Check slicing: add M(=3) dummy values at the start and end of `values`, // add N(=2) dummy values at the start and end of `filter`. ARROW_SCOPED_TRACE("for sliced values and filter"); @@ -759,8 +754,10 @@ TEST_F(TestFilterKernelWithStruct, FilterStruct) { class TestFilterKernelWithUnion : public TestFilterKernel {}; TEST_F(TestFilterKernelWithUnion, FilterUnion) { - auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5}); - auto union_json = R"([ + for (const auto& union_type : + {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}), + sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) { + auto union_json = R"([ [2, null], [2, 222], [5, "hello"], @@ -769,31 +766,21 @@ TEST_F(TestFilterKernelWithUnion, FilterUnion) { [2, 111], [5, null] ])"; - this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]"); - this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([ + this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]"); + this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([ [2, 222], [5, "hello"], [2, null], [2, 111], [5, null] ])"); - this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([ + this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([ [2, null], [5, "hello"], [2, null] ])"); - this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json); - - // Sliced - // (check this manually as concatenation of dense unions isn't supported: ARROW-4975) - auto values = ArrayFromJSON(union_type, union_json)->Slice(2, 4); - auto filter = ArrayFromJSON(boolean(), "[0, 1, 1, null, 0, 1, 1]")->Slice(2, 4); - auto expected = ArrayFromJSON(union_type, R"([ - [5, "hello"], - [2, null], - [2, 111] - ])"); - this->AssertFilter(values, filter, expected); + this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json); + } } class TestFilterKernelWithRecordBatch : public TestFilterKernel { @@ -1026,13 +1013,11 @@ void CheckTake(const std::shared_ptr& type, const std::string& values_ AssertTakeArrays(values, indices, expected); // Check sliced values - if (type->id() != Type::DENSE_UNION) { - ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(type, 2)); - ASSERT_OK_AND_ASSIGN(auto values_sliced, - Concatenate({values_filler, values, values_filler})); - values_sliced = values_sliced->Slice(2, values->length()); - AssertTakeArrays(values_sliced, indices, expected); - } + ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(type, 2)); + ASSERT_OK_AND_ASSIGN(auto values_sliced, + Concatenate({values_filler, values, values_filler})); + values_sliced = values_sliced->Slice(2, values->length()); + AssertTakeArrays(values_sliced, indices, expected); // Check sliced indices ASSERT_OK_AND_ASSIGN(auto zero, MakeScalar(index_type, int8_t{0})); @@ -1477,32 +1462,34 @@ TEST_F(TestTakeKernelWithStruct, TakeStruct) { class TestTakeKernelWithUnion : public TestTakeKernelTyped {}; TEST_F(TestTakeKernelWithUnion, TakeUnion) { - auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5}); - auto union_json = R"([ - [2, null], + for (const auto& union_type : + {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}), + sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) { + auto union_json = R"([ [2, 222], + [2, null], [5, "hello"], [5, "eh"], [2, null], [2, 111], [5, null] ])"; - CheckTake(union_type, union_json, "[]", "[]"); - CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([ + CheckTake(union_type, union_json, "[]", "[]"); + CheckTake(union_type, union_json, "[3, 0, 3, 0, 3]", R"([ [5, "eh"], [2, 222], [5, "eh"], [2, 222], [5, "eh"] ])"); - CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([ + CheckTake(union_type, union_json, "[4, 2, 0, 6]", R"([ [2, null], [5, "hello"], [2, 222], [5, null] ])"); - CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json); - CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([ + CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json); + CheckTake(union_type, union_json, "[1, 2, 2, 2, 2, 2, 2]", R"([ [2, null], [5, "hello"], [5, "hello"], @@ -1511,6 +1498,16 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) { [5, "hello"], [5, "hello"] ])"); + CheckTake(union_type, union_json, "[0, null, 1, null, 2, 2, 2]", R"([ + [2, 222], + [2, null], + [2, null], + [2, null], + [5, "hello"], + [5, "hello"], + [5, "hello"] + ])"); + } } class TestPermutationsWithTake : public ::testing::Test { @@ -2162,8 +2159,10 @@ TEST_F(TestDropNullKernelWithStruct, DropNullStruct) { class TestDropNullKernelWithUnion : public TestDropNullKernelTyped {}; TEST_F(TestDropNullKernelWithUnion, DropNullUnion) { - auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5}); - auto union_json = R"([ + for (const auto& union_type : + {dense_union({field("a", int32()), field("b", utf8())}, {2, 5}), + sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) { + auto union_json = R"([ [2, null], [2, 222], [5, "hello"], @@ -2172,7 +2171,8 @@ TEST_F(TestDropNullKernelWithUnion, DropNullUnion) { [2, 111], [5, null] ])"; - CheckDropNull(union_type, union_json, union_json); + CheckDropNull(union_type, union_json, union_json); + } } class TestDropNullKernelWithRecordBatch : public TestDropNullKernelTyped { diff --git a/docs/source/cpp/compute.rst b/docs/source/cpp/compute.rst index 55e29588129b8..010a1ac78c895 100644 --- a/docs/source/cpp/compute.rst +++ b/docs/source/cpp/compute.rst @@ -1680,28 +1680,26 @@ These functions select and return a subset of their input. +---------------+--------+--------------+--------------+--------------+-------------------------+-----------+ | Function name | Arity | Input type 1 | Input type 2 | Output type | Options class | Notes | +===============+========+==============+==============+==============+=========================+===========+ -| array_filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(1) \(3) | +| array_filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(2) | +---------------+--------+--------------+--------------+--------------+-------------------------+-----------+ -| array_take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(1) \(4) | +| array_take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(3) | +---------------+--------+--------------+--------------+--------------+-------------------------+-----------+ -| drop_null | Unary | Any | - | Input type 1 | | \(1) \(2) | +| drop_null | Unary | Any | - | Input type 1 | | \(1) | +---------------+--------+--------------+--------------+--------------+-------------------------+-----------+ -| filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(1) \(3) | +| filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(2) | +---------------+--------+--------------+--------------+--------------+-------------------------+-----------+ -| take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(1) \(4) | +| take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(3) | +---------------+--------+--------------+--------------+--------------+-------------------------+-----------+ -* \(1) Sparse unions are unsupported. - -* \(2) Each element in the input is appended to the output iff it is non-null. +* \(1) Each element in the input is appended to the output iff it is non-null. If the input is a record batch or table, any null value in a column drops the entire row. -* \(3) Each element in input 1 (the values) is appended to the output iff +* \(2) Each element in input 1 (the values) is appended to the output iff the corresponding element in input 2 (the filter) is true. How nulls in the filter are handled can be configured using FilterOptions. -* \(4) For each element *i* in input 2 (the indices), the *i*'th element +* \(3) For each element *i* in input 2 (the indices), the *i*'th element in input 1 (the values) is appended to the output. Containment tests