Skip to content

Commit

Permalink
GH-36905: selection functions support sparse union
Browse files Browse the repository at this point in the history
  • Loading branch information
js8544 committed Jul 26, 2023
1 parent 2f039f1 commit cba22d6
Show file tree
Hide file tree
Showing 6 changed files with 125 additions and 32 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,7 @@ void PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
{InputType(Type::LARGE_LIST), plain_filter, LargeListFilterExec},
{InputType(Type::FIXED_SIZE_LIST), plain_filter, FSLFilterExec},
{InputType(Type::DENSE_UNION), plain_filter, DenseUnionFilterExec},
{InputType(Type::SPARSE_UNION), plain_filter, SparseUnionFilterExec},
{InputType(Type::STRUCT), plain_filter, StructFilterExec},
{InputType(Type::MAP), plain_filter, MapFilterExec},

Expand All @@ -1064,6 +1065,7 @@ void PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
{InputType(Type::LARGE_LIST), ree_filter, LargeListFilterExec},
{InputType(Type::FIXED_SIZE_LIST), ree_filter, FSLFilterExec},
{InputType(Type::DENSE_UNION), ree_filter, DenseUnionFilterExec},
{InputType(Type::SPARSE_UNION), ree_filter, SparseUnionFilterExec},
{InputType(Type::STRUCT), ree_filter, StructFilterExec},
{InputType(Type::MAP), ree_filter, MapFilterExec},
};
Expand Down
81 changes: 81 additions & 0 deletions cpp/src/arrow/compute/kernels/vector_selection_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,79 @@ struct DenseUnionSelectionImpl
}
};

struct SparseUnionSelectionImpl
: public Selection<SparseUnionSelectionImpl, SparseUnionType> {
using Base = Selection<SparseUnionSelectionImpl, SparseUnionType>;
LIFT_BASE_MEMBERS();

TypedBufferBuilder<int8_t> child_id_buffer_builder_;
std::vector<int8_t> type_codes_;
std::vector<Int32Builder> child_indices_builders_;

SparseUnionSelectionImpl(KernelContext* ctx, const ExecSpan& batch,
int64_t output_length, ExecResult* out)
: Base(ctx, batch, output_length, out),
child_id_buffer_builder_(ctx->memory_pool()),
type_codes_(checked_cast<const UnionType&>(*this->values.type).type_codes()),
child_indices_builders_(type_codes_.size()) {
for (auto& child_indices_builder : child_indices_builders_) {
child_indices_builder = Int32Builder(ctx->memory_pool());
}
}

template <typename Adapter>
Status GenerateOutput() {
SparseUnionArray typed_values(this->values.ToArrayData());
Adapter adapter(this);
RETURN_NOT_OK(adapter.Generate(
[&](int64_t index) {
int8_t child_id = typed_values.child_id(index);
child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]);
// TODO(jinshang): We use a naive approach for now: apply take for each child
// array. There is room for optimization because the unselected child arrays can
// have any value at this slot.
for (auto& child_indices_builder : child_indices_builders_) {
child_indices_builder.UnsafeAppend(index);
}
return Status::OK();
},
[&]() {
int8_t child_id = 0;
child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]);
for (auto& child_indices_builder : child_indices_builders_) {
child_indices_builder.UnsafeAppendNull();
}
return Status::OK();
}));
return Status::OK();
}

Status Init() override {
RETURN_NOT_OK(child_id_buffer_builder_.Reserve(output_length));
for (auto& child_index_builder : child_indices_builders_) {
RETURN_NOT_OK(child_index_builder.Reserve(output_length));
}
return Status::OK();
}

