From 89c815a3dcf4d122de243f13072ac2243004609b Mon Sep 17 00:00:00 2001 From: Ray Douglass Date: Mon, 24 Jul 2023 17:08:53 -0400 Subject: [PATCH 01/10] v23.10 Updates [skip ci] From 58ba5a482d484f4eda3a5c82466431113416c587 Mon Sep 17 00:00:00 2001 From: Chuang Zhu <111838961+chuangz0@users.noreply.github.com> Date: Thu, 17 Aug 2023 20:57:04 +0800 Subject: [PATCH 02/10] PR: Use top-k from RAFT (#53) Closes #5 - fix bugs in `cpp/tests/wholegraph_ops/wholegraph_csr_weighted_sample_without_replacement_tests.cu` and `cpp/tests/graph_ops/csr_add_self_loop_utils.cu` - use `raft::warp_sort( select_k)` to impl weighted_sampling_without_replacement, when sample_count>256, cub::DeviceSegmentSort is used for the implementation. - remove `block_radix_topk.cuh`. replace `block_topk` in file `embedding_cache_func.cuh` with `raft::warp_sort`. Authors: - Chuang Zhu (https://github.com/chuangz0) Approvers: - https://github.com/dongxuy04 - Brad Rees (https://github.com/BradReesWork) URL: https://github.com/rapidsai/wholegraph/pull/53 --- cpp/src/wholegraph_ops/block_radix_topk.cuh | 371 ----------- ...ighted_sample_without_replacement_func.cuh | 622 ++++++++++-------- .../functions/embedding_cache_func.cuh | 200 +++--- .../graph_ops/csr_add_self_loop_utils.cu | 2 +- .../graph_sampling_test_utils.cu | 54 +- ...ighted_sample_without_replacement_tests.cu | 16 +- ...aph_weighted_sample_without_replacement.py | 39 +- 7 files changed, 482 insertions(+), 822 deletions(-) delete mode 100644 cpp/src/wholegraph_ops/block_radix_topk.cuh diff --git a/cpp/src/wholegraph_ops/block_radix_topk.cuh b/cpp/src/wholegraph_ops/block_radix_topk.cuh deleted file mode 100644 index 624c07510..000000000 --- a/cpp/src/wholegraph_ops/block_radix_topk.cuh +++ /dev/null @@ -1,371 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include - -namespace wholegraph_ops { - -template -class BlockRadixTopKGlobalMemory { - static_assert(cub::PowerOfTwo::VALUE && (RADIX_BITS <= (sizeof(KeyT) * 8)), - "RADIX_BITS should be power of 2, and <= (sizeof(KeyT) * 8)"); - static_assert(cub::PowerOfTwo::VALUE, "BLOCK_SIZE should be power of 2"); - using KeyTraits = cub::Traits; - using UnsignedBits = typename KeyTraits::UnsignedBits; - using BlockScanT = cub::BlockScan; - static constexpr int RADIX_SIZE = (1 << RADIX_BITS); - static constexpr int SCAN_ITEMS_PER_THREAD = (RADIX_SIZE + BLOCK_SIZE - 1) / BLOCK_SIZE; - using BinBlockLoad = cub::BlockLoad; - using BinBlockStore = cub::BlockStore; - struct _TempStorage { - typename BlockScanT::TempStorage scan_storage; - union { - typename BinBlockLoad::TempStorage load_storage; - typename BinBlockStore::TempStorage store_storage; - } load_store; - union { - int shared_bins[RADIX_SIZE]; - }; - int share_target_k; - int share_bucket_id; - }; - - public: - struct TempStorage : cub::Uninitialized<_TempStorage> {}; - __device__ __forceinline__ BlockRadixTopKGlobalMemory(TempStorage& temp_storage) - : temp_storage_{temp_storage.Alias()}, tid_(threadIdx.x){}; - __device__ __forceinline__ void radixTopKGetThreshold( - const KeyT* data, int k, int size, KeyT& topK, bool& topk_is_unique) - { - assert(k < size && k > 0); - int target_k = k; - UnsignedBits key_pattern = 0; - int digit_pos = sizeof(KeyT) * 8 - RADIX_BITS; - for (; digit_pos >= 0; digit_pos -= RADIX_BITS) { - UpdateSharedBins(data, size, digit_pos, key_pattern); - InclusiveScanBins(); - UpdateTopK(digit_pos, target_k, key_pattern); - if (target_k == 0) break; - } - if (target_k == 0) { - key_pattern -= 1; - topk_is_unique = true; - } else { - topk_is_unique = false; - } - if (GREATER) key_pattern = ~key_pattern; - UnsignedBits topK_unsigned = KeyTraits::TwiddleOut(key_pattern); - topK = reinterpret_cast(topK_unsigned); - } - - private: - __device__ __forceinline__ void UpdateSharedBins(const KeyT* key, - int size, - int digit_pos, - UnsignedBits key_pattern) - { - for (int id = tid_; id < RADIX_SIZE; id += BLOCK_SIZE) { - temp_storage_.shared_bins[id] = 0; - } - cub::CTA_SYNC(); - UnsignedBits key_mask = ((UnsignedBits)(-1)) << ((UnsignedBits)(digit_pos + RADIX_BITS)); -#pragma unroll - for (int idx = tid_; idx < size; idx += BLOCK_SIZE) { - KeyT key_data = key[idx]; - UnsignedBits twiddled_data = KeyTraits::TwiddleIn(reinterpret_cast(key_data)); - if (GREATER) twiddled_data = ~twiddled_data; - UnsignedBits digit_in_radix = cub::BFE(twiddled_data, digit_pos, RADIX_BITS); - if ((twiddled_data & key_mask) == (key_pattern & key_mask)) { - atomicAdd(&temp_storage_.shared_bins[digit_in_radix], 1); - } - } - cub::CTA_SYNC(); - } - __device__ __forceinline__ void InclusiveScanBins() - { - int items[SCAN_ITEMS_PER_THREAD]; - BinBlockLoad(temp_storage_.load_store.load_storage) - .Load(temp_storage_.shared_bins, items, RADIX_SIZE, 0); - cub::CTA_SYNC(); - BlockScanT(temp_storage_.scan_storage).InclusiveSum(items, items); - cub::CTA_SYNC(); - BinBlockStore(temp_storage_.load_store.store_storage) - .Store(temp_storage_.shared_bins, items, RADIX_SIZE); - cub::CTA_SYNC(); - } - __device__ __forceinline__ void UpdateTopK(int digit_pos, - int& target_k, - UnsignedBits& target_pattern) - { - for (int idx = tid_; (idx < RADIX_SIZE); idx += BLOCK_SIZE) { - int prev_count = (idx == 0) ? 0 : temp_storage_.shared_bins[idx - 1]; - int cur_count = temp_storage_.shared_bins[idx]; - if (prev_count <= target_k && cur_count > target_k) { - temp_storage_.share_target_k = target_k - prev_count; - temp_storage_.share_bucket_id = idx; - } - } - cub::CTA_SYNC(); - target_k = temp_storage_.share_target_k; - int target_bucket_id = temp_storage_.share_bucket_id; - UnsignedBits key_segment = ((UnsignedBits)target_bucket_id) << ((UnsignedBits)digit_pos); - target_pattern |= key_segment; - } - _TempStorage& temp_storage_; - int tid_; -}; - -template -class BlockRadixTopKRegister { - static_assert(cub::PowerOfTwo::VALUE && (RADIX_BITS <= (sizeof(KeyT) * 8)), - "RADIX_BITS should be power of 2, and <= (sizeof(KeyT) * 8)"); - static_assert(cub::PowerOfTwo::VALUE, "BLOCK_SIZE should be power of 2"); - using KeyTraits = cub::Traits; - using UnsignedBits = typename KeyTraits::UnsignedBits; - using BlockScanT = cub::BlockScan; - static constexpr int RADIX_SIZE = (1 << RADIX_BITS); - static constexpr bool KEYS_ONLY = std::is_same::value; - static constexpr int SCAN_ITEMS_PER_THREAD = (RADIX_SIZE + BLOCK_SIZE - 1) / BLOCK_SIZE; - using BinBlockLoad = cub::BlockLoad; - using BinBlockStore = cub::BlockStore; - using BlockExchangeKey = cub::BlockExchange; - using BlockExchangeValue = cub::BlockExchange; - - using _ExchangeKeyTempStorage = typename BlockExchangeKey::TempStorage; - using _ExchangeValueTempStorage = typename BlockExchangeValue::TempStorage; - typedef union ExchangeKeyTempStorageType { - _ExchangeKeyTempStorage key_storage; - } ExchKeyTempStorageType; - typedef union ExchangeKeyValueTempStorageType { - _ExchangeKeyTempStorage key_storage; - _ExchangeValueTempStorage value_storage; - } ExchKeyValueTempStorageType; - using _ExchangeType = - typename std::conditional::type; - - struct _TempStorage { - typename BlockScanT::TempStorage scan_storage; - union { - typename BinBlockLoad::TempStorage load_storage; - typename BinBlockStore::TempStorage store_storage; - } load_store; - union { - int shared_bins[RADIX_SIZE]; - _ExchangeType exchange_storage; - }; - int share_target_k; - int share_bucket_id; - int share_prev_count; - }; - - public: - struct TempStorage : cub::Uninitialized<_TempStorage> {}; - __device__ __forceinline__ BlockRadixTopKRegister(TempStorage& temp_storage) - : temp_storage_{temp_storage.Alias()}, tid_(threadIdx.x){}; - __device__ __forceinline__ void radixTopKToStriped(KeyT (&keys)[ITEMS_PER_THREAD], - const int k, - const int valid_count) - { - if (k == valid_count) return; - TopKGenRank(keys, k, valid_count); - int is_valid[ITEMS_PER_THREAD]; - GenValidArray(is_valid, k); - BlockExchangeKey{temp_storage_.exchange_storage.key_storage}.ScatterToStripedFlagged( - keys, keys, ranks_, is_valid); - cub::CTA_SYNC(); - } - __device__ __forceinline__ void radixTopKToStriped(KeyT (&keys)[ITEMS_PER_THREAD], - ValueT (&values)[ITEMS_PER_THREAD], - const int k, - const int valid_count) - { - if (k == valid_count) return; - TopKGenRank(keys, k, valid_count); - int is_valid[ITEMS_PER_THREAD]; - GenValidArray(is_valid, k); - BlockExchangeKey{temp_storage_.exchange_storage.key_storage}.ScatterToStripedFlagged( - keys, keys, ranks_, is_valid); - cub::CTA_SYNC(); - BlockExchangeValue{temp_storage_.exchange_storage.value_storage}.ScatterToStripedFlagged( - values, values, ranks_, is_valid); - cub::CTA_SYNC(); - } - - private: - __device__ __forceinline__ void TopKGenRank(KeyT (&keys)[ITEMS_PER_THREAD], - const int k, - const int valid_count) - { - assert(k <= BLOCK_SIZE * ITEMS_PER_THREAD); - assert(k <= valid_count); - UnsignedBits(&unsigned_keys)[ITEMS_PER_THREAD] = - reinterpret_cast(keys); - search_mask_ = 0; - top_k_mask_ = 0; - -#pragma unroll - for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { - int idx = KEY * BLOCK_SIZE + tid_; - unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]); - if (GREATER) unsigned_keys[KEY] = ~unsigned_keys[KEY]; - if (idx < valid_count) search_mask_ |= (1U << KEY); - } - - int target_k = k; - int prefix_k = 0; - - for (int digit_pos = sizeof(KeyT) * 8 - RADIX_BITS; digit_pos >= 0; digit_pos -= RADIX_BITS) { - UpdateSharedBins(unsigned_keys, digit_pos, prefix_k); - InclusiveScanBins(); - UpdateTopK(unsigned_keys, digit_pos, target_k, prefix_k, digit_pos == 0); - if (target_k == 0) break; - } - -#pragma unroll - for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { - if (GREATER) unsigned_keys[KEY] = ~unsigned_keys[KEY]; - unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]); - } - } - __device__ __forceinline__ void GenValidArray(int (&is_valid)[ITEMS_PER_THREAD], int k) - { -#pragma unroll - for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { - if ((top_k_mask_ & (1U << KEY)) && ranks_[KEY] < k) { - is_valid[KEY] = 1; - } else { - is_valid[KEY] = 0; - } - } - } - __device__ __forceinline__ void UpdateSharedBins(UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD], - int digit_pos, - int prefix_k) - { - for (int id = tid_; id < RADIX_SIZE; id += BLOCK_SIZE) { - temp_storage_.shared_bins[id] = 0; - } - cub::CTA_SYNC(); -// #define USE_MATCH -#ifdef USE_MATCH - int lane_mask = cub::LaneMaskLt(); -#pragma unroll - for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { - bool is_search = search_mask_ & (1U << KEY); - int bucket_idx = -1; - if (is_search) { - UnsignedBits digit_in_radix = - cub::BFE(unsigned_keys[KEY], digit_pos, RADIX_BITS); - bucket_idx = (int)digit_in_radix; - } - int warp_match_mask = __match_any_sync(0xffffffff, bucket_idx); - int same_count = __popc(warp_match_mask); - int idx_in_same_bucket = __popc(warp_match_mask & lane_mask); - int same_bucket_root_lane = __ffs(warp_match_mask) - 1; - int same_bucket_start_idx; - if (idx_in_same_bucket == 0 && is_search) { - same_bucket_start_idx = atomicAdd(&temp_storage_.shared_bins[bucket_idx], same_count); - } - same_bucket_start_idx = - __shfl_sync(0xffffffff, same_bucket_start_idx, same_bucket_root_lane, 32); - if (is_search) { ranks_[KEY] = same_bucket_start_idx + idx_in_same_bucket + prefix_k; } - } -#else -#pragma unroll - for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { - bool is_search = search_mask_ & (1U << KEY); - int bucket_idx = -1; - if (is_search) { - UnsignedBits digit_in_radix = - cub::BFE(unsigned_keys[KEY], digit_pos, RADIX_BITS); - bucket_idx = (int)digit_in_radix; - ranks_[KEY] = atomicAdd(&temp_storage_.shared_bins[bucket_idx], 1) + prefix_k; - } - } -#endif - cub::CTA_SYNC(); - } - __device__ __forceinline__ void InclusiveScanBins() - { - int items[SCAN_ITEMS_PER_THREAD]; - BinBlockLoad(temp_storage_.load_store.load_storage) - .Load(temp_storage_.shared_bins, items, RADIX_SIZE, 0); - cub::CTA_SYNC(); - BlockScanT(temp_storage_.scan_storage).InclusiveSum(items, items); - cub::CTA_SYNC(); - BinBlockStore(temp_storage_.load_store.store_storage) - .Store(temp_storage_.shared_bins, items, RADIX_SIZE); - cub::CTA_SYNC(); - } - __device__ __forceinline__ void UpdateTopK(UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD], - int digit_pos, - int& target_k, - int& prefix_k, - bool mark_equal) - { - for (int idx = tid_; (idx < RADIX_SIZE); idx += BLOCK_SIZE) { - int prev_count = (idx == 0) ? 0 : temp_storage_.shared_bins[idx - 1]; - int cur_count = temp_storage_.shared_bins[idx]; - if (prev_count <= target_k && cur_count > target_k) { - temp_storage_.share_target_k = target_k - prev_count; - temp_storage_.share_bucket_id = idx; - temp_storage_.share_prev_count = prev_count; - } - } - cub::CTA_SYNC(); - target_k = temp_storage_.share_target_k; - prefix_k += temp_storage_.share_prev_count; - int target_bucket_id = temp_storage_.share_bucket_id; -#pragma unroll - for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { - if (search_mask_ & (1U << KEY)) { - UnsignedBits digit_in_radix = - cub::BFE(unsigned_keys[KEY], digit_pos, RADIX_BITS); - if (digit_in_radix < target_bucket_id) { - top_k_mask_ |= (1U << KEY); - search_mask_ &= ~(1U << KEY); - } else if (digit_in_radix > target_bucket_id) { - search_mask_ &= ~(1U << KEY); - } else { - if (mark_equal) top_k_mask_ |= (1U << KEY); - } - if (digit_in_radix <= target_bucket_id) { - int prev_count = - (digit_in_radix == 0) ? 0 : temp_storage_.shared_bins[digit_in_radix - 1]; - ranks_[KEY] += prev_count; - } - } - } - cub::CTA_SYNC(); - } - - _TempStorage& temp_storage_; - int tid_; - int ranks_[ITEMS_PER_THREAD]; - unsigned int search_mask_; - unsigned int top_k_mask_; -}; - -} // namespace wholegraph_ops diff --git a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh index 22a97fd19..a2915cd00 100644 --- a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh @@ -14,22 +14,26 @@ * limitations under the License. */ #pragma once +#include #include +#include +#include +#include #include #include +#include "raft/matrix/detail/select_warpsort.cuh" +#include "raft/util/cuda_dev_essentials.cuh" +#include "wholememory_ops/output_memory_handle.hpp" +#include "wholememory_ops/raft_random.cuh" +#include "wholememory_ops/temp_memory_handle.hpp" +#include "wholememory_ops/thrust_allocator.hpp" #include #include #include #include #include -#include "wholememory_ops/output_memory_handle.hpp" -#include "wholememory_ops/raft_random.cuh" -#include "wholememory_ops/temp_memory_handle.hpp" -#include "wholememory_ops/thrust_allocator.hpp" - -#include "block_radix_topk.cuh" #include "cuda_macros.hpp" #include "error.hpp" #include "sample_comm.cuh" @@ -54,16 +58,14 @@ __device__ __forceinline__ float gen_key_from_weight(const WeightType weight, PC } template -__launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacement_large_kernel( + unsigned int BLOCK_SIZE> +__launch_bounds__(BLOCK_SIZE) __global__ void generate_weighted_keys_and_idxs_kernel( wholememory_gref_t wm_csr_row_ptr, wholememory_array_description_t wm_csr_row_ptr_desc, wholememory_gref_t wm_csr_col_ptr, @@ -74,18 +76,14 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen const int input_node_count, const int max_sample_count, unsigned long long random_seed, - const int* sample_offset, - wholememory_array_description_t sample_offset_desc, const int* target_neighbor_offset, - WMIdType* output, - int* src_lid, - int64_t* out_edge_gid, - WeightKeyType* weight_keys_buff) + WeightKeyType* output_weighted_keys, + NeighborIdxType* output_idxs, + bool need_random = true) { int input_idx = blockIdx.x; if (input_idx >= input_node_count) return; int gidx = threadIdx.x + blockIdx.x * BLOCK_SIZE; - wholememory::device_reference csr_row_ptr_gen(wm_csr_row_ptr); wholememory::device_reference csr_col_ptr_gen(wm_csr_col_ptr); wholememory::device_reference csr_weight_ptr_gen(wm_csr_weight_ptr); @@ -93,13 +91,57 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen int64_t start = csr_row_ptr_gen[nid]; int64_t end = csr_row_ptr_gen[nid + 1]; int neighbor_count = (int)(end - start); + if (neighbor_count <= max_sample_count) { need_random = false; } + + PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); + int output_offset = target_neighbor_offset[input_idx]; + output_weighted_keys += output_offset; + output_idxs += output_offset; + for (int id = threadIdx.x; id < neighbor_count; id += BLOCK_SIZE) { + WeightType thread_weight = csr_weight_ptr_gen[start + id]; + output_weighted_keys[id] = + need_random ? static_cast(gen_key_from_weight(thread_weight, rng)) + : (static_cast(thread_weight)); + output_idxs[id] = static_cast(id); + } +} + +template +__launch_bounds__(BLOCK_SIZE) __global__ + void weighted_sample_select_k_kernel(wholememory_gref_t wm_csr_row_ptr, + wholememory_array_description_t wm_csr_row_ptr_desc, + wholememory_gref_t wm_csr_col_ptr, + wholememory_array_description_t wm_csr_col_ptr_desc, + const IdType* input_nodes, + const int input_node_count, + const int max_sample_count, + const int* sample_offset, + wholememory_array_description_t sample_offset_desc, + const NeighborIdxType* sorted_idxs, + const int* target_neighbor_offset, + WMIdType* output, + LocalIdType* src_lid, + int64_t* out_edge_gid) +{ + int input_idx = blockIdx.x; + if (input_idx >= input_node_count) return; + wholememory::device_reference csr_row_ptr_gen(wm_csr_row_ptr); + wholememory::device_reference csr_col_ptr_gen(wm_csr_col_ptr); + IdType nid = input_nodes[input_idx]; + int64_t start = csr_row_ptr_gen[nid]; + int64_t end = csr_row_ptr_gen[nid + 1]; + int neighbor_count = (int)(end - start); + + int offset = sample_offset[input_idx]; - WeightKeyType* weight_keys_local_buff = weight_keys_buff + target_neighbor_offset[input_idx]; - int offset = sample_offset[input_idx]; if (neighbor_count <= max_sample_count) { for (int sample_id = threadIdx.x; sample_id < neighbor_count; sample_id += BLOCK_SIZE) { - int neighbor_idx = sample_id; - int original_neighbor_idx = neighbor_idx; + int original_neighbor_idx = sample_id; IdType gid = csr_col_ptr_gen[start + original_neighbor_idx]; output[offset + sample_id] = gid; if (src_lid) src_lid[offset + sample_id] = (LocalIdType)input_idx; @@ -108,83 +150,14 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen } return; } - - PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); - for (int id = threadIdx.x; id < neighbor_count; id += BLOCK_SIZE) { - WeightType thread_weight = csr_weight_ptr_gen[start + id]; - weight_keys_local_buff[id] = - NeedRandom ? static_cast(gen_key_from_weight(thread_weight, rng)) - : (static_cast(thread_weight)); - } - - __syncthreads(); - - WeightKeyType topk_val; - bool topk_is_unique; - - using BlockRadixSelectT = - std::conditional_t, - BlockRadixTopKGlobalMemory>; - __shared__ typename BlockRadixSelectT::TempStorage share_storage; - - BlockRadixSelectT{share_storage}.radixTopKGetThreshold( - weight_keys_local_buff, max_sample_count, neighbor_count, topk_val, topk_is_unique); - __shared__ int cnt; - - if (threadIdx.x == 0) { cnt = 0; } - __syncthreads(); - - for (int i = threadIdx.x; i < max_sample_count; i += BLOCK_SIZE) { - if (src_lid) src_lid[offset + i] = (LocalIdType)input_idx; - } - - // We use atomicAdd 1 operations instead of binaryScan to calculate the write - // index, since we do not need to keep the relative positions of element. - - if (topk_is_unique) { - for (int neighbor_idx = threadIdx.x; neighbor_idx < neighbor_count; - neighbor_idx += BLOCK_SIZE) { - WeightKeyType key = weight_keys_local_buff[neighbor_idx]; - bool has_topk = Ascending ? (key <= topk_val) : (key >= topk_val); - - if (has_topk) { - int write_index = atomicAdd(&cnt, 1); - LocalIdType local_original_idx = neighbor_idx; - output[offset + write_index] = csr_col_ptr_gen[start + local_original_idx]; - if (out_edge_gid) - out_edge_gid[offset + write_index] = static_cast(start + local_original_idx); - } - } - } else { - for (int neighbor_idx = threadIdx.x; neighbor_idx < neighbor_count; - neighbor_idx += BLOCK_SIZE) { - WeightKeyType key = weight_keys_local_buff[neighbor_idx]; - bool has_topk = Ascending ? (key < topk_val) : (key > topk_val); - - if (has_topk) { - int write_index = atomicAdd(&cnt, 1); - LocalIdType local_original_idx = neighbor_idx; - output[offset + write_index] = csr_col_ptr_gen[start + local_original_idx]; - if (out_edge_gid) - out_edge_gid[offset + write_index] = static_cast(start + local_original_idx); - } - } - __syncthreads(); - for (int neighbor_idx = threadIdx.x; neighbor_idx < neighbor_count; - neighbor_idx += BLOCK_SIZE) { - WeightKeyType key = weight_keys_local_buff[neighbor_idx]; - bool has_topk = (key == topk_val); - - if (has_topk) { - int write_index = atomicAdd(&cnt, 1); - if (write_index >= max_sample_count) break; - LocalIdType local_original_idx = neighbor_idx; - output[offset + write_index] = csr_col_ptr_gen[start + local_original_idx]; - if (out_edge_gid) - out_edge_gid[offset + write_index] = static_cast(start + local_original_idx); - } - } + int neighbor_offset = target_neighbor_offset[input_idx]; + for (int sample_id = threadIdx.x; sample_id < max_sample_count; sample_id += BLOCK_SIZE) { + int original_neighbor_idx = sorted_idxs[neighbor_offset + sample_id]; + IdType gid = csr_col_ptr_gen[start + original_neighbor_idx]; + output[offset + sample_id] = gid; + if (src_lid) src_lid[offset + sample_id] = (LocalIdType)input_idx; + if (out_edge_gid) + out_edge_gid[offset + sample_id] = static_cast(start + original_neighbor_idx); } } @@ -216,21 +189,30 @@ __global__ void get_sample_count_and_neighbor_count_without_replacement_kernel( } } +// to avoid queue.store() store keys or values in output. +struct null_store_t {}; +struct null_store_op { + template + constexpr auto operator()(const Type& in, UnusedArgs...) const + { + return null_store_t{}; + } +}; + // A-RES algorithmn // https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Res -// max_sample_count should <=(BLOCK_SIZE*ITEMS_PER_THREAD*/4) otherwise,need to -// change the template parameters of BlockRadixTopK. -template class WarpSortClass, + int Capacity, + typename IdType, typename LocalIdType, typename WeightType, + typename NeighborIdxType, typename WMIdType, typename WMOffsetType, typename WMWeightType, - unsigned int ITEMS_PER_THREAD, - unsigned int BLOCK_SIZE, - bool NeedRandom = true, - bool Ascending = false> -__launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacement_kernel( + bool NEED_RANDOM = true, + bool ASCENDING = false> +__launch_bounds__(256) __global__ void weighted_sample_without_replacement_raft_kernel( wholememory_gref_t wm_csr_row_ptr, wholememory_array_description_t wm_csr_row_ptr_desc, wholememory_gref_t wm_csr_col_ptr, @@ -244,13 +226,12 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen const int* sample_offset, wholememory_array_description_t sample_offset_desc, WMIdType* output, - int* src_lid, + LocalIdType* src_lid, int64_t* out_edge_gid) { int input_idx = blockIdx.x; if (input_idx >= input_node_count) return; - int gidx = threadIdx.x + blockIdx.x * BLOCK_SIZE; - + int gidx = threadIdx.x + blockIdx.x * blockDim.x; wholememory::device_reference csr_row_ptr_gen(wm_csr_row_ptr); wholememory::device_reference csr_col_ptr_gen(wm_csr_col_ptr); wholememory::device_reference csr_weight_ptr_gen(wm_csr_weight_ptr); @@ -258,86 +239,153 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen IdType nid = input_nodes[input_idx]; int64_t start = csr_row_ptr_gen[nid]; int64_t end = csr_row_ptr_gen[nid + 1]; - int neighbor_count = (int)(end - start); + int neighbor_count = static_cast(end - start); int offset = sample_offset[input_idx]; if (neighbor_count <= max_sample_count) { - for (int sample_id = threadIdx.x; sample_id < neighbor_count; sample_id += BLOCK_SIZE) { + for (int sample_id = threadIdx.x; sample_id < neighbor_count; sample_id += blockDim.x) { int neighbor_idx = sample_id; int original_neighbor_idx = neighbor_idx; IdType gid = csr_col_ptr_gen[start + original_neighbor_idx]; output[offset + sample_id] = gid; - if (src_lid) src_lid[offset + sample_id] = (LocalIdType)input_idx; + if (src_lid) src_lid[offset + sample_id] = input_idx; if (out_edge_gid) out_edge_gid[offset + sample_id] = static_cast(start + original_neighbor_idx); } return; } else { - PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); - - float weight_keys[ITEMS_PER_THREAD]; - int neighbor_idxs[ITEMS_PER_THREAD]; - - using BlockRadixTopKT = - std::conditional_t, - BlockRadixTopKRegister>; - - __shared__ typename BlockRadixTopKT::TempStorage sort_tmp_storage; - - const int tx = threadIdx.x; -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - int idx = BLOCK_SIZE * i + tx; + extern __shared__ __align__(256) uint8_t smem_buf_bytes[]; + using bq_t = raft::matrix::detail::select::warpsort:: + block_sort; + + uint8_t* warp_smem = bq_t::queue_t::mem_required(blockDim.x) > 0 ? smem_buf_bytes : nullptr; + bq_t queue(max_sample_count, warp_smem); + PCGenerator rng(random_seed, static_cast(gidx), static_cast(0)); + const int per_thread_lim = neighbor_count + raft::laneId(); + for (int idx = threadIdx.x; idx < per_thread_lim; idx += blockDim.x) { + WeightType weight_key = + WarpSortClass::kDummy; if (idx < neighbor_count) { WeightType thread_weight = csr_weight_ptr_gen[start + idx]; - weight_keys[i] = - NeedRandom ? gen_key_from_weight(thread_weight, rng) : (float)thread_weight; - neighbor_idxs[i] = idx; + weight_key = NEED_RANDOM ? gen_key_from_weight(thread_weight, rng) : thread_weight; } + queue.add(weight_key, idx); } - const int valid_count = (neighbor_count < (BLOCK_SIZE * ITEMS_PER_THREAD)) - ? neighbor_count - : (BLOCK_SIZE * ITEMS_PER_THREAD); - BlockRadixTopKT{sort_tmp_storage}.radixTopKToStriped( - weight_keys, neighbor_idxs, max_sample_count, valid_count); + queue.done(smem_buf_bytes); + __syncthreads(); - const int stride = BLOCK_SIZE * ITEMS_PER_THREAD - max_sample_count; - - for (int idx_offset = ITEMS_PER_THREAD * BLOCK_SIZE; idx_offset < neighbor_count; - idx_offset += stride) { -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - int local_idx = BLOCK_SIZE * i + tx - max_sample_count; - // [0,BLOCK_SIZE*ITEMS_PER_THREAD-max_sample_count) - int target_idx = idx_offset + local_idx; - if (local_idx >= 0 && target_idx < neighbor_count) { - WeightType thread_weight = csr_weight_ptr_gen[start + target_idx]; - weight_keys[i] = - NeedRandom ? gen_key_from_weight(thread_weight, rng) : (float)thread_weight; - neighbor_idxs[i] = target_idx; - } + NeighborIdxType* smem_topk_idx = reinterpret_cast(smem_buf_bytes); + queue.store(static_cast(nullptr), smem_topk_idx, null_store_op{}); + __syncthreads(); + for (int idx = threadIdx.x; idx < max_sample_count; idx += blockDim.x) { + NeighborIdxType local_original_idx = static_cast(smem_topk_idx[idx]); + if (src_lid) { src_lid[offset + idx] = static_cast(input_idx); } + output[offset + idx] = csr_col_ptr_gen[start + local_original_idx]; + if (out_edge_gid) { + out_edge_gid[offset + idx] = static_cast(start + local_original_idx); } - const int iter_valid_count = ((neighbor_count - idx_offset) >= stride) - ? (BLOCK_SIZE * ITEMS_PER_THREAD) - : (max_sample_count + neighbor_count - idx_offset); - BlockRadixTopKT{sort_tmp_storage}.radixTopKToStriped( - weight_keys, neighbor_idxs, max_sample_count, iter_valid_count); - __syncthreads(); } -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - int idx = i * BLOCK_SIZE + tx; - if (idx < max_sample_count) { - if (src_lid) src_lid[offset + idx] = (LocalIdType)input_idx; - LocalIdType local_original_idx = neighbor_idxs[i]; - output[offset + idx] = csr_col_ptr_gen[start + local_original_idx]; - if (out_edge_gid) - out_edge_gid[offset + idx] = static_cast(start + local_original_idx); - } + }; +} + +template