Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-36905: [C++] Add support for SparseUnion to selection functions #36906

Merged
merged 6 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,7 @@ void PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
{InputType(Type::LARGE_LIST), plain_filter, LargeListFilterExec},
{InputType(Type::FIXED_SIZE_LIST), plain_filter, FSLFilterExec},
{InputType(Type::DENSE_UNION), plain_filter, DenseUnionFilterExec},
{InputType(Type::SPARSE_UNION), plain_filter, SparseUnionFilterExec},
{InputType(Type::STRUCT), plain_filter, StructFilterExec},
{InputType(Type::MAP), plain_filter, MapFilterExec},

Expand All @@ -1064,6 +1065,7 @@ void PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
{InputType(Type::LARGE_LIST), ree_filter, LargeListFilterExec},
{InputType(Type::FIXED_SIZE_LIST), ree_filter, FSLFilterExec},
{InputType(Type::DENSE_UNION), ree_filter, DenseUnionFilterExec},
{InputType(Type::SPARSE_UNION), ree_filter, SparseUnionFilterExec},
{InputType(Type::STRUCT), ree_filter, StructFilterExec},
{InputType(Type::MAP), ree_filter, MapFilterExec},
};
Expand Down
77 changes: 77 additions & 0 deletions cpp/src/arrow/compute/kernels/vector_selection_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -741,6 +741,63 @@ struct DenseUnionSelectionImpl
}
};

struct SparseUnionSelectionImpl
: public Selection<SparseUnionSelectionImpl, SparseUnionType> {
using Base = Selection<SparseUnionSelectionImpl, SparseUnionType>;
LIFT_BASE_MEMBERS();

TypedBufferBuilder<int8_t> child_id_buffer_builder_;
std::vector<int8_t> type_codes_;

SparseUnionSelectionImpl(KernelContext* ctx, const ExecSpan& batch,
int64_t output_length, ExecResult* out)
: Base(ctx, batch, output_length, out),
child_id_buffer_builder_(ctx->memory_pool()),
type_codes_(checked_cast<const UnionType&>(*this->values.type).type_codes()) {}

template <typename Adapter>
Status GenerateOutput() {
SparseUnionArray typed_values(this->values.ToArrayData());
Adapter adapter(this);
RETURN_NOT_OK(adapter.Generate(
[&](int64_t index) {
int8_t child_id = typed_values.child_id(index);
child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to be doing a pointless back-and-forth between type codes and child ids?

Suggested change
int8_t child_id = typed_values.child_id(index);
child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]);
child_id_buffer_builder_.UnsafeAppend(typed_values.type_code(index));

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

// TODO(jinshang): We use a naive approach for now: apply take for each child
// array. There is room for optimization because the unselected child arrays can
// have any value at this slot.
return Status::OK();
},
[&]() {
int8_t child_id = 0;
child_id_buffer_builder_.UnsafeAppend(type_codes_[child_id]);
return Status::OK();
}));
return Status::OK();
}

Status Init() override {
RETURN_NOT_OK(child_id_buffer_builder_.Reserve(output_length));
return Status::OK();
}

Status Finish() override {
ARROW_ASSIGN_OR_RAISE(auto child_ids_buffer, child_id_buffer_builder_.Finish());
SparseUnionArray typed_values(this->values.ToArrayData());
auto num_fields = typed_values.num_fields();
auto num_rows = child_ids_buffer->size();
BufferVector buffers{nullptr, std::move(child_ids_buffer)};
*out = ArrayData(typed_values.type(), num_rows, std::move(buffers), 0);
out->child_data.reserve(num_fields);
for (auto i = 0; i < num_fields; i++) {
ARROW_ASSIGN_OR_RAISE(auto child_datum,
Take(*typed_values.field(i), *this->selection.ToArrayData()));
out->child_data.emplace_back(std::move(child_datum).array());
}
return Status::OK();
}
};

struct FSLSelectionImpl : public Selection<FSLSelectionImpl, FixedSizeListType> {
Int64Builder child_index_builder;

Expand Down Expand Up @@ -863,6 +920,22 @@ Status DenseUnionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResul
return FilterExec<DenseUnionSelectionImpl>(ctx, batch, out);
}