Status Finish() override {
ARROW_ASSIGN_OR_RAISE(auto child_ids_buffer, child_id_buffer_builder_.Finish());
SparseUnionArray typed_values(this->values.ToArrayData());
auto num_fields = typed_values.num_fields();
auto num_rows = child_ids_buffer->size();
BufferVector buffers{nullptr, std::move(child_ids_buffer)};
*out = ArrayData(typed_values.type(), num_rows, std::move(buffers), 0);
for (auto i = 0; i < num_fields; i++) {
ARROW_ASSIGN_OR_RAISE(auto child_indices_array,
child_indices_builders_[i].Finish());
ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Array> child_array,
Take(*typed_values.field(i), *child_indices_array));
out->child_data.push_back(child_array->data());
}
return Status::OK();
}
};

struct FSLSelectionImpl : public Selection<FSLSelectionImpl, FixedSizeListType> {
Int64Builder child_index_builder;

Expand Down Expand Up @@ -863,6 +936,10 @@ Status DenseUnionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResul
return FilterExec<DenseUnionSelectionImpl>(ctx, batch, out);
}

Status SparseUnionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return FilterExec<SparseUnionSelectionImpl>(ctx, batch, out);
}

Status MapFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return FilterExec<ListSelectionImpl<MapType>>(ctx, batch, out);
}
Expand Down Expand Up @@ -909,6 +986,10 @@ Status DenseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult*
return TakeExec<DenseUnionSelectionImpl>(ctx, batch, out);
}

Status SparseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<SparseUnionSelectionImpl>(ctx, batch, out);
}

Status StructTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<StructSelectionImpl>(ctx, batch, out);
}
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/kernels/vector_selection_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Status ListFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
Status LargeListFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
Status FSLFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
Status DenseUnionFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
Status SparseUnionFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
Status MapFilterExec(KernelContext*, const ExecSpan&, ExecResult*);

