Skip to content

Commit

Permalink
response to review comments & optimize repeat
Browse files Browse the repository at this point in the history
  • Loading branch information
rhdong committed Sep 24, 2024
1 parent 92890e1 commit 224742c
Show file tree
Hide file tree
Showing 6 changed files with 126 additions and 70 deletions.
123 changes: 65 additions & 58 deletions cpp/include/raft/core/bitset.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -63,49 +63,75 @@ _RAFT_DEVICE void bitset_view<bitset_t, index_t>::set(const index_t sample_index
}

template <typename bitset_t, typename index_t>
struct bitset_copy_functor {
const bitset_t* bitset_ptr;
bitset_t* output_device_ptr;
index_t valid_bits;
index_t bits_per_element;
index_t total_bits;

bitset_copy_functor(const bitset_t* _bitset_ptr,
bitset_t* _output_device_ptr,
index_t _valid_bits,
index_t _bits_per_element,
index_t _total_bits)
: bitset_ptr(_bitset_ptr),
output_device_ptr(_output_device_ptr),
valid_bits(_valid_bits),
bits_per_element(_bits_per_element),
total_bits(_total_bits)
{
}
void bitset_view<bitset_t, index_t>::count(const raft::resources& res,
raft::device_scalar_view<index_t> count_gpu_scalar) const
{
auto max_len = raft::make_host_scalar_view<const index_t, index_t>(&bitset_len_);
auto values = raft::make_device_vector_view<const bitset_t, index_t>(bitset_ptr_, n_elements());
raft::popc(res, values, max_len, count_gpu_scalar);
}

template <typename bitset_t, typename index_t>
RAFT_KERNEL bitset_repeat_kernel(const bitset_t* src,
bitset_t* output,
index_t src_bit_len,
index_t repeat_times)
{
constexpr index_t bits_per_element = sizeof(bitset_t) * 8;
int output_idx = blockIdx.x * blockDim.x + threadIdx.x;

index_t total_bits = src_bit_len * repeat_times;
index_t output_size = (total_bits + bits_per_element - 1) / bits_per_element;
index_t src_size = (src_bit_len + bits_per_element - 1) / bits_per_element;

__device__ void operator()(index_t i)
{
if (i < total_bits) {
index_t src_bit_index = i % valid_bits;
index_t dst_bit_index = i;
if (output_idx < output_size) {
bitset_t result = 0;
index_t bit_written = 0;

index_t src_element_index = src_bit_index / bits_per_element;
index_t src_bit_offset = src_bit_index % bits_per_element;
index_t start_bit = output_idx * bits_per_element;

index_t dst_element_index = dst_bit_index / bits_per_element;
index_t dst_bit_offset = dst_bit_index % bits_per_element;
while (bit_written < bits_per_element && start_bit + bit_written < total_bits) {
index_t bit_idx = (start_bit + bit_written) % src_bit_len;
index_t src_word_idx = bit_idx / bits_per_element;
index_t src_offset = bit_idx % bits_per_element;

bitset_t src_element = bitset_ptr[src_element_index];
bitset_t src_bit = (src_element >> src_bit_offset) & 1;
index_t remaining_bits = min(bits_per_element - bit_written, src_bit_len - bit_idx);

if (src_bit) {
atomicOr(output_device_ptr + dst_element_index, bitset_t(1) << dst_bit_offset);
} else {
atomicAnd(output_device_ptr + dst_element_index, ~(bitset_t(1) << dst_bit_offset));
bitset_t src_value = (src[src_word_idx] >> src_offset);

if (src_offset + remaining_bits > bits_per_element) {
bitset_t next_value = src[(src_word_idx + 1) % src_size];
src_value |= (next_value << (bits_per_element - src_offset));
}
src_value &= ((bitset_t{1} << remaining_bits) - 1);
result |= (src_value << bit_written);
bit_written += remaining_bits;
}
output[output_idx] = result;
}
};
}

template <typename bitset_t, typename index_t>
void bitset_repeat(raft::resources const& handle,
const bitset_t* d_src,
bitset_t* d_output,
index_t src_bit_len,
index_t repeat_times)
{
if (src_bit_len == 0 || repeat_times == 0) return;
auto stream = resource::get_cuda_stream(handle);

constexpr index_t bits_per_element = sizeof(bitset_t) * 8;
const index_t total_bits = src_bit_len * repeat_times;
const index_t output_size = (total_bits + bits_per_element - 1) / bits_per_element;

int threadsPerBlock = 128;
int blocksPerGrid = (output_size + threadsPerBlock - 1) / threadsPerBlock;
bitset_repeat_kernel<<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
d_src, d_output, src_bit_len, repeat_times);

return;
}

template <typename bitset_t, typename index_t>
void bitset_view<bitset_t, index_t>::repeat(const raft::resources& res,
Expand All @@ -125,37 +151,18 @@ void bitset_view<bitset_t, index_t>::repeat(const raft::resources& res,
raft::resource::get_cuda_stream(res));
}
} else {
index_t valid_bits = bitset_len_;
index_t total_bits = valid_bits * times;
index_t output_row_elements = (total_bits + bits_per_element - 1) / bits_per_element;
thrust::for_each_n(thrust_policy,
thrust::counting_iterator<index_t>(0),
total_bits,
bitset_copy_functor<bitset_t, index_t>(
bitset_ptr_, output_device_ptr, valid_bits, bits_per_element, total_bits));
bitset_repeat(res, bitset_ptr_, output_device_ptr, bitset_len_, times);
}
}