Status SparseUnionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: move this into vector_filter_internal.cc along StructFilterExec? (can probably also share some code between them...)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved and extracted a FilterWithTakeExec.

// Transform filter to selection indices and then use Take.
std::shared_ptr<ArrayData> indices;
RETURN_NOT_OK(GetTakeIndices(batch[1].array,
FilterState::Get(ctx).null_selection_behavior,
ctx->memory_pool())
.Value(&indices));

Datum result;
RETURN_NOT_OK(Take(batch[0].array.ToArrayData(), Datum(indices),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps we can call SparseUnionTakeExec directly instead of going through the function lookup and execution machinery again?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

TakeOptions::NoBoundsCheck(), ctx->exec_context())
.Value(&result));
out->value = result.array();
return Status::OK();
}

Status MapFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return FilterExec<ListSelectionImpl<MapType>>(ctx, batch, out);
}
Expand Down Expand Up @@ -909,6 +982,10 @@ Status DenseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult*
return TakeExec<DenseUnionSelectionImpl>(ctx, batch, out);
}

Status SparseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<SparseUnionSelectionImpl>(ctx, batch, out);
}

Status StructTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<StructSelectionImpl>(ctx, batch, out);
}
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/compute/kernels/vector_selection_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Status ListFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
Status LargeListFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
Status FSLFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
Status DenseUnionFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
Status SparseUnionFilterExec(KernelContext*, const ExecSpan&, ExecResult*);
Status MapFilterExec(KernelContext*, const ExecSpan&, ExecResult*);

