Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-36905: [C++] Add support for SparseUnion to selection functions #36906

Merged
merged 6 commits into from
Aug 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions cpp/src/arrow/compute/kernels/vector_selection_filter_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "arrow/chunked_array.h"
#include "arrow/compute/api_vector.h"
#include "arrow/compute/exec.h"
#include "arrow/compute/kernel.h"
#include "arrow/compute/kernels/codegen_internal.h"
#include "arrow/compute/kernels/vector_selection_filter_internal.h"
#include "arrow/compute/kernels/vector_selection_internal.h"
Expand All @@ -49,8 +50,7 @@ using internal::CopyBitmap;
using internal::CountSetBits;
using internal::OptionalBitBlockCounter;

namespace compute {
namespace internal {
namespace compute::internal {

namespace {

Expand Down Expand Up @@ -863,20 +863,29 @@ Status ExtensionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult
return Status::OK();
}

Status StructFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
// Transform filter to selection indices and then use Take.
// Transform filter to selection indices and then use Take.
Status FilterWithTakeExec(const ArrayKernelExec& take_exec, KernelContext* ctx,
const ExecSpan& batch, ExecResult* out) {
std::shared_ptr<ArrayData> indices;
RETURN_NOT_OK(GetTakeIndices(batch[1].array,
FilterState::Get(ctx).null_selection_behavior,
ctx->memory_pool())
.Value(&indices));
KernelContext take_ctx(*ctx);
TakeState state{TakeOptions::NoBoundsCheck()};
take_ctx.SetState(&state);
ExecSpan take_batch({batch[0], ArraySpan(*indices)}, batch.length);
return take_exec(&take_ctx, take_batch, out);
}

Datum result;
RETURN_NOT_OK(Take(batch[0].array.ToArrayData(), Datum(indices),
TakeOptions::NoBoundsCheck(), ctx->exec_context())
.Value(&result));
out->value = result.array();
return Status::OK();
// Due to the special treatment with their Take kernels, we filter Struct and SparseUnion
// arrays by transforming filter to selection indices and call Take.
Status StructFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return FilterWithTakeExec(StructTakeExec, ctx, batch, out);
}

Status SparseUnionFilterExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return FilterWithTakeExec(SparseUnionTakeExec, ctx, batch, out);
}

// ----------------------------------------------------------------------
Expand Down Expand Up @@ -1047,6 +1056,7 @@ void PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
{InputType(Type::LARGE_LIST), plain_filter, LargeListFilterExec},
{InputType(Type::FIXED_SIZE_LIST), plain_filter, FSLFilterExec},
{InputType(Type::DENSE_UNION), plain_filter, DenseUnionFilterExec},
{InputType(Type::SPARSE_UNION), plain_filter, SparseUnionFilterExec},
{InputType(Type::STRUCT), plain_filter, StructFilterExec},
{InputType(Type::MAP), plain_filter, MapFilterExec},

Expand All @@ -1064,12 +1074,12 @@ void PopulateFilterKernels(std::vector<SelectionKernelData>* out) {
{InputType(Type::LARGE_LIST), ree_filter, LargeListFilterExec},
{InputType(Type::FIXED_SIZE_LIST), ree_filter, FSLFilterExec},
{InputType(Type::DENSE_UNION), ree_filter, DenseUnionFilterExec},
{InputType(Type::SPARSE_UNION), ree_filter, SparseUnionFilterExec},
{InputType(Type::STRUCT), ree_filter, StructFilterExec},
{InputType(Type::MAP), ree_filter, MapFilterExec},
};
}

} // namespace internal
} // namespace compute
} // namespace compute::internal

} // namespace arrow
73 changes: 66 additions & 7 deletions cpp/src/arrow/compute/kernels/vector_selection_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ namespace arrow {

using internal::CheckIndexBounds;

namespace compute {
namespace internal {
namespace compute::internal {

void RegisterSelectionFunction(const std::string& name, FunctionDoc doc,
VectorKernel base_kernel,
Expand Down Expand Up @@ -171,9 +170,6 @@ void VisitPlainxREEFilterOutputSegments(

namespace {

using FilterState = OptionsWrapper<FilterOptions>;
using TakeState = OptionsWrapper<TakeOptions>;

// ----------------------------------------------------------------------
// Implement take for other data types where there is less performance
// sensitivity by visiting the selected indices.
Expand Down Expand Up @@ -741,6 +737,66 @@ struct DenseUnionSelectionImpl
}
};

// We need a slightly different approach for SparseUnion. For Take, we can
// invoke Take on each child's data with boundschecking disabled. For
// Filter on the other hand, if we naively call Filter on each child, then the
// filter output length will have to be redundantly computed. Thus, for Filter
// we instead convert the filter to selection indices and then invoke take.

// SparseUnion selection implementation. ONLY used for Take
struct SparseUnionSelectionImpl
: public Selection<SparseUnionSelectionImpl, SparseUnionType> {
using Base = Selection<SparseUnionSelectionImpl, SparseUnionType>;
LIFT_BASE_MEMBERS();

TypedBufferBuilder<int8_t> child_id_buffer_builder_;
const int8_t type_code_for_null_;

SparseUnionSelectionImpl(KernelContext* ctx, const ExecSpan& batch,
int64_t output_length, ExecResult* out)
: Base(ctx, batch, output_length, out),
child_id_buffer_builder_(ctx->memory_pool()),
type_code_for_null_(
checked_cast<const UnionType&>(*this->values.type).type_codes()[0]) {}

template <typename Adapter>
Status GenerateOutput() {
SparseUnionArray typed_values(this->values.ToArrayData());
Adapter adapter(this);
RETURN_NOT_OK(adapter.Generate(
[&](int64_t index) {
child_id_buffer_builder_.UnsafeAppend(typed_values.type_code(index));
return Status::OK();
},
[&]() {
child_id_buffer_builder_.UnsafeAppend(type_code_for_null_);
return Status::OK();
}));
return Status::OK();
}

Status Init() override {
RETURN_NOT_OK(child_id_buffer_builder_.Reserve(output_length));
return Status::OK();
}

Status Finish() override {
ARROW_ASSIGN_OR_RAISE(auto child_ids_buffer, child_id_buffer_builder_.Finish());
SparseUnionArray typed_values(this->values.ToArrayData());
auto num_fields = typed_values.num_fields();
auto num_rows = child_ids_buffer->size();
BufferVector buffers{nullptr, std::move(child_ids_buffer)};
*out = ArrayData(typed_values.type(), num_rows, std::move(buffers), 0);
out->child_data.reserve(num_fields);
for (auto i = 0; i < num_fields; i++) {
ARROW_ASSIGN_OR_RAISE(auto child_datum,
Take(*typed_values.field(i), *this->selection.ToArrayData()));
out->child_data.emplace_back(std::move(child_datum).array());
}
return Status::OK();
}
};

struct FSLSelectionImpl : public Selection<FSLSelectionImpl, FixedSizeListType> {
Int64Builder child_index_builder;

Expand Down Expand Up @@ -909,6 +965,10 @@ Status DenseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult*
return TakeExec<DenseUnionSelectionImpl>(ctx, batch, out);
}

Status SparseUnionTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<SparseUnionSelectionImpl>(ctx, batch, out);
}

Status StructTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<StructSelectionImpl>(ctx, batch, out);
}
Expand All @@ -917,6 +977,5 @@ Status MapTakeExec(KernelContext* ctx, const ExecSpan& batch, ExecResult* out) {
return TakeExec<ListSelectionImpl<MapType>>(ctx, batch, out);
}

} // namespace internal
} // namespace compute
} // namespace compute::internal
} // namespace arrow
13 changes: 7 additions & 6 deletions cpp/src/arrow/compute/kernels/vector_selection_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,12 @@
#include "arrow/compute/exec.h"
#include "arrow/compute/function.h"
#include "arrow/compute/kernel.h"
#include "arrow/compute/kernels/codegen_internal.h"

