Skip to content

Commit

Permalink
add test
Browse files Browse the repository at this point in the history
  • Loading branch information
R-JunmingChen committed Nov 15, 2023
1 parent 60bcd37 commit 186b047
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 10 deletions.
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ set(ARROW_SRCS
util/debug.cc
util/decimal.cc
util/delimiting.cc
util/dict_util.cc
util/formatting.cc
util/future.cc
util/hashing.cc
Expand Down
46 changes: 46 additions & 0 deletions cpp/src/arrow/array/array_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,21 @@ class TestArray : public ::testing::Test {
MemoryPool* pool_;
};

void CheckDictionaryNullCount(const std::shared_ptr<DataType>& dict_type,
const std::string& input_dictionary_json,
const std::string& input_index_json,
const int64_t& expected_null_count,
const int64_t& expected_logical_null_count,
bool expected_may_have_nulls,
bool expected_may_have_logical_nulls) {
std::shared_ptr<arrow::Array> arr = DictArrayFromJSON(dict_type, input_index_json, input_dictionary_json);

ASSERT_EQ(expected_null_count, arr->null_count());
ASSERT_EQ(expected_logical_null_count, arr->ComputeLogicalNullCount());
ASSERT_EQ(expected_may_have_nulls, arr->data()->MayHaveNulls());
ASSERT_EQ(expected_may_have_logical_nulls, arr->data()->MayHaveLogicalNulls());
}

TEST_F(TestArray, TestNullCount) {
// These are placeholders
auto data = std::make_shared<Buffer>(nullptr, 0);
Expand Down Expand Up @@ -127,6 +142,37 @@ TEST_F(TestArray, TestNullCount) {
ASSERT_EQ(0, ree_no_nulls->ComputeLogicalNullCount());
ASSERT_FALSE(ree_no_nulls->data()->MayHaveNulls());
ASSERT_FALSE(ree_no_nulls->data()->MayHaveLogicalNulls());

// dictionary type
std::shared_ptr<arrow::DataType> type;
std::shared_ptr<arrow::DataType> dict_type;

for (const auto& index_type : all_dictionary_index_types()) {
ARROW_SCOPED_TRACE("index_type = ", index_type->ToString());

type = boolean();
dict_type = dictionary(index_type, type);
// no null value
CheckDictionaryNullCount(dict_type, "[]", "[]", 0, 0, false, false);
CheckDictionaryNullCount(dict_type, "[true, false]", "[0, 1, 0]", 0, 0, false, false);

// only indices contain null value
CheckDictionaryNullCount(dict_type, "[true, false]", "[null, 0, 1]", 1, 1, true,
true);
CheckDictionaryNullCount(dict_type, "[true, false]", "[null, null]", 2, 2, true,
true);

// only dictionary contains null value
CheckDictionaryNullCount(dict_type, "[null, true]", "[]", 0, 0, false, true);
CheckDictionaryNullCount(dict_type, "[null, true, false]", "[0, 1, 0]", 0, 2, false,
true);

// both indices and dictionary contain null value
CheckDictionaryNullCount(dict_type, "[null, true, false]", "[0, 1, 0, null]", 1, 3,
true, true);
CheckDictionaryNullCount(dict_type, "[null, true, null, false]", "[null, 1, 0, 2, 3]",
1, 3, true, true);
}
}

TEST_F(TestArray, TestSlicePreservesAllNullCount) {
Expand Down
14 changes: 11 additions & 3 deletions cpp/src/arrow/array/data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,12 @@
#include "arrow/type_traits.h"
#include "arrow/util/binary_view_util.h"
#include "arrow/util/bitmap_ops.h"
#include "arrow/util/dict_util.h"
#include "arrow/util/logging.h"
#include "arrow/util/macros.h"
#include "arrow/util/ree_util.h"
#include "arrow/util/slice_util_internal.h"
#include "arrow/util/union_util.h"
#include "arrow/util/dict_util.h"

namespace arrow {

Expand Down Expand Up @@ -94,6 +94,10 @@ bool RunEndEncodedMayHaveLogicalNulls(const ArrayData& data) {
return ArraySpan(data).MayHaveLogicalNulls();
}

bool DictionaryMayHaveLogicalNulls(const ArrayData& data) {
return ArraySpan(data).MayHaveLogicalNulls();
}

BufferSpan PackVariadicBuffers(util::span<const std::shared_ptr<Buffer>> buffers) {
return {const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(buffers.data())),
static_cast<int64_t>(buffers.size() * sizeof(std::shared_ptr<Buffer>))};
Expand Down Expand Up @@ -175,7 +179,7 @@ int64_t ArrayData::GetNullCount() const {
}

int64_t ArrayData::ComputeLogicalNullCount() const {
if (this->buffers[0]) {
if (this->buffers[0] && this->type->id() != Type::DICTIONARY) {
return GetNullCount();
}
return ArraySpan(*this).ComputeLogicalNullCount();
Expand Down Expand Up @@ -521,7 +525,7 @@ int64_t ArraySpan::ComputeLogicalNullCount() const {
if (t == Type::RUN_END_ENCODED) {
return ree_util::LogicalNullCount(*this);
}
if(t == Type::DICTIONARY){
if (t == Type::DICTIONARY) {
return dict_util::LogicalNullCount(*this);
}
return GetNullCount();
Expand Down Expand Up @@ -621,6 +625,10 @@ bool ArraySpan::RunEndEncodedMayHaveLogicalNulls() const {
return ree_util::ValuesArray(*this).MayHaveLogicalNulls();
}

bool ArraySpan::DictionaryMayHaveLogicalNulls() const {
return this->GetNullCount() != 0 || this->dictionary().GetNullCount() != 0;
}

// ----------------------------------------------------------------------
// Implement internal::GetArrayView

Expand Down
14 changes: 11 additions & 3 deletions cpp/src/arrow/array/data.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ struct ArrayData;

namespace internal {
// ----------------------------------------------------------------------
// Null handling for types without a validity bitmap
// Null handling for types without a validity bitmap and the dictionary type

ARROW_EXPORT bool IsNullSparseUnion(const ArrayData& data, int64_t i);
ARROW_EXPORT bool IsNullDenseUnion(const ArrayData& data, int64_t i);
ARROW_EXPORT bool IsNullRunEndEncoded(const ArrayData& data, int64_t i);

ARROW_EXPORT bool UnionMayHaveLogicalNulls(const ArrayData& data);
ARROW_EXPORT bool RunEndEncodedMayHaveLogicalNulls(const ArrayData& data);
ARROW_EXPORT bool DictionaryMayHaveLogicalNulls(const ArrayData& data);
} // namespace internal

// When slicing, we do not know the null count of the sliced range without
Expand Down Expand Up @@ -280,7 +281,7 @@ struct ARROW_EXPORT ArrayData {

/// \brief Return true if the validity bitmap may have 0's in it, or if the
/// child arrays (in the case of types without a validity bitmap) may have
/// nulls
/// nulls, or if the dictionary of dictionay array may have nulls.
///
/// This is not a drop-in replacement for MayHaveNulls, as historically
/// MayHaveNulls() has been used to check for the presence of a validity
Expand Down Expand Up @@ -325,6 +326,9 @@ struct ARROW_EXPORT ArrayData {
if (t == Type::RUN_END_ENCODED) {
return internal::RunEndEncodedMayHaveLogicalNulls(*this);
}
if (t == Type::DICTIONARY) {
return internal::DictionaryMayHaveLogicalNulls(*this);
}
return null_count.load() != 0;
}

Expand Down Expand Up @@ -505,7 +509,7 @@ struct ARROW_EXPORT ArraySpan {

/// \brief Return true if the validity bitmap may have 0's in it, or if the
/// child arrays (in the case of types without a validity bitmap) may have
/// nulls
/// nulls, or if the dictionary of dictionay array may have nulls.
///
/// \see ArrayData::MayHaveLogicalNulls
bool MayHaveLogicalNulls() const {
Expand All @@ -519,6 +523,9 @@ struct ARROW_EXPORT ArraySpan {
if (t == Type::RUN_END_ENCODED) {
return RunEndEncodedMayHaveLogicalNulls();
}
if (t == Type::DICTIONARY) {
return DictionaryMayHaveLogicalNulls();
}
return null_count != 0;
}

Expand Down Expand Up @@ -560,6 +567,7 @@ struct ARROW_EXPORT ArraySpan {

bool UnionMayHaveLogicalNulls() const;
bool RunEndEncodedMayHaveLogicalNulls() const;
bool DictionaryMayHaveLogicalNulls() const;
};

namespace internal {
Expand Down
7 changes: 3 additions & 4 deletions cpp/src/arrow/util/dict_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
// under the License.

#include "arrow/util/dict_util.h"
#include "array_dict.h"
#include "arrow/array/array_dict.h"
#include "arrow/util/bit_util.h"
#include "arrow/util/checked_cast.h"

Expand All @@ -33,10 +33,9 @@ int64_t LogicalNullCount(const ArraySpan& span) {
using CType = typename IndexArrowType::c_type;
const CType* indices_data = span.GetValues<CType>(1);
auto index_length = span.length;
CType dict_len = static_cast<CType>(span.dictionary().length);
int64_t null_count = 0;
for (int64_t i = 0; i < index_length; i++) {
if (!bit_util::GetBit(indices_null_bit_map, i)) {
if (indices_null_bit_map != nullptr && !bit_util::GetBit(indices_null_bit_map, i)) {
null_count++;
continue;
}
Expand All @@ -56,7 +55,7 @@ int64_t LogicalNullCount(const ArraySpan& span) {
return span.GetNullCount();
}

const auto& dict_array_type = internal::checked_cast<DictionaryType>(*span.type);
const auto& dict_array_type = internal::checked_cast<const DictionaryType&>(*span.type);
switch (dict_array_type.index_type()->id()) {
case Type::UINT8:
return LogicalNullCount<UInt8Type>(span);
Expand Down

0 comments on commit 186b047

Please sign in to comment.