Skip to content

Commit

Permalink
Last-minute improvements:
Browse files Browse the repository at this point in the history
- avoid copying type codes array
- remove skipping of sliced tests on dense unions
- add union-take test with null indices
  • Loading branch information
pitrou committed Aug 10, 2023
1 parent cddf9c8 commit 200fd5f
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 30 deletions.
7 changes: 4 additions & 3 deletions cpp/src/arrow/compute/kernels/vector_selection_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -750,13 +750,14 @@ struct SparseUnionSelectionImpl
LIFT_BASE_MEMBERS();

TypedBufferBuilder<int8_t> child_id_buffer_builder_;
std::vector<int8_t> type_codes_;
const int8_t type_code_for_null_;

SparseUnionSelectionImpl(KernelContext* ctx, const ExecSpan& batch,
int64_t output_length, ExecResult* out)
: Base(ctx, batch, output_length, out),
child_id_buffer_builder_(ctx->memory_pool()),
type_codes_(checked_cast<const UnionType&>(*this->values.type).type_codes()) {}
type_code_for_null_(
checked_cast<const UnionType&>(*this->values.type).type_codes()[0]) {}

template <typename Adapter>
Status GenerateOutput() {
Expand All @@ -768,7 +769,7 @@ struct SparseUnionSelectionImpl
return Status::OK();
},
[&]() {
child_id_buffer_builder_.UnsafeAppend(type_codes_[0]);
child_id_buffer_builder_.UnsafeAppend(type_code_for_null_);
return Status::OK();
}));
return Status::OK();
Expand Down
45 changes: 18 additions & 27 deletions cpp/src/arrow/compute/kernels/vector_selection_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,11 +282,6 @@ class TestFilterKernel : public ::testing::Test {
const std::shared_ptr<Array>& expected) {
DoAssertFilter(values, filter, expected);

if (values->type_id() == Type::DENSE_UNION) {
// Concatenation of dense union not supported
return;
}

// Check slicing: add M(=3) dummy values at the start and end of `values`,
// add N(=2) dummy values at the start and end of `filter`.
ARROW_SCOPED_TRACE("for sliced values and filter");
Expand Down Expand Up @@ -785,17 +780,6 @@ TEST_F(TestFilterKernelWithUnion, FilterUnion) {
[2, null]
])");
this->AssertFilter(union_type, union_json, "[1, 1, 1, 1, 1, 1, 1]", union_json);

// Sliced
// (check this manually as concatenation of dense unions isn't supported: ARROW-4975)
auto values = ArrayFromJSON(union_type, union_json)->Slice(2, 4);
auto filter = ArrayFromJSON(boolean(), "[0, 1, 1, null, 0, 1, 1]")->Slice(2, 4);
auto expected = ArrayFromJSON(union_type, R"([
[5, "hello"],
[2, null],
[2, 111]
])");
this->AssertFilter(values, filter, expected);
}
}

Expand Down Expand Up @@ -1029,13 +1013,11 @@ void CheckTake(const std::shared_ptr<DataType>& type, const std::string& values_
AssertTakeArrays(values, indices, expected);

// Check sliced values
if (type->id() != Type::DENSE_UNION) {
ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(type, 2));
ASSERT_OK_AND_ASSIGN(auto values_sliced,
Concatenate({values_filler, values, values_filler}));
values_sliced = values_sliced->Slice(2, values->length());
AssertTakeArrays(values_sliced, indices, expected);
}
ASSERT_OK_AND_ASSIGN(auto values_filler, MakeArrayOfNull(type, 2));
ASSERT_OK_AND_ASSIGN(auto values_sliced,
Concatenate({values_filler, values, values_filler}));
values_sliced = values_sliced->Slice(2, values->length());
AssertTakeArrays(values_sliced, indices, expected);

// Check sliced indices
ASSERT_OK_AND_ASSIGN(auto zero, MakeScalar(index_type, int8_t{0}));
Expand Down Expand Up @@ -1484,30 +1466,30 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) {
{dense_union({field("a", int32()), field("b", utf8())}, {2, 5}),
sparse_union({field("a", int32()), field("b", utf8())}, {2, 5})}) {
auto union_json = R"([
[2, null],
[2, 222],
[2, null],
[5, "hello"],
[5, "eh"],
[2, null],
[2, 111],
[5, null]
])";
CheckTake(union_type, union_json, "[]", "[]");
CheckTake(union_type, union_json, "[3, 1, 3, 1, 3]", R"([
CheckTake(union_type, union_json, "[3, 0, 3, 0, 3]", R"([
[5, "eh"],
[2, 222],
[5, "eh"],
[2, 222],
[5, "eh"]
])");
CheckTake(union_type, union_json, "[4, 2, 1, 6]", R"([
CheckTake(union_type, union_json, "[4, 2, 0, 6]", R"([
[2, null],
[5, "hello"],
[2, 222],
[5, null]
])");
CheckTake(union_type, union_json, "[0, 1, 2, 3, 4, 5, 6]", union_json);
CheckTake(union_type, union_json, "[0, 2, 2, 2, 2, 2, 2]", R"([
CheckTake(union_type, union_json, "[1, 2, 2, 2, 2, 2, 2]", R"([
[2, null],
[5, "hello"],
[5, "hello"],
Expand All @@ -1516,6 +1498,15 @@ TEST_F(TestTakeKernelWithUnion, TakeUnion) {
[5, "hello"],
[5, "hello"]
])");
CheckTake(union_type, union_json, "[0, null, 1, null, 2, 2, 2]", R"([
[2, 222],
[2, null],
[2, null],
[2, null],
[5, "hello"],
[5, "hello"],
[5, "hello"]
])");
}
}

Expand Down

0 comments on commit 200fd5f

Please sign in to comment.