Status VarBinaryTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Expand All @@ -82,6 +83,7 @@ Status ListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status FSLTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status DenseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status SparseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status StructTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status MapTakeExec(KernelContext*, const ExecSpan&, ExecResult*);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,7 @@ void PopulateTakeKernels(std::vector<SelectionKernelData>* out) {
{InputType(Type::LARGE_LIST), take_indices, LargeListTakeExec},
{InputType(Type::FIXED_SIZE_LIST), take_indices, FSLTakeExec},
{InputType(Type::DENSE_UNION), take_indices, DenseUnionTakeExec},
{InputType(Type::SPARSE_UNION), take_indices, SparseUnionTakeExec},
{InputType(Type::STRUCT), take_indices, StructTakeExec},
{InputType(Type::MAP), take_indices, MapTakeExec},
};
Expand Down
53 changes: 31 additions & 22 deletions cpp/src/arrow/compute/kernels/vector_selection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -759,8 +759,10 @@ TEST_F(TestFilterKernelWithStruct, FilterStruct) {
class TestFilterKernelWithUnion : public TestFilterKernel {};

TEST_F(TestFilterKernelWithUnion, FilterUnion) {
auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
auto union_json = R"([
for (const auto& union_type :
{dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
auto union_json = R"([
[2, null],
[2, 222],
[5, "hello"],
Expand All @@ -769,31 +771,32 @@ TEST_F(TestFilterKernelWithUnion, FilterUnion) {
[2, 111],
[5, null]
])";
this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]");
this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([
this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]");
this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([
[2, 222],
[5, "hello"],
[2, null],
[2, 111],
[5, null]
])");
this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([
this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([
[2, null],
[5, "hello"],
[2, null]
])");
this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json);
this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json);

// Sliced
// (check this manually as concatenation of dense unions isn't supported: ARROW-4975)
auto values = ArrayFromJSON(union_type, union_json)->Slice(2, 4);
auto filter = ArrayFromJSON(boolean(), "[0, 1, 1, null, 0, 1, 1]")->Slice(2, 4);
auto expected = ArrayFromJSON(union_type, R"([
// Sliced
// (check this manually as concatenation of dense unions isn't supported: ARROW-4975)
auto values = ArrayFromJSON(union_type, union_json)->Slice(2, 4);
auto filter = ArrayFromJSON(boolean(), "[0, 1, 1, null, 0, 1, 1]")->Slice(2, 4);
auto expected = ArrayFromJSON(union_type, R"([
[5, "hello"],
[2, null],
[2, 111]
])");
this->AssertFilter(values, filter, expected);
this->AssertFilter(values, filter, expected);
}
}

class TestFilterKernelWithRecordBatch : public TestFilterKernel {
Expand Down Expand Up @@ -1477,8 +1480,10 @@ TEST_F(TestTakeKernelWithStruct, TakeStruct) {
class TestTakeKernelWithUnion : public TestTakeKernelTyped<UnionType> {};

TEST_F(TestTakeKernelWithUnion, TakeUnion) {
auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
auto union_json = R"([
for (const auto& union_type :
{dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
auto union_json = R"([
[2, null],
[2, 222],
[5, "hello"],
Expand All @@ -1487,22 +1492,22 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) {
[2, 111],
[5, null]
])";
CheckTake(union_type, union_json, "[]", "[]");
CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([
CheckTake(union_type, union_json, "[]", "[]");
CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([
[5, "eh"],
[2, 222],
[5, "eh"],
[2, 222],
[5, "eh"]
])");
CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([
CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([
[2, null],
[5, "hello"],
[2, 222],
[5, null]
])");
CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json);
CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json);
CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
[2, null],
[5, "hello"],
[5, "hello"],
Expand All @@ -1511,6 +1516,7 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) {
[5, "hello"],
[5, "hello"]
])");
}
}

class TestPermutationsWithTake : public ::testing::Test {
Expand Down Expand Up @@ -2162,8 +2168,10 @@ TEST_F(TestDropNullKernelWithStruct, DropNullStruct) {
class TestDropNullKernelWithUnion : public TestDropNullKernelTyped<UnionType> {};

TEST_F(TestDropNullKernelWithUnion, DropNullUnion) {
auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
auto union_json = R"([
for (const auto& union_type :
{dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
auto union_json = R"([
[2, null],
[2, 222],
[5, "hello"],
Expand All @@ -2172,7 +2180,8 @@ TEST_F(TestDropNullKernelWithUnion, DropNullUnion) {
[2, 111],
[5, null]
])";
CheckDropNull(union_type, union_json, union_json);
CheckDropNull(union_type, union_json, union_json);
}
}

class TestDropNullKernelWithRecordBatch : public TestDropNullKernelTyped<RecordBatch> {
Expand Down
18 changes: 8 additions & 10 deletions docs/source/cpp/compute.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1680,28 +1680,26 @@ These functions select and return a subset of their input.
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
| Function name | Arity | Input type 1 | Input type 2 | Output type | Options class | Notes |
+===============+========+==============+==============+==============+=========================+===========+
| array_filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(1) \(3) |
| array_filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(2) |
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
| array_take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(1) \(4) |
| array_take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(3) |
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
| drop_null | Unary | Any | - | Input type 1 | | \(1) \(2) |
| drop_null | Unary | Any | - | Input type 1 | | \(1) |
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
| filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(1) \(3) |
| filter | Binary | Any | Boolean | Input type 1 | :struct:`FilterOptions` | \(3) |
+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+
| take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(1) \(4) |
| take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(4) |
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose this should be

Suggested change
| take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(4) |
| take | Binary | Any | Integer | Input type 1 | :struct:`TakeOptions` | \(3) |

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed, and the previous line was also wrong.

+---------------+--------+--------------+--------------+--------------+-------------------------+-----------+

* \(1) Sparse unions are unsupported.

* \(2) Each element in the input is appended to the output iff it is non-null.
* \(1) Each element in the input is appended to the output iff it is non-null.
If the input is a record batch or table, any null value in a column drops
the entire row.

* \(3) Each element in input 1 (the values) is appended to the output iff
* \(2) Each element in input 1 (the values) is appended to the output iff
the corresponding element in input 2 (the filter) is true. How
nulls in the filter are handled can be configured using FilterOptions.

* \(4) For each element *i* in input 2 (the indices), the *i*'th element
* \(3) For each element *i* in input 2 (the indices), the *i*'th element
in input 1 (the values) is appended to the output.

Containment tests
Expand Down
Loading