Status VarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Expand All @@ -82,6 +83,7 @@ Status ListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status FSLTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status DenseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status SparseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status StructTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status MapTakeExec(KernelContext*, const ExecSpan&, ExecResult*);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,7 @@ void PopulateTakeKernels(std::vector<SelectionKernelData>* out) {
{InputType(Type::LARGE_LIST), take_indices, LargeListTakeExec},
{InputType(Type::FIXED_SIZE_LIST), take_indices, FSLTakeExec},
{InputType(Type::DENSE_UNION), take_indices, DenseUnionTakeExec},
{InputType(Type::SPARSE_UNION), take_indices, SparseUnionTakeExec},
{InputType(Type::STRUCT), take_indices, StructTakeExec},
{InputType(Type::MAP), take_indices, MapTakeExec},
};
Expand Down
53 changes: 31 additions & 22 deletions cpp/src/arrow/compute/kernels/vector_selection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,10 @@ TEST_F(TestFilterKernelWithStruct, FilterStruct) {
class TestFilterKernelWithUnion : public TestFilterKernel {};

TEST_F(TestFilterKernelWithUnion, FilterUnion) {
auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
auto union_json = R"([
for (const auto& union_type :
{dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
auto union_json = R"([
[2, null],
[2, 222],
[5, "hello"],
Expand All @@ -769,31 +771,32 @@ TEST_F(TestFilterKernelWithUnion, FilterUnion) {
[2, 111],
[5, null]
])";
this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]");
this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([
this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]");
this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([
[2, 222],
[5, "hello"],
[2, null],
[2, 111],
[5, null]
])");
this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([
this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([
[2, null],
[5, "hello"],
[2, null]
])");
this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json);
this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json);

// Sliced
// (check this manually as concatenation of dense unions isn't supported: ARROW-4975)
auto values = ArrayFromJSON(union_type, union_json)->Slice(2, 4);
auto filter = ArrayFromJSON(boolean(), "[0, 1, 1, null, 0, 1, 1]")->Slice(2, 4);
auto expected = ArrayFromJSON(union_type, R"([
// Sliced
// (check this manually as concatenation of dense unions isn't supported: ARROW-4975)
auto values = ArrayFromJSON(union_type, union_json)->Slice(2, 4);
auto filter = ArrayFromJSON(boolean(), "[0, 1, 1, null, 0, 1, 1]")->Slice(2, 4);
auto expected = ArrayFromJSON(union_type, R"([
[5, "hello"],
[2, null],
[2, 111]
])");
this->AssertFilter(values, filter, expected);
this->AssertFilter(values, filter, expected);
}
}

class TestFilterKernelWithRecordBatch : public TestFilterKernel {
Expand Down Expand Up @@ -1477,8 +1480,10 @@ TEST_F(TestTakeKernelWithStruct, TakeStruct) {
class TestTakeKernelWithUnion : public TestTakeKernelTyped<UnionType> {};

TEST_F(TestTakeKernelWithUnion, TakeUnion) {
auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
auto union_json = R"([
for (const auto& union_type :
{dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
auto union_json = R"([
[2, null],
[2, 222],
[5, "hello"],
Expand All @@ -1487,22 +1492,22 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) {
[2, 111],
[5, null]
])";
CheckTake(union_type, union_json, "[]", "[]");
CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([
CheckTake(union_type, union_json, "[]", "[]");
CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([
[5, "eh"],
[2, 222],
[5, "eh"],
[2, 222],
[5, "eh"]
])");
CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([
CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([
[2, null],
[5, "hello"],
[2, 222],
[5, null]
])");
CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json);
CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json);
CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
[2, null],
[5, "hello"],
[5, "hello"],
Expand All @@ -1511,6 +1516,7 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) {
[5, "hello"],
[5, "hello"]
])");
}
}

class TestPermutationsWithTake : public ::testing::Test {
Expand Down Expand Up @@ -2162,8 +2168,10 @@ TEST_F(TestDropNullKernelWithStruct, DropNullStruct) {
class TestDropNullKernelWithUnion : public TestDropNullKernelTyped<UnionType> {};

TEST_F(TestDropNullKernelWithUnion, DropNullUnion) {
auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
auto union_json = R"([
for (const auto& union_type :
{dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
auto union_json = R"([
[2, null],
[2, 222],
[5, "hello"],
Expand All @@ -2172,7 +2180,8 @@ TEST_F(TestDropNullKernelWithUnion, DropNullUnion) {
[2, 111],
[5, null]
])";
CheckDropNull(union_type, union_json, union_json);
CheckDropNull(union_type, union_json, union_json);
}
}

class TestDropNullKernelWithRecordBatch : public TestDropNullKernelTyped<RecordBatch> {
Expand Down
18 changes: 8 additions & 10 deletions docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1680,28 +1680,26 @@ These functions select and return a subset of their input.
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
| Function name | Arity | Input type 1 | Input type 2 | Output type | Options class | Notes |
+===============+========+==============+==============+==============+=========================+===========+
| array_filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(1) \(3) |
| array_filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(2) |
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
| array_take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(1) \(4) |
| array_take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(3) |
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
| drop_null | Unary | Any | - | Input type 1 | | \(1) \(2) |
| drop_null | Unary | Any | - | Input type 1 | | \(1) |
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
| filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(1) \(3) |
| filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(3) |
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
| take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(1) \(4) |
| take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(4) |
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+

* \(1) Sparse unions are unsupported.

* \(2) Each element in the input is appended to the output iff it is non-null.
* \(1) Each element in the input is appended to the output iff it is non-null.
If the input is a record batch or table, any null value in a column drops
the entire row.

* \(3) Each element in input 1 (the values) is appended to the output iff
* \(2) Each element in input 1 (the values) is appended to the output iff
the corresponding element in input 2 (the filter) is true. How
nulls in the filter are handled can be configured using FilterOptions.

* \(4) For each element *i* in input 2 (the indices), the *i*'th element
* \(3) For each element *i* in input 2 (the indices), the *i*'th element
in input 1 (the values) is appended to the output.

Containment tests
Expand Down

0 comments on commit cba22d6

Please sign in to comment.