diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index 32fd6fcde..a851cce06 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -100,3 +100,7 @@ ConfigureNVBench(RETRIEVE_BENCH # - reduce_by_key benchmarks ---------------------------------------------------------------------- set(RBK_BENCH_SRC "${CMAKE_CURRENT_SOURCE_DIR}/reduce_by_key/reduce_by_key.cu") ConfigureBench(RBK_BENCH "${RBK_BENCH_SRC}") + +################################################################################################### +set(PRIORITY_QUEUE_BENCH_SRC "${CMAKE_CURRENT_SOURCE_DIR}/priority_queue/priority_queue_bench.cu") +ConfigureBench(PRIORITY_QUEUE_BENCH "${PRIORITY_QUEUE_BENCH_SRC}") diff --git a/benchmarks/priority_queue/priority_queue_bench.cu b/benchmarks/priority_queue/priority_queue_bench.cu new file mode 100644 index 000000000..b40e142c5 --- /dev/null +++ b/benchmarks/priority_queue/priority_queue_bench.cu @@ -0,0 +1,93 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include + +#include + +#include +#include +#include + +using namespace cuco; + +template +struct pair_less { + __host__ __device__ bool operator()(const T& a, const T& b) const { return a.first < b.first; } +}; + +template +static void generate_kv_pairs_uniform(OutputIt output_begin, OutputIt output_end) +{ + std::random_device rd; + std::mt19937 gen{rd()}; + + const auto num_keys = std::distance(output_begin, output_end); + + for (auto i = 0; i < num_keys; ++i) { + output_begin[i] = {static_cast(gen()), static_cast(gen())}; + } +} + +template +static void BM_insert(::benchmark::State& state) +{ + for (auto _ : state) { + state.PauseTiming(); + + priority_queue, pair_less>> pq(NumKeys); + + std::vector> h_pairs(NumKeys); + generate_kv_pairs_uniform(h_pairs.begin(), h_pairs.end()); + const thrust::device_vector> d_pairs(h_pairs); + + state.ResumeTiming(); + pq.push(d_pairs.begin(), d_pairs.end()); + cudaDeviceSynchronize(); + } +} + +template +static void BM_delete(::benchmark::State& state) +{ + for (auto _ : state) { + state.PauseTiming(); + + priority_queue, pair_less>> pq(NumKeys); + + std::vector> h_pairs(NumKeys); + generate_kv_pairs_uniform(h_pairs.begin(), h_pairs.end()); + thrust::device_vector> d_pairs(h_pairs); + + pq.push(d_pairs.begin(), d_pairs.end()); + cudaDeviceSynchronize(); + + state.ResumeTiming(); + pq.pop(d_pairs.begin(), d_pairs.end()); + cudaDeviceSynchronize(); + } +} + +BENCHMARK_TEMPLATE(BM_insert, int, int, 128'000'000)->Unit(benchmark::kMillisecond); + +BENCHMARK_TEMPLATE(BM_delete, int, int, 128'000'000)->Unit(benchmark::kMillisecond); + +BENCHMARK_TEMPLATE(BM_insert, int, int, 256'000'000)->Unit(benchmark::kMillisecond); + +BENCHMARK_TEMPLATE(BM_delete, int, int, 256'000'000)->Unit(benchmark::kMillisecond); diff --git a/include/cuco/detail/priority_queue.inl b/include/cuco/detail/priority_queue.inl new file mode 100644 index 000000000..189166c51 --- /dev/null +++ b/include/cuco/detail/priority_queue.inl @@ -0,0 +1,197 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include + +namespace cuco { + +template +priority_queue::priority_queue(std::size_t initial_capacity, + Allocator const& allocator, + cudaStream_t stream) + : int_allocator_{allocator}, t_allocator_{allocator}, size_t_allocator_{allocator} +{ + node_size_ = 1024; + + // Round up to the nearest multiple of node size + const int nodes = ((initial_capacity + node_size_ - 1) / node_size_); + + node_capacity_ = nodes; + lowest_level_start_ = 1 << static_cast(std::log2(nodes)); + + // Allocate device variables + + d_size_ = std::allocator_traits::allocate(int_allocator_, 1); + + CUCO_CUDA_TRY(cudaMemsetAsync(d_size_, 0, sizeof(int), stream)); + + d_p_buffer_size_ = std::allocator_traits::allocate(size_t_allocator_, 1); + + CUCO_CUDA_TRY(cudaMemsetAsync(d_p_buffer_size_, 0, sizeof(std::size_t), stream)); + + d_heap_ = std::allocator_traits::allocate( + t_allocator_, node_capacity_ * node_size_ + node_size_); + + d_locks_ = + std::allocator_traits::allocate(int_allocator_, node_capacity_ + 1); + + CUCO_CUDA_TRY(cudaMemsetAsync(d_locks_, 0, sizeof(int) * (node_capacity_ + 1), stream)); +} + +template +priority_queue::~priority_queue() +{ + std::allocator_traits::deallocate(int_allocator_, d_size_, 1); + std::allocator_traits::deallocate(size_t_allocator_, d_p_buffer_size_, 1); + std::allocator_traits::deallocate( + t_allocator_, d_heap_, node_capacity_ * node_size_ + node_size_); + std::allocator_traits::deallocate( + int_allocator_, d_locks_, node_capacity_ + 1); +} + +template +template +void priority_queue::push(InputIt first, InputIt last, cudaStream_t stream) +{ + constexpr int block_size = 256; + + const int num_nodes = static_cast((last - first) / node_size_) + 1; + const int num_blocks = std::min(64000, num_nodes); + + detail::push_kernel<<>>( + first, + last - first, + d_heap_, + d_size_, + node_size_, + d_locks_, + d_p_buffer_size_, + lowest_level_start_, + compare_); + + CUCO_CUDA_TRY(cudaGetLastError()); +} + +template +template +void priority_queue::pop(OutputIt first, OutputIt last, cudaStream_t stream) +{ + constexpr int block_size = 256; + const int pop_size = last - first; + + const int num_nodes = static_cast(pop_size / node_size_) + 1; + const int num_blocks = std::min(64000, num_nodes); + + detail::pop_kernel<<>>( + first, + pop_size, + d_heap_, + d_size_, + node_size_, + d_locks_, + d_p_buffer_size_, + lowest_level_start_, + node_capacity_, + compare_); + + CUCO_CUDA_TRY(cudaGetLastError()); +} + +template +template +__device__ void priority_queue::device_mutable_view::push(CG const& g, + InputIt first, + InputIt last, + void* temp_storage) +{ + const detail::shared_memory_layout shmem = + detail::get_shared_memory_layout((int*)temp_storage, g.size(), node_size_); + + const auto push_size = last - first; + for (std::size_t i = 0; i < push_size / node_size_; i++) { + detail::push_single_node(g, + first + i * node_size_, + d_heap_, + d_size_, + node_size_, + d_locks_, + lowest_level_start_, + shmem, + compare_); + } + + if (push_size % node_size_ != 0) { + detail::push_partial_node(g, + first + (push_size / node_size_) * node_size_, + push_size % node_size_, + d_heap_, + d_size_, + node_size_, + d_locks_, + d_p_buffer_size_, + lowest_level_start_, + shmem, + compare_); + } +} + +template +template +__device__ void priority_queue::device_mutable_view::pop(CG const& g, + OutputIt first, + OutputIt last, + void* temp_storage) +{ + const detail::shared_memory_layout shmem = + detail::get_shared_memory_layout((int*)temp_storage, g.size(), node_size_); + + const auto pop_size = last - first; + for (std::size_t i = 0; i < pop_size / node_size_; i++) { + detail::pop_single_node(g, + first + i * node_size_, + d_heap_, + d_size_, + node_size_, + d_locks_, + d_p_buffer_size_, + lowest_level_start_, + node_capacity_, + shmem, + compare_); + } + + if (pop_size % node_size_ != 0) { + detail::pop_partial_node(g, + first + (pop_size / node_size_) * node_size_, + last - first, + d_heap_, + d_size_, + node_size_, + d_locks_, + d_p_buffer_size_, + lowest_level_start_, + node_capacity_, + shmem, + compare_); + } +} + +} // namespace cuco diff --git a/include/cuco/detail/priority_queue_kernels.cuh b/include/cuco/detail/priority_queue_kernels.cuh new file mode 100644 index 000000000..6ec36233b --- /dev/null +++ b/include/cuco/detail/priority_queue_kernels.cuh @@ -0,0 +1,1236 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +namespace cuco { +namespace detail { +namespace cg = cooperative_groups; + +constexpr int kPBufferIdx = 0; +constexpr int kRootIdx = 1; + +/* + * Struct to hold pointers to the temp storage used by the priority + * queue's kernels and functions. + * Ideally, this temp storage is in shared memory + */ +template +struct shared_memory_layout { + int* intersections; + T* a; + T* b; +}; + +/* + * Get the shared memory layout for a given group dimension + * and node size. + * + * @param s Pointer to the beginning of the section of shared memory to + * partition + * @param dim Size of the cooperative group the memory will be used by + * @param node_size Size of the nodes in this priority queue + * @returns The memory layout for the given group dimension and node size + */ +template +__device__ shared_memory_layout get_shared_memory_layout(int* s, int dim, std::size_t node_size) +{ + shared_memory_layout result; + result.intersections = s; + result.a = (T*)(s + 2 * (dim + 1)); + result.b = result.a + node_size; + return result; +} + +/** + * Acquires lock l for the current thread block + * The entire thread block must call the function + * + * @param g The cooperative group that will acquire the lock + * @param l Pointer to the lock to be acquired + */ +template +__device__ void acquire_lock(CG const& g, int* l) +{ + if (g.thread_rank() == 0) { + while (atomicCAS(l, 0, 1) != 0) + ; + } + __threadfence(); + g.sync(); +} + +/** + * Releases lock l for the current thread block + * + * @param g The cooperative group that will release the lock + * @param l Pointer to the lock to be released + */ +template +__device__ void release_lock(CG const& g, int* l) +{ + if (g.thread_rank() == 0) { atomicExch(l, 0); } +} + +/** + * Copy pairs from src to dst + * + * @param g The cooperative group that will perform the copy + * @param dst_start Iterator to the beginning of the destination array + * @param src_start Iterator to the beginning of the source array + * @param src_end Iterator to the end of the source array + */ +template +__device__ void copy_pairs(CG const& g, InputIt1 dst_start, InputIt2 src_start, InputIt2 src_end) +{ + auto dst = dst_start + g.thread_rank(); + for (auto src = src_start + g.thread_rank(); src < src_end; dst += g.size(), src += g.size()) { + *dst = *src; + } +} + +/** + * Copy node_size pairs from src to dst + * + * @param g The cooperative group that will perform the copy + * @param dst_start Iterator to the beginning of the destination array + * @param src_start Iterator to the beginning of the source array + * @param num_pairs Number of pairs to copy + */ +template +__device__ void copy_pairs(CG const& g, + InputIt1 dst_start, + InputIt2 src_start, + std::size_t num_pairs) +{ + copy_pairs(g, dst_start, src_start, src_start + num_pairs); +} + +/** + * Merge arrays a and b of size node_size by key, putting the + * node_size elements with the lowest keys in lo, sorted by key, and the + * node_size elements with the highest keys in hi, sorted by key + * + * @param g The cooperative group that will perform the merge and sort + * @param a The first array of pairs to be merged, sorted by key + * @param b The second array of pairs to be merged, sorted by key + * @param lo The array in which the node_size elements with the lowest keys + * will be placed when the merge is completed + * @param hi The array in which the node_size elements with the highest keys + * will be placed when the merge is completed + * @param node_size The size of arrays a, b, lo, and hi + * @param shmem The shared memory layout for this cooperative group + * @param compare Comparison operator ordering the elements to be merged + */ +template +__device__ void merge_and_sort(CG const& g, + T* a, + T* b, + T* lo, + T* hi, + std::size_t node_size, + shared_memory_layout shmem, + Compare const& compare) +{ + merge_and_sort(g, a, b, lo, hi, node_size, node_size, node_size, shmem, compare); +} + +/** + * Merge array a of size num_elements_a and array b of size num_elements_b + * by key. If num_elements_a + num_elements_b <= node_size, all merged elements + * will be placed in lo. Otherwise, the node_size lowest merged elements will + * be placed in lo, and the rest of the elements will be placed in hi. + * + * @param g The cooperative group that will perform the merge and sort + * @param a The first array of pairs to be merged, sorted by key + * @param b The second array of pairs to be merged, sorted by key + * @param lo The array in which the node_size elements with the lowest keys + * will be placed when the merge is completed + * @param hi The array in which the node_size elements with the highest keys + * will be placed when the merge is completed, + * if num_elements_a + num_elements_b > node_size. May be nullptr in + * the case that num_elements_a + num_elements_b < node_size. + * @param num_elements_a The number of pairs in array a + * @param num_elements_b The number of pairs in array b + * @param node_size The size of arrays hi and lo, in other words how many + * elements to insert into lo before starting insertion into + * hi + * @param shmem The shared memory layout for this cooperative group + * @param compare Comparison operator ordering the elements to be merged + */ +template +__device__ void merge_and_sort(CG const& g, + T* a, + T* b, + T* lo, + T* hi, + std::size_t num_elements_a, + std::size_t num_elements_b, + std::size_t node_size, + shared_memory_layout shmem, + Compare const& compare) +{ + const int lane = g.thread_rank(); + const int dim = g.size(); + + if (num_elements_a == node_size && compare(a[node_size - 1], b[0])) { + copy_pairs(g, lo, a, num_elements_a); + + copy_pairs(g, hi, b, num_elements_b); + return; + } + + if (num_elements_b == node_size && compare(b[node_size - 1], a[0])) { + copy_pairs(g, hi, a, num_elements_a); + + copy_pairs(g, lo, b, num_elements_b); + return; + } + + // Array of size 2 * (blockDim.x + 1) + int* const intersections = shmem.intersections; + + if (lane == 0) { + intersections[0] = 0; + intersections[1] = 0; + + intersections[2 * dim] = node_size; + intersections[2 * dim + 1] = node_size; + } + + // Calculate the diagonal spacing + const int p = 2 * node_size / dim; + + // There will be one less diagonal than threads + if (threadIdx.x != 0) { + // i + j = (p * threadIdx.x - 1) + const int j_bl = min((int)node_size - 1, p * lane - 1); + const int i_bl = (p * lane - 1) - j_bl; + + const int diag_len = min(p * lane, (int)node_size - i_bl); + + // Will be the location of the rightmost one + // in the merge-path grid in terms of array a + int rightmost_one = i_bl - 1; + + // Location of leftmost zero + int leftmost_zero = i_bl + diag_len; + + // Binary search along the diagonal + while (leftmost_zero - rightmost_one > 1) { + const int i = (rightmost_one + leftmost_zero) / 2; + const int j = (p * lane - 1) - i; + + if (i >= num_elements_a) { + leftmost_zero = i; + } else if (j >= num_elements_b || compare(a[i], b[j])) { + rightmost_one = i; + } else { + leftmost_zero = i; + } + } + + intersections[2 * lane] = leftmost_zero; + intersections[2 * lane + 1] = (p * lane - 1) - leftmost_zero + 1; + } + + g.sync(); + + // Get the intersection that starts this partition + int i = intersections[2 * lane]; + int j = intersections[2 * lane + 1]; + + // Get the intersection that ends this partition + const int i_max = min(intersections[2 * (lane + 1)], (int)num_elements_a); + const int j_max = min(intersections[2 * (lane + 1) + 1], (int)num_elements_b); + + // Insert location into the output array + int ins_loc = lane * p; + + // Merge our partition into the output arrays + while (i < i_max && j < j_max) { + T next_element; + if (compare(a[i], b[j])) { + next_element = a[i]; + i++; + } else { + next_element = b[j]; + j++; + } + if (ins_loc < node_size) { + lo[ins_loc] = next_element; + } else { + hi[ins_loc - node_size] = next_element; + } + ins_loc++; + } + + // Insert the any remaining elements in a + while (i < i_max) { + if (ins_loc < node_size) { + lo[ins_loc] = a[i]; + i++; + } else { + hi[ins_loc - node_size] = a[i]; + i++; + } + ins_loc++; + } + + // Insert any remaining elements in b + while (j < j_max) { + if (ins_loc < node_size) { + lo[ins_loc] = b[j]; + j++; + } else { + hi[ins_loc - node_size] = b[j]; + j++; + } + ins_loc++; + } +} + +/** + * Sorts the len pairs at start by key + * + * @param g The cooperative group that will perform the sort + * @param start Pointer to the array to be sorted + * @param len Number of pairs to be sorted + * @param node_size A power of two corresponding to the number of pairs + * temp can contain + * @param temp A temporary array containing space for at least the nearest + * power of two greater than len pairs + * @param compare Comparison operator ordering the elements to be sorted + */ +template +__device__ void pb_sort( + CG const& g, T* start, std::size_t len, std::size_t node_size, T* temp, Compare const& compare) +{ + const int lane = g.thread_rank(); + const int dim = g.size(); + + char* const mask = (char*)temp; + + for (int i = lane; i < node_size; i += dim) { + mask[i] = i < len; + } + g.sync(); + + // Build a bitonic sequence + for (int width = 2; width < node_size; width *= 2) { + for (int jump = width / 2; jump >= 1; jump /= 2) { + for (int i = lane; i < node_size / 2; i += dim) { + const int start_jump = width / 2; + const int left = (i / jump) * jump * 2 + i % jump; + const int right = left + jump; + if ((i / start_jump) % 2 == 0) { + if (!mask[left] || (mask[right] && !compare(start[left], start[right]))) { + auto temp = start[left]; + start[left] = start[right]; + start[right] = temp; + + auto temp_mask = mask[left]; + mask[left] = mask[right]; + mask[right] = temp_mask; + } + } else { + if (!mask[right] || (mask[left] && compare(start[left], start[right]))) { + auto temp = start[left]; + start[left] = start[right]; + start[right] = temp; + + auto temp_mask = mask[left]; + mask[left] = mask[right]; + mask[right] = temp_mask; + } + } + } + g.sync(); + } + } + + // Merge to get the sorted result + for (int jump = node_size / 2; jump >= 1; jump /= 2) { + for (int i = lane; i < node_size / 2; i += dim) { + const int left = (i / jump) * jump * 2 + i % jump; + const int right = left + jump; + if (!mask[left] || (mask[right] && !compare(start[left], start[right]))) { + auto temp = start[left]; + start[left] = start[right]; + start[right] = temp; + + auto temp_mask = mask[left]; + mask[left] = mask[right]; + mask[right] = temp_mask; + } + } + g.sync(); + } +} + +/** + * Reverses the bits after the most significant set bit in x + * i.e. if x is 1abc..xyz in binary returns 1zyx...cba + * + * @param x The number whose lower bits will be reversed + * @return The number with all bits after the most significant + * set bit reversed + */ +__device__ int bit_reverse_perm(int x) +{ + const int clz = __clz(x); + + const int bits = sizeof(int) * 8; + const int high_bit = 1 << ((bits - 1) - clz); + const int mask = high_bit - 1; + + const int masked = x & mask; + const int rev = __brev(masked) >> (clz + 1); + + return high_bit | rev; +} + +/** + * Given x, the idx of a node, return when that node is inserted, + * i.e. if x is 6 and lowest_level_start > 6, return 5 since the node + * at element 6 will be the 5th to be inserted with the bit reversal + * permutation. This operation is its own inverse. + * + * @param x The index to operate on + * @param lowest_level_start Index of the first node in the last level of the + * heap + */ +__device__ int insertion_order_index(int x, int lowest_level_start) +{ + assert(x > 0); + + if (x >= lowest_level_start) { return x; } + + return bit_reverse_perm(x); +} + +/** + * Find the index of the parent of the node at index x + * + * @param x The index to operate on + * @param lowest_level_start Index of the first node in the last level of the + * heap + * @return The index of the parent of x + */ +__device__ int parent(int x, int lowest_level_start) +{ + assert(x > 0); + if (x >= lowest_level_start) { return bit_reverse_perm(x) / 2; } + + return x / 2; +} + +/** + * Find the index of the left child of the node at index x + * + * @param x The index to operate on + * @param lowest_level_start Index of the first node in the last level of the + * heap + * @return The index of the left child of x + */ +__device__ int left_child(int x, int lowest_level_start) +{ + assert(x > 0); + int result = x * 2; + + if (result >= lowest_level_start) { result = bit_reverse_perm(result); } + + return result; +} + +/** + * Find the index of the right child of the node at index x + * + * @param x The index to operate on + * @param lowest_level_start Index of the first node in the last level of the + * heap + * @return The index of the right child of x + */ +__device__ int right_child(int x, int lowest_level_start) +{ + assert(x > 0); + int result = x * 2 + 1; + + if (result >= lowest_level_start) { result = bit_reverse_perm(result); } + + return result; +} + +/** + * swim node cur_node up the heap + * Pre: g must hold the lock corresponding to cur_node + * + * @param g The cooperative group that will perform the operation + * @param cur_node Index of the node to swim + * @param heap The array of pairs that stores the heap itself + * @param size Pointer to the number of pairs currently in the heap + * @param node_size Size of the nodes in the heap + * @param locks Array of locks, one for each node in the heap + * @param lowest_level_start Index of the first node in the last level of the + * heap + * @param shmem The shared memory layout for this cooperative group + * @param compare Comparison operator ordering the elements in the heap + */ +template +__device__ void swim(CG const& g, + int cur_node, + T* heap, + int* size, + std::size_t node_size, + int* locks, + int lowest_level_start, + shared_memory_layout shmem, + Compare const& compare) +{ + const int lane = g.thread_rank(); + const int dim = g.size(); + + int cur_parent = parent(cur_node, lowest_level_start); + + // swim the new node up the tree + while (cur_node != 1) { + acquire_lock(g, &(locks[cur_parent])); + + // If the heap property is already satisfied for this node and its + // parent we are done + if (!compare(heap[cur_node * node_size], heap[cur_parent * node_size + node_size - 1])) { + release_lock(g, &(locks[cur_parent])); + break; + } + + merge_and_sort(g, + &heap[cur_parent * node_size], + &heap[cur_node * node_size], + shmem.a, + shmem.b, + node_size, + shmem, + compare); + + g.sync(); + + copy_pairs(g, &heap[cur_parent * node_size], shmem.a, node_size); + copy_pairs(g, &heap[cur_node * node_size], shmem.b, node_size); + + g.sync(); + + release_lock(g, &(locks[cur_node])); + cur_node = cur_parent; + cur_parent = parent(cur_node, lowest_level_start); + } + + release_lock(g, &(locks[cur_node])); +} + +/** + * sink the root down the heap + * Pre: g must hold the root's lock + * + * @param g The cooperative group that will perform the operation + * @param heap The array of pairs that stores the heap itself + * @param size Pointer to the number of pairs currently in the heap + * @param node_size Size of the nodes in the heap + * @param locks Array of locks, one for each node in the heap + * @param lowest_level_start Index of the first node in the last level of the + * heap + * @param node_capacity Max capacity of the heap in nodes + * @param shmem The shared memory layout for this cooperative group + * @param compare Comparison operator ordering the elements in the heap + */ +template +__device__ void sink(CG const& g, + T* heap, + int* size, + std::size_t node_size, + int* locks, + std::size_t* p_buffer_size, + int lowest_level_start, + int node_capacity, + shared_memory_layout shmem, + Compare const& compare) +{ + std::size_t cur = kRootIdx; + + const int dim = g.size(); + + // sink the node + while (insertion_order_index(left_child(cur, lowest_level_start), lowest_level_start) <= + node_capacity) { + const std::size_t left = left_child(cur, lowest_level_start); + const std::size_t right = right_child(cur, lowest_level_start); + + acquire_lock(g, &locks[left]); + + // The left node might have been removed + // since the while loop condition, in which + // case we are already at the bottom of the heap + if (insertion_order_index(left, lowest_level_start) > *size) { + release_lock(g, &locks[left]); + break; + } + + std::size_t lo; + + if (insertion_order_index(right, lowest_level_start) <= node_capacity) { + acquire_lock(g, &locks[right]); + + // Note that even with the bit reversal permutation, + // we can never have a right child without a left child + // + // If we have both children, merge and sort them + if (insertion_order_index(right, lowest_level_start) <= *size) { + std::size_t hi; + + // In order to ensure we preserve the heap property, + // we put the largest node_size elements in the child + // that previously contained the largest element + if (!compare(heap[(left + 1) * node_size - 1], heap[(right + 1) * node_size - 1])) { + hi = left; + lo = right; + } else { + lo = left; + hi = right; + } + + // Skip the merge and sort if the nodes are already correctly + // sorted + if (!compare(heap[(lo + 1) * node_size - 1], heap[hi * node_size])) { + merge_and_sort(g, + &heap[left * node_size], + &heap[right * node_size], + shmem.a, + shmem.b, + node_size, + shmem, + compare); + + g.sync(); + + copy_pairs(g, &heap[hi * node_size], shmem.b, node_size); + copy_pairs(g, &heap[lo * node_size], shmem.a, node_size); + + g.sync(); + } + release_lock(g, &locks[hi]); + } else { + lo = left; + release_lock(g, &locks[right]); + } + } else { + lo = left; + } + + merge_and_sort(g, + &heap[lo * node_size], + &heap[cur * node_size], + shmem.a, + shmem.b, + node_size, + shmem, + compare); + + g.sync(); + + copy_pairs(g, &heap[lo * node_size], shmem.b, node_size); + copy_pairs(g, &heap[cur * node_size], shmem.a, node_size); + + g.sync(); + + release_lock(g, &locks[cur]); + + cur = lo; + } + release_lock(g, &locks[cur]); +} + +/** + * Add exactly node_size elements into the heap from + * elements + * + * @param g The cooperative group that will perform the push + * @param elements Iterator for the elements to be inserted + * @param heap The array of pairs that stores the heap itself + * @param size Pointer to the number of pairs currently in the heap + * @param node_size Size of the nodes in the heap + * @param locks Array of locks, one for each node in the heap + * @param lowest_level_start Index of the first node in the last level of the + * heap + * @param shmem The shared memory layout for this cooperative group + * @param compare Comparison operator ordering the elements in the heap + */ +template +__device__ void push_single_node(CG const& g, + InputIt elements, + T* heap, + int* size, + std::size_t node_size, + int* locks, + int lowest_level_start, + shared_memory_layout shmem, + Compare const& compare) +{ + const int lane = g.thread_rank(); + const int dim = g.size(); + + copy_pairs(g, shmem.a, elements, elements + node_size); + + g.sync(); + + pb_sort(g, shmem.a, node_size, node_size, shmem.b, compare); + + int* const cur_node_temp = (int*)shmem.intersections; + if (lane == 0) { *cur_node_temp = atomicAdd(size, 1) + 1; } + g.sync(); + + const int cur_node = insertion_order_index(*cur_node_temp, lowest_level_start); + + acquire_lock(g, &(locks[cur_node])); + + copy_pairs(g, &heap[cur_node * node_size], shmem.a, node_size); + + g.sync(); + + swim(g, cur_node, heap, size, node_size, locks, lowest_level_start, shmem, compare); +} + +/** + * Remove exactly node_size elements from the heap and place them + * in elements + * + * @param g The cooperative group that will perform the pop + * @param elements Iterator to the elements to write to + * @param heap The array of pairs that stores the heap itself + * @param size Pointer to the number of pairs currently in the heap + * @param node_size Size of the nodes in the heap + * @param locks Array of locks, one for each node in the heap + * @param p_buffer_size Number of pairs in the heap's partial buffer + * @param lowest_level_start Index of the first node in the last level of the + * heap + * @param node_capacity Maximum capacity of the heap in nodes + * @param shmem The shared memory layout for this cooperative group + * @param compare Comparison operator ordering the elements in the heap + */ +template +__device__ void pop_single_node(CG const& g, + OutputIt elements, + T* heap, + int* size, + std::size_t node_size, + int* locks, + std::size_t* p_buffer_size, + int lowest_level_start, + int node_capacity, + shared_memory_layout shmem, + Compare const& compare) +{ + const int lane = g.thread_rank(); + const int dim = g.size(); + + acquire_lock(g, &locks[kRootIdx]); + if (*size == 0) { + copy_pairs(g, elements, heap, node_size); + + if (lane == 0) { *p_buffer_size = 0; } + g.sync(); + return; + } + + // Find the target node (the last one inserted) and + // decrement the size + + const std::size_t tar = insertion_order_index(*size, lowest_level_start); + + if (tar != 1) { acquire_lock(g, &locks[tar]); } + + g.sync(); + + if (lane == 0) { *size -= 1; } + g.sync(); + + // Copy the root to the output array + + copy_pairs(g, elements, &heap[node_size], &heap[node_size] + node_size); + + g.sync(); + + // Copy the target node to the root + + if (tar != kRootIdx) { + copy_pairs(g, &heap[node_size], &heap[tar * node_size], node_size); + + release_lock(g, &locks[tar]); + + g.sync(); + } + + // Merge and sort the root and the partial buffer + + merge_and_sort(g, + &heap[node_size], + &heap[kPBufferIdx], + shmem.a, + shmem.b, + node_size, + *p_buffer_size, + node_size, + shmem, + compare); + + g.sync(); + + copy_pairs(g, &heap[node_size], shmem.a, node_size); + + copy_pairs(g, heap, shmem.b, *p_buffer_size); + + g.sync(); + + sink(g, + heap, + size, + node_size, + locks, + p_buffer_size, + lowest_level_start, + node_capacity, + shmem, + compare); +} + +/** + * Remove num_elements < node_size elements from the heap and place them + * in elements + * + * @param elements The array of elements to insert into + * @param num_elements The number of elements to remove + * @param heap The array of pairs that stores the heap itself + * @param size Pointer to the number of pairs currently in the heap + * @param node_size Size of the nodes in the heap + * @param locks Array of locks, one for each node in the heap + * @param p_buffer_size Number of pairs in the heap's partial buffer + * @param lowest_level_start Index of the first node in the last level of the + * heap + * @param node_capacity Maximum capacity of the heap in nodes + * @param shmem The shared memory layout for this cooperative group + * @param compare Comparison operator ordering the elements in the heap + */ +template +__device__ void pop_partial_node(CG const& g, + InputIt elements, + std::size_t num_elements, + T* heap, + int* size, + std::size_t node_size, + int* locks, + std::size_t* p_buffer_size, + int lowest_level_start, + int node_capacity, + shared_memory_layout shmem, + Compare const& compare) +{ + const int lane = g.thread_rank(); + const int dim = g.size(); + + acquire_lock(g, &locks[kRootIdx]); + + if (*size == 0) { + copy_pairs(g, elements, heap, num_elements); + g.sync(); + + const std::size_t n_p_buffer_size = *p_buffer_size - num_elements; + + copy_pairs(g, shmem.a, heap + num_elements, n_p_buffer_size); + + g.sync(); + + copy_pairs(g, heap, shmem.a, n_p_buffer_size); + + if (lane == 0) { *p_buffer_size = n_p_buffer_size; } + + release_lock(g, &locks[kRootIdx]); + } else { + copy_pairs(g, elements, &heap[kRootIdx * node_size], num_elements); + g.sync(); + + if (*p_buffer_size >= num_elements) { + merge_and_sort(g, + &heap[kPBufferIdx], + &heap[kRootIdx * node_size] + num_elements, + shmem.a, + shmem.b, + *p_buffer_size, + node_size - num_elements, + node_size, + shmem, + compare); + + g.sync(); + + if (lane == 0) { *p_buffer_size = *p_buffer_size - num_elements; } + + g.sync(); + + copy_pairs(g, &heap[kRootIdx * node_size], shmem.a, node_size); + copy_pairs(g, &heap[kPBufferIdx], shmem.b, *p_buffer_size); + + g.sync(); + + sink(g, + heap, + size, + node_size, + locks, + p_buffer_size, + lowest_level_start, + node_capacity, + shmem, + compare); + } else { + merge_and_sort(g, + &heap[kPBufferIdx], + &heap[kRootIdx * node_size] + num_elements, + shmem.a, + (T*)nullptr, + *p_buffer_size, + node_size - num_elements, + node_size, + shmem, + compare); + + g.sync(); + + copy_pairs(g, &heap[kPBufferIdx], shmem.a, *p_buffer_size + node_size - num_elements); + + const int tar = insertion_order_index(*size, lowest_level_start); + g.sync(); + + *p_buffer_size += node_size; + *p_buffer_size -= num_elements; + + g.sync(); + + if (lane == 0) { *size -= 1; } + + if (tar != kRootIdx) { + acquire_lock(g, &locks[tar]); + + copy_pairs(g, &heap[kRootIdx * node_size], &heap[tar * node_size], node_size); + + g.sync(); + + release_lock(g, &locks[tar]); + + merge_and_sort(g, + &heap[node_size], + &heap[kPBufferIdx], + shmem.a, + shmem.b, + node_size, + *p_buffer_size, + node_size, + shmem, + compare); + g.sync(); + + copy_pairs(g, &heap[node_size], shmem.a, node_size); + + copy_pairs(g, heap, shmem.b, *p_buffer_size); + + g.sync(); + + sink(g, + heap, + size, + node_size, + locks, + p_buffer_size, + lowest_level_start, + node_capacity, + shmem, + compare); + } else { + release_lock(g, &locks[kRootIdx]); + } + } + } +} + +/** + * Add p_ins_size < node_size elements into the heap from + * elements + * + * @param g The cooperative group that will perform the push + * @param elements The array of elements to add + * @param p_ins_size The number of elements to be inserted + * @param heap The array of pairs that stores the heap itself + * @param size Pointer to the number of pairs currently in the heap + * @param node_size Size of the nodes in the heap + * @param locks Array of locks, one for each node in the heap + * @param p_buffer_size The size of the partial buffer + * @param lowest_level_start Index of the first node in the last level of the + * heap + * @param shmem The shared memory layout for this cooperative group + * @param compare Comparison operator ordering the elements in the heap + */ +template +__device__ void push_partial_node(CG const& g, + InputIt elements, + std::size_t p_ins_size, + T* heap, + int* size, + std::size_t node_size, + int* locks, + std::size_t* p_buffer_size, + int lowest_level_start, + shared_memory_layout shmem, + Compare const& compare) +{ + const int lane = g.thread_rank(); + const int dim = g.size(); + + acquire_lock(g, &locks[kRootIdx]); + + copy_pairs(g, shmem.b, elements, p_ins_size); + + pb_sort(g, shmem.b, p_ins_size, node_size, shmem.a, compare); + + // There is enough data for a new node, in which case we + // construct a new node and insert it + if (*p_buffer_size + p_ins_size >= node_size) { + int* const cur_node_temp = shmem.intersections; + if (lane == 0) { *cur_node_temp = atomicAdd(size, 1) + 1; } + g.sync(); + + const int cur_node = insertion_order_index(*cur_node_temp, lowest_level_start); + + if (cur_node != kRootIdx) { acquire_lock(g, &(locks[cur_node])); } + + g.sync(); + + merge_and_sort(g, + shmem.b, + &heap[kPBufferIdx], + &heap[cur_node * node_size], + shmem.a, + p_ins_size, + *p_buffer_size, + node_size, + shmem, + compare); + + if (lane == 0) { *p_buffer_size = (*p_buffer_size + p_ins_size) - node_size; } + + g.sync(); + + copy_pairs(g, heap, shmem.a, *p_buffer_size); + + if (cur_node != kRootIdx) { release_lock(g, &locks[kRootIdx]); } + + swim(g, cur_node, heap, size, node_size, locks, lowest_level_start, shmem, compare); + + } else { + // There are not enough elements for a new node, + // in which case we merge and sort the root and + // the elements to be inserted and then the root + // and the partial buffer + + merge_and_sort(g, + shmem.b, + &heap[kPBufferIdx], + shmem.a, + (T*)nullptr, + p_ins_size, + *p_buffer_size, + node_size, + shmem, + compare); + + g.sync(); + + if (lane == 0) { *p_buffer_size += p_ins_size; } + + g.sync(); + + copy_pairs(g, heap, shmem.a, *p_buffer_size); + + g.sync(); + + if (*size > 0) { + merge_and_sort(g, + &heap[node_size], + &heap[kPBufferIdx], + shmem.a, + shmem.b, + node_size, + *p_buffer_size, + node_size, + shmem, + compare); + g.sync(); + + copy_pairs(g, heap, shmem.b, *p_buffer_size); + + copy_pairs(g, &heap[node_size], shmem.a, node_size); + + g.sync(); + } + release_lock(g, &locks[kRootIdx]); + } +} + +/** +* Add num_elements elements into the heap from +* elements +* @param elements The array of elements to add +* @param num_elements The number of elements to be inserted +* @param heap The array of pairs that stores the heap itself +* @param size Pointer to the number of pairs currently in the heap +* @param node_size Size of the nodes in the heap +* @param locks Array of locks, one for each node in the heap +* @param p_buffer_size Number of pairs in the heap's partial buffer +* @param temp_node A temporary array large enough to store + sizeof(T) * node_size bytes +* @param lowest_level_start The first index of the heaps lowest layer +* @param compare Comparison operator ordering the elements in the heap +*/ +template +__global__ void push_kernel(OutputIt elements, + std::size_t num_elements, + T* heap, + int* size, + std::size_t node_size, + int* locks, + std::size_t* p_buffer_size, + int lowest_level_start, + Compare compare) +{ + extern __shared__ int s[]; + + const shared_memory_layout shmem = get_shared_memory_layout(s, blockDim.x, node_size); + + // We push as many elements as possible as full nodes, + // then deal with the remaining elements as a partial insertion + // below + cg::thread_block g = cg::this_thread_block(); + for (std::size_t i = blockIdx.x * node_size; i + node_size <= num_elements; + i += gridDim.x * node_size) { + push_single_node( + g, elements + i, heap, size, node_size, locks, lowest_level_start, shmem, compare); + } + + // We only need one block for partial insertion + if (blockIdx.x != 0) { return; } + + // If node_size does not divide num_elements, there are some leftover + // elements for which we must perform a partial insertion + const std::size_t first_not_inserted = (num_elements / node_size) * node_size; + + if (first_not_inserted < num_elements) { + const std::size_t p_ins_size = num_elements - first_not_inserted; + push_partial_node(g, + elements + first_not_inserted, + p_ins_size, + heap, + size, + node_size, + locks, + p_buffer_size, + lowest_level_start, + shmem, + compare); + } +} + +/** + * Remove exactly node_size elements from the heap and place them + * in elements + * @param elements The array of elements to insert into + * @param num_elements The number of elements to remove + * @param heap The array of pairs that stores the heap itself + * @param size Pointer to the number of pairs currently in the heap + * @param node_size Size of the nodes in the heap + * @param locks Array of locks, one for each node in the heap + * @param p_buffer_size Number of pairs in the heap's partial buffer + * @param lowest_level_start The first index of the heaps lowest layer + * @param node_capacity The capacity of the heap in nodes + * @param compare Comparison operator ordering the elements in the heap + */ +template +__global__ void pop_kernel(OutputIt elements, + std::size_t num_elements, + T* heap, + int* size, + std::size_t node_size, + int* locks, + std::size_t* p_buffer_size, + int lowest_level_start, + int node_capacity, + Compare compare) +{ + extern __shared__ int s[]; + + const shared_memory_layout shmem = get_shared_memory_layout(s, blockDim.x, node_size); + + cg::thread_block g = cg::this_thread_block(); + for (std::size_t i = blockIdx.x; i < num_elements / node_size; i += gridDim.x) { + pop_single_node(g, + elements + i * node_size, + heap, + size, + node_size, + locks, + p_buffer_size, + lowest_level_start, + node_capacity, + shmem, + compare); + } + + // We only need one block for partial deletion + if (blockIdx.x != 0) { return; } + + // If node_size does not divide num_elements, there are some leftover + // elements for which we must perform a partial deletion + const std::size_t first_not_inserted = (num_elements / node_size) * node_size; + + if (first_not_inserted < num_elements) { + const std::size_t p_del_size = num_elements - first_not_inserted; + pop_partial_node(g, + elements + first_not_inserted, + p_del_size, + heap, + size, + node_size, + locks, + p_buffer_size, + lowest_level_start, + node_capacity, + shmem, + compare); + } +} + +} // namespace detail + +} // namespace cuco diff --git a/include/cuco/priority_queue.cuh b/include/cuco/priority_queue.cuh new file mode 100644 index 000000000..a7c0d3a1a --- /dev/null +++ b/include/cuco/priority_queue.cuh @@ -0,0 +1,266 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +#include +#include + +namespace cuco { + +/* + * @brief A GPU-accelerated priority queue of key-value pairs + * + * Allows for multiple concurrent insertions as well as multiple concurrent + * deletions + * + * Current limitations: + * - Does not support insertion and deletion at the same time + * - The implementation of the priority queue is based on + * https://arxiv.org/pdf/1906.06504.pdf, which provides a way to allow + * concurrent insertion and deletion, so this could be added later if useful + * - Capacity is fixed and the queue does not automatically resize + * - Deletion from the queue is much slower than insertion into the queue + * due to congestion at the underlying heap's root node + * + * The queue supports two operations: + * `push`: Add elements into the queue + * `pop`: Remove the element(s) with the lowest (when Max == false) or highest + * (when Max == true) keys + * + * The priority queue supports bulk host-side operations and more fine-grained + * device-side operations. + * + * The host-side bulk operations `push` and `pop` allow an arbitrary number of + * elements to be pushed to or popped from the queue. + * + * The device-side operations allow a cooperative group to push or pop from + * device code. These device side + * operations are invoked with a trivially-copyable device view, + * `device_mutable_view` which can be obtained with the host function + * `get_mutable_device_view` and passed to the device. + * + * @tparam T Type of the elements stored in the queue + * @tparam Compare Comparison operator used to order the elements in the queue + * @tparam Allocator Allocator defining how memory is allocated internally + */ +template , + typename Allocator = cuco::cuda_allocator> +class priority_queue { + using int_allocator_type = typename std::allocator_traits::rebind_alloc; + + using t_allocator_type = typename std::allocator_traits::rebind_alloc; + + using size_t_allocator_type = typename std::allocator_traits::rebind_alloc; + + public: + /** + * @brief Construct a priority queue + * + * @param initial_capacity The number of elements the priority queue can hold + * @param alloc Allocator used for allocating device storage + * @param stream Stream used for constructing the priority queue + */ + priority_queue(std::size_t initial_capacity, + Allocator const& alloc = Allocator{}, + cudaStream_t stream = 0); + + /** + * @brief Push elements into the priority queue + * + * @tparam InputIt Device accessible input iterator whose `value_type` + * can be converted to T + * @param first Beginning of the sequence of elements + * @param last End of the sequence of elements + * @param stream The stream in which the underlying device operations will be + * executed + */ + template + void push(InputIt first, InputIt last, cudaStream_t stream = 0); + + /** + * @brief Remove a sequence of the lowest elements ordered by Compare + * + * @tparam OutputIt Device accessible output iterator whose `value_type` + * can be converted to T + * @param first Beginning of the sequence of output elements + * @param last End of the sequence of output elements + * @param stream The stream in which the underlying GPU operations will be + * run + */ + template + void pop(OutputIt first, OutputIt last, cudaStream_t stream = 0); + + /* + * @brief Return the amount of shared memory required for operations + * on the queue with a thread block size of block_size + * + * @param block_size Size of the blocks to calculate storage for + * @return The amount of temporary storage required in bytes + */ + int get_shmem_size(int const block_size) const + { + int intersection_bytes = 2 * (block_size + 1) * sizeof(int); + int node_bytes = node_size_ * sizeof(T); + return intersection_bytes + 2 * node_bytes; + } + + /** + * @brief Destroys the queue and frees its contents + */ + ~priority_queue(); + + class device_mutable_view { + public: + using value_type = T; + /** + * @brief Push elements into the priority queue + * + * @tparam CG Cooperative Group type + * @tparam InputIt Device accessible iterator whose `value_type` + * is convertible to T + * @param g The cooperative group that will perform the operation + * @param first The beginning of the sequence of elements to insert + * @param last The end of the sequence of elements to insert + * @param temp_storage Pointer to a contiguous section of memory + * large enough to hold get_shmem_size(g.size()) bytes + */ + template + __device__ void push(CG const& g, InputIt first, InputIt last, void* temp_storage); + + /** + * @brief Pop elements from the priority queue + * + * @tparam CG Cooperative Group type + * @tparam OutputIt Device accessible iterator whose `value_type` + * is convertible to T + * @param g The cooperative group that will perform the operation + * @param first The beginning of the sequence of elements to output into + * @param last The end of the sequence of elements to output into + * @param temp_storage Pointer to a contiguous section of memory + * large enough to hold get_shmem_size(g.size()) bytes + */ + template + __device__ void pop(CG const& g, OutputIt first, OutputIt last, void* temp_storage); + + /* + * @brief Return the amount of temporary storage required for operations + * on the queue with a cooperative group size of block_size + * + * @param block_size Size of the cooperative groups to calculate storage for + * @return The amount of temporary storage required in bytes + */ + __device__ int get_shmem_size(int block_size) const + { + int intersection_bytes = 2 * (block_size + 1) * sizeof(int); + int node_bytes = node_size_ * sizeof(T); + return intersection_bytes + 2 * node_bytes; + } + + __host__ __device__ device_mutable_view(std::size_t node_size, + T* d_heap, + int* d_size, + std::size_t* d_p_buffer_size, + int* d_locks, + int lowest_level_start, + int node_capacity, + Compare const& compare) + : node_size_(node_size), + d_heap_(d_heap), + d_size_(d_size), + d_p_buffer_size_(d_p_buffer_size), + d_locks_(d_locks), + lowest_level_start_(lowest_level_start), + node_capacity_(node_capacity), + compare_(compare) + { + } + + private: + std::size_t node_size_; ///< Size of the heap's nodes (i.e. number of T's + /// in each node) + int lowest_level_start_; ///< Index in `d_heap_` of the first node in the + /// heap's lowest level + int node_capacity_; ///< Capacity of the heap in nodes + + T* d_heap_; ///< Pointer to an array of nodes, the 0th node + /// being the heap's partial buffer, and nodes + /// 1..(node_capacity_) being the heap, where + /// the 1st node is the root + int* d_size_; ///< Number of nodes currently in the heap + std::size_t* d_p_buffer_size_; ///< Number of elements currently in the + /// partial buffer + int* d_locks_; ///< Array of locks where `d_locks_[i]` is the + /// lock for the node starting at + /// d_heap_[node_size * i]` + Compare compare_{}; ///< Comparator used to order the elements in the queue + }; + + /* + * @brief Returns a trivially-copyable class that can be used to perform + * insertion and deletion of single nodes in device code with + * cooperative groups + * + * @return A device view + */ + device_mutable_view get_mutable_device_view() const noexcept + { + return device_mutable_view(node_size_, + d_heap_, + d_size_, + d_p_buffer_size_, + d_locks_, + lowest_level_start_, + node_capacity_, + compare_); + } + + private: + std::size_t node_size_; ///< Size of the heap's nodes (i.e. number of T's + /// in each node) + int lowest_level_start_; ///< Index in `d_heap_` of the first node in the + /// heap's lowest level + int node_capacity_; ///< Capacity of the heap in nodes + + T* d_heap_; ///< Pointer to an array of nodes, the 0th node + /// being the heap's partial buffer, and nodes + /// 1..(node_capacity_) being the heap, where the + /// 1st node is the root + int* d_size_; ///< Number of nodes currently in the heap + std::size_t* d_p_buffer_size_; ///< Number of elements currently in the + /// partial buffer + int* d_locks_; ///< Array of locks where `d_locks_[i]` is the + /// lock for the node starting at + /// d_heap_[node_size * i]` + + int_allocator_type int_allocator_; ///< Allocator used to allocated ints + /// for example, the lock array + t_allocator_type t_allocator_; ///< Allocator used to allocate T's + /// and therefore nodes + size_t_allocator_type size_t_allocator_; ///< Allocator used to allocate + /// size_t's, e.g. d_p_buffer_size_ + + Compare compare_{}; ///< Comparator used to order the elements in the queue +}; + +} // namespace cuco + +#include diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e9d256ce1..b5d417837 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -67,6 +67,11 @@ foreach(arch IN LISTS CMAKE_CUDA_ARCHITECTURES) endif() endforeach() +################################################################################################### +# - priority_queue tests -------------------------------------------------------------------------- +ConfigureTest(PRIORITY_QUEUE_TEST + priority_queue/priority_queue_test.cu) + ################################################################################################### # - dynamic_map tests ----------------------------------------------------------------------------- ConfigureTest(DYNAMIC_MAP_TEST @@ -81,3 +86,4 @@ ConfigureTest(STATIC_MULTIMAP_TEST static_multimap/multiplicity_test.cu static_multimap/non_match_test.cu static_multimap/pair_function_test.cu) + diff --git a/tests/priority_queue/priority_queue_test.cu b/tests/priority_queue/priority_queue_test.cu new file mode 100644 index 000000000..84d3353bc --- /dev/null +++ b/tests/priority_queue/priority_queue_test.cu @@ -0,0 +1,413 @@ +/* + * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include +#include + +#include + +#include +#include + +#include +#include + +using namespace cuco; +namespace cg = cooperative_groups; + +template +struct kv_pair { + K first; + V second; +}; + +template +bool __host__ __device__ operator==(const kv_pair& a, const kv_pair& b) +{ + return a.first == b.first && a.second == b.second; +} + +template +bool __host__ __device__ operator<(const kv_pair& a, const kv_pair& b) +{ + if (a.first == b.first) { + return a.second < b.second; + } else { + return a.first < b.first; + } +} + +template +struct kv_less { + __host__ __device__ bool operator()(const T& a, const T& b) const { return a.first < b.first; } +}; + +template +std::map construct_count_map(std::vector& a) +{ + std::map result; + + for (T& e : a) { + if (result.find(e) == result.end()) { result.emplace(e, 0); } + + result[e]++; + } + + return result; +} + +template +bool is_valid_top_n(std::vector top_n, std::vector elements) +{ + const auto top_n_map = construct_count_map(top_n); + const auto elements_map = construct_count_map(elements); + + const size_t n = top_n.size(); + + // 1. Check that the count of each element in the top n is less than or + // equal to the count of that element overall in the queue + for (auto& pair : top_n_map) { + if (elements_map.find(pair.first) == elements_map.end() || + elements_map.at(pair.first) < pair.second) { + return false; + } + } + + // 2. Check that each element in the top N is not ordered + // after the ith element of the sorted list of elements + std::sort(elements.begin(), elements.end(), Compare{}); + + std::sort(top_n.begin(), top_n.end(), Compare{}); + + for (int i = 0; i < top_n.size(); i++) { + const T max = elements[i]; + const T e = top_n[i]; + if (Compare{}(max, e)) { return false; } + } + + return true; +} + +template +static void generate_element(T& e, std::mt19937& gen) +{ + e = static_cast(gen()); +} + +template +void generate_element(kv_pair& e, std::mt19937& gen) +{ + generate_element(e.first, gen); + generate_element(e.second, gen); +} + +template +static std::vector generate_elements(size_t num_keys) +{ + std::random_device rd; + std::mt19937 gen{rd()}; + + std::vector result(num_keys); + + for (auto i = 0; i < num_keys; i++) { + generate_element(result[i], gen); + } + + return result; +} + +template +static void insert_to_queue(priority_queue& pq, const std::vector& v) +{ + const thrust::device_vector d_v(v); + + pq.push(d_v.begin(), d_v.end()); + + cudaDeviceSynchronize(); +} + +template +static std::vector pop_from_queue(priority_queue& pq, size_t n) +{ + thrust::device_vector d_popped(n); + + pq.pop(d_popped.begin(), d_popped.end()); + + cudaDeviceSynchronize(); + + const thrust::host_vector h_popped(d_popped); + + std::vector result(h_popped.size()); + + thrust::copy(thrust::host, h_popped.begin(), h_popped.end(), result.begin()); + + return result; +} + +// Insert elements into the queue and check that they are +// all returned when removed from the queue +template +bool test_insertion_and_deletion(priority_queue& pq, + const std::vector& elements, + size_t n) +{ + insert_to_queue(pq, elements); + + const auto popped_elements = pop_from_queue(pq, n); + + return is_valid_top_n(popped_elements, elements); +} + +TEST_CASE("Single uint32_t element", "") +{ + priority_queue pq(1); + + const std::vector els = {1}; + + REQUIRE(test_insertion_and_deletion(pq, els, 1)); +} + +TEST_CASE("New node created on partial insertion") +{ + const size_t insertion_size = 600; + const size_t num_elements = insertion_size * 2; + + priority_queue pq(num_elements); + + std::vector els = generate_elements(num_elements); + + std::vector first_insertion(els.begin(), els.begin() + insertion_size); + + std::vector second_insertion(els.begin() + insertion_size, els.end()); + + insert_to_queue(pq, first_insertion); + + insert_to_queue(pq, second_insertion); + + const auto popped_elements = pop_from_queue(pq, insertion_size); + + REQUIRE(is_valid_top_n>(popped_elements, els)); +} + +TEST_CASE("Insert, delete, insert, delete", "") +{ + const size_t first_insertion_size = 100'000; + const size_t first_deletion_size = 10'000; + const size_t second_insertion_size = 20'000; + const size_t second_deletion_size = 50'000; + using T = uint32_t; + using Compare = thrust::less; + + priority_queue pq(first_insertion_size + second_insertion_size); + + auto first_insertion_els = generate_elements(first_insertion_size); + + const auto second_insertion_els = generate_elements(second_insertion_size); + + insert_to_queue(pq, first_insertion_els); + + const auto first_popped_elements = pop_from_queue(pq, first_deletion_size); + + insert_to_queue(pq, second_insertion_els); + + const auto second_popped_elements = pop_from_queue(pq, second_deletion_size); + + std::vector remaining_elements; + + std::sort(first_insertion_els.begin(), first_insertion_els.end(), Compare{}); + + remaining_elements.insert(remaining_elements.end(), + first_insertion_els.begin() + first_deletion_size, + first_insertion_els.end()); + + remaining_elements.insert( + remaining_elements.end(), second_insertion_els.begin(), second_insertion_els.end()); + + REQUIRE((is_valid_top_n(first_popped_elements, first_insertion_els) && + is_valid_top_n(second_popped_elements, remaining_elements))); +} + +TEST_CASE("Insertion and deletion on different streams", "") +{ + const size_t insertion_size = 100'000; + const size_t deletion_size = 10'000; + using T = uint32_t; + using Compare = thrust::less; + + const auto elements = generate_elements(insertion_size * 2); + const thrust::device_vector insertion1(elements.begin(), elements.begin() + insertion_size); + const thrust::device_vector insertion2(elements.begin() + insertion_size, elements.end()); + + priority_queue pq(insertion_size * 2); + + cudaStream_t stream1, stream2; + + cudaStreamCreate(&stream1); + cudaStreamCreate(&stream2); + + pq.push(insertion1.begin(), insertion1.end(), stream1); + pq.push(insertion2.begin(), insertion2.end(), stream2); + + cudaStreamSynchronize(stream1); + cudaStreamSynchronize(stream2); + + thrust::device_vector deletion1(deletion_size); + thrust::device_vector deletion2(deletion_size); + + pq.pop(deletion1.begin(), deletion1.end(), stream1); + pq.pop(deletion2.begin(), deletion2.end(), stream2); + + cudaStreamSynchronize(stream1); + cudaStreamSynchronize(stream2); + + const thrust::host_vector h_deletion1(deletion1); + const thrust::host_vector h_deletion2(deletion2); + + std::vector popped_elements(h_deletion1.begin(), h_deletion1.end()); + + popped_elements.insert(popped_elements.end(), h_deletion2.begin(), h_deletion2.end()); + + REQUIRE(is_valid_top_n(popped_elements, elements)); + + cudaStreamDestroy(stream1); + cudaStreamDestroy(stream2); +} + +template +__global__ void device_api_insert(View view, InputIt begin, InputIt end) +{ + extern __shared__ int shmem[]; + cg::thread_block g = cg::this_thread_block(); + view.push(g, begin, end, shmem); +} + +template +__global__ void device_api_delete(View view, OutputIt begin, OutputIt end) +{ + extern __shared__ int shmem[]; + cg::thread_block g = cg::this_thread_block(); + view.pop(g, begin, end, shmem); +} + +TEST_CASE("Insertion and deletion with Device API", "") +{ + const size_t insertion_size = 2000; + const size_t deletion_size = 1000; + using T = uint32_t; + using Compare = thrust::less; + + const auto els = generate_elements(insertion_size); + + const thrust::device_vector d_els(els); + + priority_queue pq(insertion_size); + + const int block_size = 32; + device_api_insert<<<1, block_size, pq.get_shmem_size(block_size)>>>( + pq.get_mutable_device_view(), d_els.begin(), d_els.end()); + + cudaDeviceSynchronize(); + + thrust::device_vector d_pop_result(deletion_size); + + device_api_delete<<<1, block_size, pq.get_shmem_size(block_size)>>>( + pq.get_mutable_device_view(), d_pop_result.begin(), d_pop_result.end()); + + cudaDeviceSynchronize(); + + const thrust::host_vector h_pop_result(d_pop_result); + const std::vector pop_result(h_pop_result.begin(), h_pop_result.end()); + + REQUIRE(is_valid_top_n(pop_result, els)); +} + +TEST_CASE("Concurrent insertion and deletion with Device API", "") +{ + const size_t insertion_size = 1000; + const size_t deletion_size = 500; + const int block_size = 32; + using T = uint32_t; + using Compare = thrust::less; + + const auto els = generate_elements(insertion_size * 2); + + const thrust::device_vector insertion1(els.begin(), els.begin() + insertion_size); + const thrust::device_vector insertion2(els.begin() + insertion_size, els.end()); + + priority_queue pq(insertion_size * 2); + + cudaStream_t stream1, stream2; + + cudaStreamCreate(&stream1); + cudaStreamCreate(&stream2); + + device_api_insert<<<1, block_size, pq.get_shmem_size(block_size), stream1>>>( + pq.get_mutable_device_view(), insertion1.begin(), insertion1.end()); + + device_api_insert<<<1, block_size, pq.get_shmem_size(block_size), stream2>>>( + pq.get_mutable_device_view(), insertion2.begin(), insertion2.end()); + + cudaStreamSynchronize(stream1); + cudaStreamSynchronize(stream2); + + thrust::device_vector d_deletion1(deletion_size); + thrust::device_vector d_deletion2(deletion_size); + + device_api_delete<<<1, block_size, pq.get_shmem_size(block_size), stream1>>>( + pq.get_mutable_device_view(), d_deletion1.begin(), d_deletion1.end()); + + device_api_delete<<<1, block_size, pq.get_shmem_size(block_size), stream2>>>( + pq.get_mutable_device_view(), d_deletion2.begin(), d_deletion2.end()); + + cudaStreamSynchronize(stream1); + cudaStreamSynchronize(stream2); + + const thrust::host_vector h_deletion1(d_deletion1); + const thrust::host_vector h_deletion2(d_deletion2); + + std::vector result(h_deletion1.begin(), h_deletion1.end()); + result.insert(result.end(), h_deletion2.begin(), h_deletion2.end()); + + REQUIRE(is_valid_top_n(result, els)); + + cudaStreamDestroy(stream1); + cudaStreamDestroy(stream2); +} + +TEMPLATE_TEST_CASE_SIG( + "N deletions are correct", + "", + ((typename T, typename Compare, size_t N, size_t NumKeys), T, Compare, N, NumKeys), + (uint32_t, thrust::less, 100, 10'000'000), + (uint64_t, thrust::less, 100, 10'000'000), + (kv_pair, kv_less>, 100, 10'000'000), + (uint32_t, thrust::less, 10'000, 10'000'000), + (uint64_t, thrust::less, 10'000, 10'000'000), + (uint64_t, thrust::greater, 10'000, 10'000'000), + (kv_pair, kv_less>, 10'000, 10'000'000), + (kv_pair, kv_less>, 10'000, 10'000'000), + (uint32_t, thrust::less, 10'000'000, 10'000'000), + (uint64_t, thrust::less, 10'000'000, 10'000'000), + (kv_pair, kv_less>, 10'000'000, 10'000'000)) +{ + priority_queue pq(NumKeys); + + const auto els = generate_elements(NumKeys); + + REQUIRE(test_insertion_and_deletion(pq, els, N)); +}