template <typename bitset_t, typename index_t>
double bitset_view<bitset_t, index_t>::sparsity(const raft::resources& res) const
{
index_t nnz_h = 0;
index_t size_h = this->size();
auto stream = raft::resource::get_cuda_stream(res);

if (0 == size_h) { return static_cast<double>(1.0); }
index_t count_h = this->count(res);

rmm::device_scalar<index_t> nnz(0, stream);

auto vector_view = raft::make_device_vector_view<const bitset_t, index_t>(data(), n_elements());
auto nnz_view = raft::make_device_scalar_view<index_t>(nnz.data());
auto size_view = raft::make_host_scalar_view<index_t>(&size_h);

raft::popc(res, vector_view, size_view, nnz_view);
raft::copy(&nnz_h, nnz.data(), 1, stream);

raft::resource::sync_stream(res, stream);
return static_cast<double>((1.0 * (size_h - nnz_h)) / (1.0 * size_h));
return static_cast<double>((1.0 * (size_h - count_h)) / (1.0 * size_h));
}

template <typename bitset_t, typename index_t>
Expand Down Expand Up @@ -253,7 +260,7 @@ template <typename bitset_t, typename index_t>
void bitset<bitset_t, index_t>::count(const raft::resources& res,
raft::device_scalar_view<index_t> count_gpu_scalar)
{
auto max_len = raft::make_host_scalar_view<index_t>(&bitset_len_);
auto max_len = raft::make_host_scalar_view<const index_t, index_t>(&bitset_len_);
auto values =
raft::make_device_vector_view<const bitset_t, index_t>(bitset_.data(), n_elements());
raft::popc(res, values, max_len, count_gpu_scalar);
Expand Down
23 changes: 23 additions & 0 deletions cpp/include/raft/core/bitset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,29 @@ struct bitset_view {
{
return raft::make_device_vector_view<const bitset_t, index_t>(bitset_ptr_, n_elements());
}
/**
* @brief Returns the number of bits set to true in count_gpu_scalar.
*
* @param[in] res RAFT resources
* @param[out] count_gpu_scalar Device scalar to store the count
*/
void count(const raft::resources& res, raft::device_scalar_view<index_t> count_gpu_scalar) const;
/**
* @brief Returns the number of bits set to true.
*
* @param res RAFT resources
* @return index_t Number of bits set to true
*/
auto count(const raft::resources& res) const -> index_t
{
auto count_gpu_scalar = raft::make_device_scalar<index_t>(res, 0.0);
count(res, count_gpu_scalar.view());
index_t count_cpu = 0;
raft::update_host(
&count_cpu, count_gpu_scalar.data_handle(), 1, resource::get_cuda_stream(res));
resource::sync_stream(res);
return count_cpu;
}

/**
* @brief Repeats the bitset data and copies it to the output device pointer.
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/util/detail/popc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,12 @@ namespace raft::detail {
*/
template <typename value_t, typename index_t>
void popc(const raft::resources& res,
device_vector_view<value_t, index_t> values,
raft::host_scalar_view<index_t> max_len,
device_vector_view<const value_t, index_t> values,
raft::host_scalar_view<const index_t, index_t> max_len,
raft::device_scalar_view<index_t> counter)
{
auto values_size = values.size();
auto values_matrix = raft::make_device_matrix_view<value_t, index_t, col_major>(
auto values_matrix = raft::make_device_matrix_view<const value_t, index_t, col_major>(
values.data_handle(), values_size, 1);
auto counter_vector = raft::make_device_vector_view<index_t, index_t>(counter.data_handle(), 1);

Expand Down
4 changes: 2 additions & 2 deletions cpp/include/raft/util/popc.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ namespace raft {
*/
template <typename value_t, typename index_t>
void popc(const raft::resources& res,
device_vector_view<value_t, index_t> values,
raft::host_scalar_view<index_t> max_len,
device_vector_view<const value_t, index_t> values,
raft::host_scalar_view<const index_t, index_t> max_len,
raft::device_scalar_view<index_t> counter)
{
detail::popc(res, values, max_len, counter);
Expand Down
18 changes: 16 additions & 2 deletions cpp/test/core/bitset.cu
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
ASSERT_TRUE(hostVecMatch(bitset_ref, bitset_result, raft::Compare<bitset_t>()));

// test sparsity, repeat and eval_n_elements
if constexpr (std::is_same_v<bitset_t, uint32_t> || std::is_same_v<bitset_t, uint64_t>) {
{
auto my_bitset_view = my_bitset.view();
auto sparsity_result = my_bitset_view.sparsity(res);
auto sparsity_ref = sparsity_cpu_bitset(bitset_ref, size_t(spec.bitset_len));
Expand All @@ -217,7 +217,20 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
update_host(
bitset_repeat_result.data(), repeat_device.data_handle(), repeat_device.size(), stream);
ASSERT_EQ(bitset_repeat_ref.size(), bitset_repeat_result.size());
ASSERT_TRUE(hostVecMatch(bitset_repeat_ref, bitset_repeat_result, raft::Compare<bitset_t>()));

index_t errors = 0;
static constexpr index_t len_per_item = sizeof(bitset_t) * 8;
bitset_t tail_len = (index_t(spec.bitset_len * spec.repeat_times) % len_per_item);
bitset_t tail_mask =
tail_len ? (bitset_t)((bitset_t{1} << tail_len) - bitset_t{1}) : ~bitset_t{0};
for (index_t i = 0; i < bitset_repeat_ref.size(); i++) {
if (i == bitset_repeat_ref.size() - 1) {
errors += (bitset_repeat_ref[i] & tail_mask) != (bitset_repeat_result[i] & tail_mask);
} else {
errors += (bitset_repeat_ref[i] != bitset_repeat_result[i]);
}
}
ASSERT_EQ(errors, 0);

// recheck the sparsity after repeat
sparsity_result =
Expand Down Expand Up @@ -246,6 +259,7 @@ class BitsetTest : public testing::TestWithParam<test_spec_bitset> {
ASSERT_EQ(my_bitset.none(res), false);
}
};
// auto inputs_bitset = ::testing::Values(test_spec_bitset{32, 5, 10, 2});

auto inputs_bitset = ::testing::Values(test_spec_bitset{32, 5, 10, 101},
test_spec_bitset{100, 30, 10, 13},
Expand Down
22 changes: 17 additions & 5 deletions cpp/test/util/popc.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class PopcTest : public ::testing::TestWithParam<PopcInputs<index_t>> {
index_t bit_position = index % (8 * sizeof(bits_t));

if (((element >> bit_position) & 1) == 0) {
element |= (static_cast<index_t>(1) << bit_position);
element |= (static_cast<bits_t>(1) << bit_position);
num_ones--;
}
}
Expand All @@ -101,7 +101,7 @@ class PopcTest : public ::testing::TestWithParam<PopcInputs<index_t>> {
raft::make_device_vector_view<const bits_t, index_t>(bits_d.data(), bits_d.size());

index_t max_len = params.n_rows * params.n_cols;
auto max_len_view = raft::make_host_scalar_view<index_t>(&max_len);
auto max_len_view = raft::make_host_scalar_view<const index_t, index_t>(&max_len);

index_t nnz_actual_h = 0;
rmm::device_scalar<index_t> nnz_actual_d(0, stream);
Expand All @@ -123,8 +123,17 @@ class PopcTest : public ::testing::TestWithParam<PopcInputs<index_t>> {
index_t nnz_expected;
};

using PopcTestI32 = PopcTest<int32_t>;
TEST_P(PopcTestI32, Result) { Run(); }
using PopcTestI32_U32 = PopcTest<int32_t, uint32_t>;
TEST_P(PopcTestI32_U32, Result) { Run(); }

using PopcTestI32_U64 = PopcTest<int32_t, uint64_t>;
TEST_P(PopcTestI32_U64, Result) { Run(); }

using PopcTestI32_U16 = PopcTest<int32_t, uint16_t>;
TEST_P(PopcTestI32_U16, Result) { Run(); }

using PopcTestI32_U8 = PopcTest<int32_t, uint8_t>;
TEST_P(PopcTestI32_U8, Result) { Run(); }

template <typename index_t>
const std::vector<PopcInputs<index_t>> popc_inputs = {
Expand Down Expand Up @@ -154,6 +163,9 @@ const std::vector<PopcInputs<index_t>> popc_inputs = {
{2, 33, 0.2},
};

INSTANTIATE_TEST_CASE_P(PopcTest, PopcTestI32, ::testing::ValuesIn(popc_inputs<int32_t>));
INSTANTIATE_TEST_CASE_P(PopcTest, PopcTestI32_U32, ::testing::ValuesIn(popc_inputs<int32_t>));
INSTANTIATE_TEST_CASE_P(PopcTest, PopcTestI32_U64, ::testing::ValuesIn(popc_inputs<int32_t>));
INSTANTIATE_TEST_CASE_P(PopcTest, PopcTestI32_U16, ::testing::ValuesIn(popc_inputs<int32_t>));
INSTANTIATE_TEST_CASE_P(PopcTest, PopcTestI32_U8, ::testing::ValuesIn(popc_inputs<int32_t>));

} // namespace raft

0 comments on commit 224742c

Please sign in to comment.