namespace arrow {
namespace compute {
namespace internal {
namespace arrow::compute::internal {

using FilterState = OptionsWrapper<FilterOptions>;
using TakeState = OptionsWrapper<TakeOptions>;

struct SelectionKernelData {
InputType value_type;
Expand Down Expand Up @@ -82,9 +84,8 @@ Status ListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status LargeListTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status FSLTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status DenseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status SparseUnionTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status StructTakeExec(KernelContext*, const ExecSpan&, ExecResult*);
Status MapTakeExec(KernelContext*, const ExecSpan&, ExecResult*);

} // namespace internal
} // namespace compute
} // namespace arrow
} // namespace arrow::compute::internal
Original file line number Diff line number Diff line change
Expand Up @@ -844,6 +844,7 @@ void PopulateTakeKernels(std::vector<SelectionKernelData>* out) {
{InputType(Type::LARGE_LIST), take_indices, LargeListTakeExec},
{InputType(Type::FIXED_SIZE_LIST), take_indices, FSLTakeExec},
{InputType(Type::DENSE_UNION), take_indices, DenseUnionTakeExec},
{InputType(Type::SPARSE_UNION), take_indices, SparseUnionTakeExec},
{InputType(Type::STRUCT), take_indices, StructTakeExec},
{InputType(Type::MAP), take_indices, MapTakeExec},
};
Expand Down
80 changes: 40 additions & 40 deletions cpp/src/arrow/compute/kernels/vector_selection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,6 @@ class TestFilterKernel : public ::testing::Test {
const std::shared_ptr<Array>& expected) {
DoAssertFilter(values, filter, expected);

if (values->type_id() == Type::DENSE_UNION) {
// Concatenation of dense union not supported
return;
}

// Check slicing: add M(=3) dummy values at the start and end of `values`,
// add N(=2) dummy values at the start and end of `filter`.
ARROW_SCOPED_TRACE("for sliced values and filter");
Expand Down Expand Up @@ -759,8 +754,10 @@ TEST_F(TestFilterKernelWithStruct, FilterStruct) {
class TestFilterKernelWithUnion : public TestFilterKernel {};

TEST_F(TestFilterKernelWithUnion, FilterUnion) {
auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
auto union_json = R"([
for (const auto& union_type :
{dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
auto union_json = R"([
[2, null],
[2, 222],
[5, "hello"],
Expand All @@ -769,31 +766,21 @@ TEST_F(TestFilterKernelWithUnion, FilterUnion) {
[2, 111],
[5, null]
])";
this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]");
this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([
this->AssertFilter(union_type, union_json, "[0, 0, 0, 0, 0, 0, 0]", "[]");
this->AssertFilter(union_type, union_json, "[0, 1, 1, null, 0, 1, 1]", R"([
[2, 222],
[5, "hello"],
[2, null],
[2, 111],
[5, null]
])");
this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([
this->AssertFilter(union_type, union_json, "[1, 0, 1, 0, 1, 0, 0]", R"([
[2, null],
[5, "hello"],
[2, null]
])");
this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json);

// Sliced
// (check this manually as concatenation of dense unions isn't supported: ARROW-4975)
auto values = ArrayFromJSON(union_type, union_json)->Slice(2, 4);
auto filter = ArrayFromJSON(boolean(), "[0, 1, 1, null, 0, 1, 1]")->Slice(2, 4);
auto expected = ArrayFromJSON(union_type, R"([
[5, "hello"],
[2, null],
[2, 111]
])");
this->AssertFilter(values, filter, expected);
this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json);
}
}

class TestFilterKernelWithRecordBatch : public TestFilterKernel {
Expand Down Expand Up @@ -1026,13 +1013,11 @@ void CheckTake(const std::shared_ptr<DataType>& type, const std::string& values_
AssertTakeArrays(values, indices, expected);

// Check sliced values
if (type->id() != Type::DENSE_UNION) {
ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(type, 2));
ASSERT_OK_AND_ASSIGN(auto values_sliced,
Concatenate({values_filler, values, values_filler}));
values_sliced = values_sliced->Slice(2, values->length());
AssertTakeArrays(values_sliced, indices, expected);
}
ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(type, 2));
ASSERT_OK_AND_ASSIGN(auto values_sliced,
Concatenate({values_filler, values, values_filler}));
values_sliced = values_sliced->Slice(2, values->length());
AssertTakeArrays(values_sliced, indices, expected);

// Check sliced indices
ASSERT_OK_AND_ASSIGN(auto zero, MakeScalar(index_type, int8_t{0}));
Expand Down Expand Up @@ -1477,32 +1462,34 @@ TEST_F(TestTakeKernelWithStruct, TakeStruct) {
class TestTakeKernelWithUnion : public TestTakeKernelTyped<UnionType> {};

TEST_F(TestTakeKernelWithUnion, TakeUnion) {
auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
auto union_json = R"([
[2, null],
for (const auto& union_type :
{dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
auto union_json = R"([
[2, 222],
[2, null],
[5, "hello"],
[5, "eh"],
[2, null],
[2, 111],
[5, null]
])";
CheckTake(union_type, union_json, "[]", "[]");
CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([
CheckTake(union_type, union_json, "[]", "[]");
CheckTake(union_type, union_json, "[3, 0, 3, 0, 3]", R"([
[5, "eh"],
[2, 222],
[5, "eh"],
[2, 222],
[5, "eh"]
])");
CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([
CheckTake(union_type, union_json, "[4, 2, 0, 6]", R"([
[2, null],
[5, "hello"],
[2, 222],
[5, null]
])");
CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json);
CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json);
CheckTake(union_type, union_json, "[1, 2, 2, 2, 2, 2, 2]", R"([
[2, null],
[5, "hello"],
[5, "hello"],
Expand All @@ -1511,6 +1498,16 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) {
[5, "hello"],
[5, "hello"]
])");
CheckTake(union_type, union_json, "[0, null, 1, null, 2, 2, 2]", R"([
[2, 222],
[2, null],
[2, null],
[2, null],
[5, "hello"],
[5, "hello"],
[5, "hello"]
])");
}
}

class TestPermutationsWithTake : public ::testing::Test {
Expand Down Expand Up @@ -2162,8 +2159,10 @@ TEST_F(TestDropNullKernelWithStruct, DropNullStruct) {
class TestDropNullKernelWithUnion : public TestDropNullKernelTyped<UnionType> {};

TEST_F(TestDropNullKernelWithUnion, DropNullUnion) {
auto union_type = dense_union({field("a", int32()), field("b", utf8())}, {2, 5});
auto union_json = R"([
for (const auto& union_type :
{dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
auto union_json = R"([
[2, null],
[2, 222],
[5, "hello"],
Expand All @@ -2172,7 +2171,8 @@ TEST_F(TestDropNullKernelWithUnion, DropNullUnion) {
[2, 111],
[5, null]
])";
CheckDropNull(union_type, union_json, union_json);
CheckDropNull(union_type, union_json, union_json);
}
}

class TestDropNullKernelWithRecordBatch : public TestDropNullKernelTyped<RecordBatch> {
Expand Down
Loading
Loading