diff --git a/.github/copy-pr-bot.yaml b/.github/copy-pr-bot.yaml new file mode 100644 index 000000000..895ba83ee --- /dev/null +++ b/.github/copy-pr-bot.yaml @@ -0,0 +1,4 @@ +# Configuration file for `copy-pr-bot` GitHub App +# https://docs.gha-runners.nvidia.com/apps/copy-pr-bot/ + +enabled: true diff --git a/.github/ops-bot.yaml b/.github/ops-bot.yaml index 2d1444c59..9a0b41550 100644 --- a/.github/ops-bot.yaml +++ b/.github/ops-bot.yaml @@ -5,5 +5,4 @@ auto_merger: true branch_checker: true label_checker: true release_drafter: true -copy_prs: true recently_updated: true diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 15180c9f0..46450bc96 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -38,7 +38,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@branch-23.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -49,12 +49,12 @@ jobs: if: github.ref_type == 'branch' needs: [python-build] secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@branch-23.10 with: arch: "amd64" branch: ${{ inputs.branch }} build_type: ${{ inputs.build_type || 'branch' }} - container_image: "rapidsai/ci:latest" + container_image: "rapidsai/ci-conda:latest" date: ${{ inputs.date }} node_type: "gpu-v100-latest-1" run_script: "ci/build_docs.sh" @@ -62,7 +62,7 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-upload-packages.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-upload-packages.yaml@branch-23.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -70,7 +70,7 @@ jobs: sha: ${{ inputs.sha }} wheel-build-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -80,7 +80,7 @@ jobs: wheel-publish-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-publish.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-publish.yaml@branch-23.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 2109630b1..28c87a1e6 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -21,57 +21,57 @@ jobs: - wheel-build-pylibwholegraph - wheel-test-pylibwholegraph secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/pr-builder.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/pr-builder.yaml@branch-23.10 checks: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/checks.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/checks.yaml@branch-23.10 with: enable_check_generated_files: false conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@branch-23.10 with: build_type: pull-request node_type: cpu16 conda-cpp-tests: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.10 with: build_type: pull-request conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@branch-23.10 with: build_type: pull-request conda-python-tests: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@branch-23.10 with: build_type: pull-request docs-build: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/custom-job.yaml@branch-23.10 with: build_type: pull-request arch: "amd64" - container_image: "rapidsai/ci:latest" + container_image: "rapidsai/ci-conda:latest" run_script: "ci/build_docs.sh" wheel-build-pylibwholegraph: needs: checks secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.10 with: build_type: pull-request script: ci/build_wheel.sh wheel-test-pylibwholegraph: needs: wheel-build-pylibwholegraph secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.10 with: build_type: pull-request script: ci/test_wheel.sh diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 1abb5e881..183a29b35 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,7 +16,7 @@ on: jobs: conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@branch-23.10 with: build_type: nightly branch: ${{ inputs.branch }} @@ -24,7 +24,7 @@ jobs: sha: ${{ inputs.sha }} conda-pytorch-tests: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@branch-23.10 with: build_type: nightly branch: ${{ inputs.branch }} @@ -32,7 +32,7 @@ jobs: sha: ${{ inputs.sha }} wheel-tests-pylibwholegraph: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.08 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.10 with: build_type: nightly branch: ${{ inputs.branch }} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 51ae0d97f..6943ae3b0 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: scripts ) - repo: https://github.com/pre-commit/mirrors-clang-format - rev: v16.0.0 + rev: v16.0.6 hooks: - id: clang-format exclude: | diff --git a/CHANGELOG.md b/CHANGELOG.md index b399b807b..32c4b47e3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,19 @@ +# wholegraph 23.10.00 (11 Oct 2023) + +## 🐛 Bug Fixes + +- Update all versions to 23.10 ([#71](https://github.com/rapidsai/wholegraph/pull/71)) [@raydouglass](https://github.com/raydouglass) +- Use `conda mambabuild` not `mamba mambabuild` ([#67](https://github.com/rapidsai/wholegraph/pull/67)) [@bdice](https://github.com/bdice) + +## 🛠️ Improvements + +- Update image names ([#70](https://github.com/rapidsai/wholegraph/pull/70)) [@AyodeAwe](https://github.com/AyodeAwe) +- Update to clang 16.0.6. ([#68](https://github.com/rapidsai/wholegraph/pull/68)) [@bdice](https://github.com/bdice) +- Simplify wheel build scripts and allow alphas of RAPIDS dependencies ([#66](https://github.com/rapidsai/wholegraph/pull/66)) [@divyegala](https://github.com/divyegala) +- Fix docs build and slightly optimize ([#63](https://github.com/rapidsai/wholegraph/pull/63)) [@dongxuy04](https://github.com/dongxuy04) +- Use `copy-pr-bot` ([#60](https://github.com/rapidsai/wholegraph/pull/60)) [@ajschmidt8](https://github.com/ajschmidt8) +- PR: Use top-k from RAFT ([#53](https://github.com/rapidsai/wholegraph/pull/53)) [@chuangz0](https://github.com/chuangz0) + # wholegraph 23.08.00 (9 Aug 2023) ## 🚨 Breaking Changes diff --git a/build.sh b/build.sh index ddf0d46a3..535988f54 100755 --- a/build.sh +++ b/build.sh @@ -212,6 +212,8 @@ if hasArg clean; then find ${REPODIR}/python/pylibwholegraph -name "*.cpython*.so" -type f -delete # remove docs build + rm -rf ${REPODIR}/cpp/html + rm -rf ${REPODIR}/cpp/xml cd ${REPODIR}/docs/wholegraph make BUILDDIR=${DOCS_BUILD_DIR} clean rm -rf ${REPODIR}/docs/wholegraph/_xml @@ -303,13 +305,13 @@ if hasArg docs; then ${CMAKE_GENERATOR_OPTION} \ ${CMAKE_VERBOSE_OPTION} fi - cd ${LIBWHOLEGRAPH_BUILD_DIR} + cd ${REPODIR}/cpp cmake --build "${LIBWHOLEGRAPH_BUILD_DIR}" -j${PARALLEL_LEVEL} --target doxygen ${VERBOSE_FLAG} mkdir -p ${REPODIR}/docs/wholegraph/_html/doxygen_docs/libwholegraph/html - mv ${LIBWHOLEGRAPH_BUILD_DIR}/html/* ${REPODIR}/docs/wholegraph/_html/doxygen_docs/libwholegraph/html + mv ${REPODIR}/cpp/html/* ${REPODIR}/docs/wholegraph/_html/doxygen_docs/libwholegraph/html mkdir -p ${REPODIR}/docs/wholegraph/_xml # _xml is used for sphinx breathe project - mv ${LIBWHOLEGRAPH_BUILD_DIR}/xml/* "${REPODIR}/docs/wholegraph/_xml" + mv ${REPODIR}/cpp/xml/* "${REPODIR}/docs/wholegraph/_xml" cd ${REPODIR}/docs/wholegraph PYTHONPATH=${REPODIR}/python/pylibwholegraph:${PYTHONPATH} make BUILDDIR=${DOCS_BUILD_DIR} html mv ${REPODIR}/docs/wholegraph/_html/doxygen_docs ${REPODIR}/docs/wholegraph/${DOCS_BUILD_DIR}/html/ diff --git a/ci/build_cpp.sh b/ci/build_cpp.sh index 45e668af1..e290374fd 100755 --- a/ci/build_cpp.sh +++ b/ci/build_cpp.sh @@ -11,6 +11,6 @@ rapids-print-env rapids-logger "Begin cpp build" -rapids-mamba-retry mambabuild conda/recipes/libwholegraph +rapids-conda-retry mambabuild conda/recipes/libwholegraph rapids-upload-conda-to-s3 cpp diff --git a/ci/build_docs.sh b/ci/build_docs.sh index d9951c890..d7b08a87b 100755 --- a/ci/build_docs.sh +++ b/ci/build_docs.sh @@ -23,7 +23,7 @@ rapids-logger "Downloading artifacts from previous jobs" CPP_CHANNEL=$(rapids-download-conda-from-s3 cpp) PYTHON_CHANNEL=$(rapids-download-conda-from-s3 python) -export RAPIDS_VERSION_NUMBER="23.08" +export RAPIDS_VERSION_NUMBER="23.10" export RAPIDS_DOCS_DIR="$(mktemp -d)" rapids-mamba-retry install \ diff --git a/ci/build_python.sh b/ci/build_python.sh index e316c94d0..e4382400e 100755 --- a/ci/build_python.sh +++ b/ci/build_python.sh @@ -22,7 +22,7 @@ rapids-logger "Begin py build" # TODO: Remove `--no-test` flags once importing on a CPU # node works correctly rapids-logger "Begin pylibwholegraph build" -rapids-mamba-retry mambabuild \ +rapids-conda-retry mambabuild \ --no-test \ --channel "${CPP_CHANNEL}" \ conda/recipes/pylibwholegraph diff --git a/ci/build_wheel.sh b/ci/build_wheel.sh index 5f6556734..6e2c9f73c 100755 --- a/ci/build_wheel.sh +++ b/ci/build_wheel.sh @@ -1,6 +1,11 @@ #!/bin/bash # Copyright (c) 2023, NVIDIA CORPORATION. +set -euo pipefail + +package_name="pylibwholegraph" +package_dir="python/pylibwholegraph" + source rapids-configure-sccache source rapids-date-string @@ -10,11 +15,26 @@ version_override="$(rapids-pip-wheel-version ${RAPIDS_DATE_STRING})" RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" -ci/release/apply_wheel_modifications.sh ${version_override} "-${RAPIDS_PY_CUDA_SUFFIX}" -echo "The package name and/or version was modified in the package source. The git diff is:" -git diff +# This is the version of the suffix with a preceding hyphen. It's used +# everywhere except in the final wheel name. +PACKAGE_CUDA_SUFFIX="-${RAPIDS_PY_CUDA_SUFFIX}" + +# Patch project metadata files to include the CUDA version suffix and version override. +pyproject_file="${package_dir}/pyproject.toml" + +sed -i "s/^version = .*/version = \"${version_override}\"/g" ${pyproject_file} +sed -i "s/name = \"${package_name}\"/name = \"${package_name}${PACKAGE_CUDA_SUFFIX}\"/g" ${pyproject_file} + +# For nightlies we want to ensure that we're pulling in alphas as well. The +# easiest way to do so is to augment the spec with a constraint containing a +# min alpha version that doesn't affect the version bounds but does allow usage +# of alpha versions for that dependency without --pre +alpha_spec='' +if ! rapids-is-release-build; then + alpha_spec=',>=0.0.0a0' +fi -cd python/pylibwholegraph +cd "${package_dir}" # Hardcode the output dir SKBUILD_CONFIGURE_OPTIONS="-DDETECT_CONDA_ENV=OFF -DBUILD_SHARED_LIBS=OFF -DCMAKE_MESSAGE_LOG_LEVEL=VERBOSE -DCUDA_STATIC_RUNTIME=ON -DWHOLEGRAPH_BUILD_WHEELS=ON" \ @@ -23,4 +43,4 @@ SKBUILD_CONFIGURE_OPTIONS="-DDETECT_CONDA_ENV=OFF -DBUILD_SHARED_LIBS=OFF -DCMAK mkdir -p final_dist python -m auditwheel repair --exclude libcuda.so.1 -w final_dist dist/* -RAPIDS_PY_WHEEL_NAME="pylibwholegraph_${RAPIDS_PY_CUDA_SUFFIX}" rapids-upload-wheels-to-s3 final_dist +RAPIDS_PY_WHEEL_NAME="${package_name}_${RAPIDS_PY_CUDA_SUFFIX}" rapids-upload-wheels-to-s3 final_dist diff --git a/ci/release/apply_wheel_modifications.sh b/ci/release/apply_wheel_modifications.sh deleted file mode 100755 index fd36d38f4..000000000 --- a/ci/release/apply_wheel_modifications.sh +++ /dev/null @@ -1,14 +0,0 @@ -#!/bin/bash -# Copyright (c) 2023, NVIDIA CORPORATION. -# -# Usage: bash apply_wheel_modifications.sh - -VERSION=${1} -CUDA_SUFFIX=${2} - -# setup.py updates -sed -i "s/^version = .*/version = \"${VERSION}\"/g" \ - python/pylibwholegraph/pyproject.toml - -# pyproject.toml cuda suffixes -sed -i "s/name = \"pylibwholegraph\"/name = \"pylibwholegraph${CUDA_SUFFIX}\"/g" python/pylibwholegraph/pyproject.toml diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index ec637bbda..8f4c468c8 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -23,8 +23,8 @@ dependencies: - graphviz - ipykernel - ipython -- libraft-headers==23.8.* -- librmm==23.8.* +- libraft-headers==23.10.* +- librmm==23.10.* - nanobind>=0.2.0 - nbsphinx - nccl diff --git a/conda/environments/all_cuda-120_arch-x86_64.yaml b/conda/environments/all_cuda-120_arch-x86_64.yaml index ba4ae96eb..1532684f7 100644 --- a/conda/environments/all_cuda-120_arch-x86_64.yaml +++ b/conda/environments/all_cuda-120_arch-x86_64.yaml @@ -25,8 +25,8 @@ dependencies: - graphviz - ipykernel - ipython -- libraft-headers==23.8.* -- librmm==23.8.* +- libraft-headers==23.10.* +- librmm==23.10.* - nanobind>=0.2.0 - nbsphinx - nccl diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index a072f6d0b..de2cd28d4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -14,7 +14,7 @@ # limitations under the License. #============================================================================= -set(RAPIDS_VERSION "23.08") +set(RAPIDS_VERSION "23.10") set(WHOLEGRAPH_VERSION "${RAPIDS_VERSION}.00") cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) @@ -292,7 +292,7 @@ rapids_export( find_package(Doxygen 1.8.11) if(Doxygen_FOUND) add_custom_command(OUTPUT WHOLEGRAPH_DOXYGEN - WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR} + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} COMMAND doxygen Doxyfile VERBATIM) diff --git a/cpp/Doxyfile b/cpp/Doxyfile index 60ee8334e..ed303cb92 100644 --- a/cpp/Doxyfile +++ b/cpp/Doxyfile @@ -38,7 +38,7 @@ PROJECT_NAME = "WholeGraph C API" # could be handy for archiving the generated documentation or if some version # control system is used. -PROJECT_NUMBER = 23.08 +PROJECT_NUMBER = 23.10 # Using the PROJECT_BRIEF tag one can provide an optional one line description # for a project that appears at the top of each page and should give viewer a diff --git a/cpp/src/wholegraph_ops/block_radix_topk.cuh b/cpp/src/wholegraph_ops/block_radix_topk.cuh deleted file mode 100644 index 624c07510..000000000 --- a/cpp/src/wholegraph_ops/block_radix_topk.cuh +++ /dev/null @@ -1,371 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include - -namespace wholegraph_ops { - -template -class BlockRadixTopKGlobalMemory { - static_assert(cub::PowerOfTwo::VALUE && (RADIX_BITS <= (sizeof(KeyT) * 8)), - "RADIX_BITS should be power of 2, and <= (sizeof(KeyT) * 8)"); - static_assert(cub::PowerOfTwo::VALUE, "BLOCK_SIZE should be power of 2"); - using KeyTraits = cub::Traits; - using UnsignedBits = typename KeyTraits::UnsignedBits; - using BlockScanT = cub::BlockScan; - static constexpr int RADIX_SIZE = (1 << RADIX_BITS); - static constexpr int SCAN_ITEMS_PER_THREAD = (RADIX_SIZE + BLOCK_SIZE - 1) / BLOCK_SIZE; - using BinBlockLoad = cub::BlockLoad; - using BinBlockStore = cub::BlockStore; - struct _TempStorage { - typename BlockScanT::TempStorage scan_storage; - union { - typename BinBlockLoad::TempStorage load_storage; - typename BinBlockStore::TempStorage store_storage; - } load_store; - union { - int shared_bins[RADIX_SIZE]; - }; - int share_target_k; - int share_bucket_id; - }; - - public: - struct TempStorage : cub::Uninitialized<_TempStorage> {}; - __device__ __forceinline__ BlockRadixTopKGlobalMemory(TempStorage& temp_storage) - : temp_storage_{temp_storage.Alias()}, tid_(threadIdx.x){}; - __device__ __forceinline__ void radixTopKGetThreshold( - const KeyT* data, int k, int size, KeyT& topK, bool& topk_is_unique) - { - assert(k < size && k > 0); - int target_k = k; - UnsignedBits key_pattern = 0; - int digit_pos = sizeof(KeyT) * 8 - RADIX_BITS; - for (; digit_pos >= 0; digit_pos -= RADIX_BITS) { - UpdateSharedBins(data, size, digit_pos, key_pattern); - InclusiveScanBins(); - UpdateTopK(digit_pos, target_k, key_pattern); - if (target_k == 0) break; - } - if (target_k == 0) { - key_pattern -= 1; - topk_is_unique = true; - } else { - topk_is_unique = false; - } - if (GREATER) key_pattern = ~key_pattern; - UnsignedBits topK_unsigned = KeyTraits::TwiddleOut(key_pattern); - topK = reinterpret_cast(topK_unsigned); - } - - private: - __device__ __forceinline__ void UpdateSharedBins(const KeyT* key, - int size, - int digit_pos, - UnsignedBits key_pattern) - { - for (int id = tid_; id < RADIX_SIZE; id += BLOCK_SIZE) { - temp_storage_.shared_bins[id] = 0; - } - cub::CTA_SYNC(); - UnsignedBits key_mask = ((UnsignedBits)(-1)) << ((UnsignedBits)(digit_pos + RADIX_BITS)); -#pragma unroll - for (int idx = tid_; idx < size; idx += BLOCK_SIZE) { - KeyT key_data = key[idx]; - UnsignedBits twiddled_data = KeyTraits::TwiddleIn(reinterpret_cast(key_data)); - if (GREATER) twiddled_data = ~twiddled_data; - UnsignedBits digit_in_radix = cub::BFE(twiddled_data, digit_pos, RADIX_BITS); - if ((twiddled_data & key_mask) == (key_pattern & key_mask)) { - atomicAdd(&temp_storage_.shared_bins[digit_in_radix], 1); - } - } - cub::CTA_SYNC(); - } - __device__ __forceinline__ void InclusiveScanBins() - { - int items[SCAN_ITEMS_PER_THREAD]; - BinBlockLoad(temp_storage_.load_store.load_storage) - .Load(temp_storage_.shared_bins, items, RADIX_SIZE, 0); - cub::CTA_SYNC(); - BlockScanT(temp_storage_.scan_storage).InclusiveSum(items, items); - cub::CTA_SYNC(); - BinBlockStore(temp_storage_.load_store.store_storage) - .Store(temp_storage_.shared_bins, items, RADIX_SIZE); - cub::CTA_SYNC(); - } - __device__ __forceinline__ void UpdateTopK(int digit_pos, - int& target_k, - UnsignedBits& target_pattern) - { - for (int idx = tid_; (idx < RADIX_SIZE); idx += BLOCK_SIZE) { - int prev_count = (idx == 0) ? 0 : temp_storage_.shared_bins[idx - 1]; - int cur_count = temp_storage_.shared_bins[idx]; - if (prev_count <= target_k && cur_count > target_k) { - temp_storage_.share_target_k = target_k - prev_count; - temp_storage_.share_bucket_id = idx; - } - } - cub::CTA_SYNC(); - target_k = temp_storage_.share_target_k; - int target_bucket_id = temp_storage_.share_bucket_id; - UnsignedBits key_segment = ((UnsignedBits)target_bucket_id) << ((UnsignedBits)digit_pos); - target_pattern |= key_segment; - } - _TempStorage& temp_storage_; - int tid_; -}; - -template -class BlockRadixTopKRegister { - static_assert(cub::PowerOfTwo::VALUE && (RADIX_BITS <= (sizeof(KeyT) * 8)), - "RADIX_BITS should be power of 2, and <= (sizeof(KeyT) * 8)"); - static_assert(cub::PowerOfTwo::VALUE, "BLOCK_SIZE should be power of 2"); - using KeyTraits = cub::Traits; - using UnsignedBits = typename KeyTraits::UnsignedBits; - using BlockScanT = cub::BlockScan; - static constexpr int RADIX_SIZE = (1 << RADIX_BITS); - static constexpr bool KEYS_ONLY = std::is_same::value; - static constexpr int SCAN_ITEMS_PER_THREAD = (RADIX_SIZE + BLOCK_SIZE - 1) / BLOCK_SIZE; - using BinBlockLoad = cub::BlockLoad; - using BinBlockStore = cub::BlockStore; - using BlockExchangeKey = cub::BlockExchange; - using BlockExchangeValue = cub::BlockExchange; - - using _ExchangeKeyTempStorage = typename BlockExchangeKey::TempStorage; - using _ExchangeValueTempStorage = typename BlockExchangeValue::TempStorage; - typedef union ExchangeKeyTempStorageType { - _ExchangeKeyTempStorage key_storage; - } ExchKeyTempStorageType; - typedef union ExchangeKeyValueTempStorageType { - _ExchangeKeyTempStorage key_storage; - _ExchangeValueTempStorage value_storage; - } ExchKeyValueTempStorageType; - using _ExchangeType = - typename std::conditional::type; - - struct _TempStorage { - typename BlockScanT::TempStorage scan_storage; - union { - typename BinBlockLoad::TempStorage load_storage; - typename BinBlockStore::TempStorage store_storage; - } load_store; - union { - int shared_bins[RADIX_SIZE]; - _ExchangeType exchange_storage; - }; - int share_target_k; - int share_bucket_id; - int share_prev_count; - }; - - public: - struct TempStorage : cub::Uninitialized<_TempStorage> {}; - __device__ __forceinline__ BlockRadixTopKRegister(TempStorage& temp_storage) - : temp_storage_{temp_storage.Alias()}, tid_(threadIdx.x){}; - __device__ __forceinline__ void radixTopKToStriped(KeyT (&keys)[ITEMS_PER_THREAD], - const int k, - const int valid_count) - { - if (k == valid_count) return; - TopKGenRank(keys, k, valid_count); - int is_valid[ITEMS_PER_THREAD]; - GenValidArray(is_valid, k); - BlockExchangeKey{temp_storage_.exchange_storage.key_storage}.ScatterToStripedFlagged( - keys, keys, ranks_, is_valid); - cub::CTA_SYNC(); - } - __device__ __forceinline__ void radixTopKToStriped(KeyT (&keys)[ITEMS_PER_THREAD], - ValueT (&values)[ITEMS_PER_THREAD], - const int k, - const int valid_count) - { - if (k == valid_count) return; - TopKGenRank(keys, k, valid_count); - int is_valid[ITEMS_PER_THREAD]; - GenValidArray(is_valid, k); - BlockExchangeKey{temp_storage_.exchange_storage.key_storage}.ScatterToStripedFlagged( - keys, keys, ranks_, is_valid); - cub::CTA_SYNC(); - BlockExchangeValue{temp_storage_.exchange_storage.value_storage}.ScatterToStripedFlagged( - values, values, ranks_, is_valid); - cub::CTA_SYNC(); - } - - private: - __device__ __forceinline__ void TopKGenRank(KeyT (&keys)[ITEMS_PER_THREAD], - const int k, - const int valid_count) - { - assert(k <= BLOCK_SIZE * ITEMS_PER_THREAD); - assert(k <= valid_count); - UnsignedBits(&unsigned_keys)[ITEMS_PER_THREAD] = - reinterpret_cast(keys); - search_mask_ = 0; - top_k_mask_ = 0; - -#pragma unroll - for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { - int idx = KEY * BLOCK_SIZE + tid_; - unsigned_keys[KEY] = KeyTraits::TwiddleIn(unsigned_keys[KEY]); - if (GREATER) unsigned_keys[KEY] = ~unsigned_keys[KEY]; - if (idx < valid_count) search_mask_ |= (1U << KEY); - } - - int target_k = k; - int prefix_k = 0; - - for (int digit_pos = sizeof(KeyT) * 8 - RADIX_BITS; digit_pos >= 0; digit_pos -= RADIX_BITS) { - UpdateSharedBins(unsigned_keys, digit_pos, prefix_k); - InclusiveScanBins(); - UpdateTopK(unsigned_keys, digit_pos, target_k, prefix_k, digit_pos == 0); - if (target_k == 0) break; - } - -#pragma unroll - for (int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { - if (GREATER) unsigned_keys[KEY] = ~unsigned_keys[KEY]; - unsigned_keys[KEY] = KeyTraits::TwiddleOut(unsigned_keys[KEY]); - } - } - __device__ __forceinline__ void GenValidArray(int (&is_valid)[ITEMS_PER_THREAD], int k) - { -#pragma unroll - for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { - if ((top_k_mask_ & (1U << KEY)) && ranks_[KEY] < k) { - is_valid[KEY] = 1; - } else { - is_valid[KEY] = 0; - } - } - } - __device__ __forceinline__ void UpdateSharedBins(UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD], - int digit_pos, - int prefix_k) - { - for (int id = tid_; id < RADIX_SIZE; id += BLOCK_SIZE) { - temp_storage_.shared_bins[id] = 0; - } - cub::CTA_SYNC(); -// #define USE_MATCH -#ifdef USE_MATCH - int lane_mask = cub::LaneMaskLt(); -#pragma unroll - for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { - bool is_search = search_mask_ & (1U << KEY); - int bucket_idx = -1; - if (is_search) { - UnsignedBits digit_in_radix = - cub::BFE(unsigned_keys[KEY], digit_pos, RADIX_BITS); - bucket_idx = (int)digit_in_radix; - } - int warp_match_mask = __match_any_sync(0xffffffff, bucket_idx); - int same_count = __popc(warp_match_mask); - int idx_in_same_bucket = __popc(warp_match_mask & lane_mask); - int same_bucket_root_lane = __ffs(warp_match_mask) - 1; - int same_bucket_start_idx; - if (idx_in_same_bucket == 0 && is_search) { - same_bucket_start_idx = atomicAdd(&temp_storage_.shared_bins[bucket_idx], same_count); - } - same_bucket_start_idx = - __shfl_sync(0xffffffff, same_bucket_start_idx, same_bucket_root_lane, 32); - if (is_search) { ranks_[KEY] = same_bucket_start_idx + idx_in_same_bucket + prefix_k; } - } -#else -#pragma unroll - for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { - bool is_search = search_mask_ & (1U << KEY); - int bucket_idx = -1; - if (is_search) { - UnsignedBits digit_in_radix = - cub::BFE(unsigned_keys[KEY], digit_pos, RADIX_BITS); - bucket_idx = (int)digit_in_radix; - ranks_[KEY] = atomicAdd(&temp_storage_.shared_bins[bucket_idx], 1) + prefix_k; - } - } -#endif - cub::CTA_SYNC(); - } - __device__ __forceinline__ void InclusiveScanBins() - { - int items[SCAN_ITEMS_PER_THREAD]; - BinBlockLoad(temp_storage_.load_store.load_storage) - .Load(temp_storage_.shared_bins, items, RADIX_SIZE, 0); - cub::CTA_SYNC(); - BlockScanT(temp_storage_.scan_storage).InclusiveSum(items, items); - cub::CTA_SYNC(); - BinBlockStore(temp_storage_.load_store.store_storage) - .Store(temp_storage_.shared_bins, items, RADIX_SIZE); - cub::CTA_SYNC(); - } - __device__ __forceinline__ void UpdateTopK(UnsignedBits (&unsigned_keys)[ITEMS_PER_THREAD], - int digit_pos, - int& target_k, - int& prefix_k, - bool mark_equal) - { - for (int idx = tid_; (idx < RADIX_SIZE); idx += BLOCK_SIZE) { - int prev_count = (idx == 0) ? 0 : temp_storage_.shared_bins[idx - 1]; - int cur_count = temp_storage_.shared_bins[idx]; - if (prev_count <= target_k && cur_count > target_k) { - temp_storage_.share_target_k = target_k - prev_count; - temp_storage_.share_bucket_id = idx; - temp_storage_.share_prev_count = prev_count; - } - } - cub::CTA_SYNC(); - target_k = temp_storage_.share_target_k; - prefix_k += temp_storage_.share_prev_count; - int target_bucket_id = temp_storage_.share_bucket_id; -#pragma unroll - for (unsigned int KEY = 0; KEY < ITEMS_PER_THREAD; KEY++) { - if (search_mask_ & (1U << KEY)) { - UnsignedBits digit_in_radix = - cub::BFE(unsigned_keys[KEY], digit_pos, RADIX_BITS); - if (digit_in_radix < target_bucket_id) { - top_k_mask_ |= (1U << KEY); - search_mask_ &= ~(1U << KEY); - } else if (digit_in_radix > target_bucket_id) { - search_mask_ &= ~(1U << KEY); - } else { - if (mark_equal) top_k_mask_ |= (1U << KEY); - } - if (digit_in_radix <= target_bucket_id) { - int prev_count = - (digit_in_radix == 0) ? 0 : temp_storage_.shared_bins[digit_in_radix - 1]; - ranks_[KEY] += prev_count; - } - } - } - cub::CTA_SYNC(); - } - - _TempStorage& temp_storage_; - int tid_; - int ranks_[ITEMS_PER_THREAD]; - unsigned int search_mask_; - unsigned int top_k_mask_; -}; - -} // namespace wholegraph_ops diff --git a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh index 22a97fd19..a2915cd00 100644 --- a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh @@ -14,22 +14,26 @@ * limitations under the License. */ #pragma once +#include #include +#include +#include +#include #include #include +#include "raft/matrix/detail/select_warpsort.cuh" +#include "raft/util/cuda_dev_essentials.cuh" +#include "wholememory_ops/output_memory_handle.hpp" +#include "wholememory_ops/raft_random.cuh" +#include "wholememory_ops/temp_memory_handle.hpp" +#include "wholememory_ops/thrust_allocator.hpp" #include #include #include #include #include -#include "wholememory_ops/output_memory_handle.hpp" -#include "wholememory_ops/raft_random.cuh" -#include "wholememory_ops/temp_memory_handle.hpp" -#include "wholememory_ops/thrust_allocator.hpp" - -#include "block_radix_topk.cuh" #include "cuda_macros.hpp" #include "error.hpp" #include "sample_comm.cuh" @@ -54,16 +58,14 @@ __device__ __forceinline__ float gen_key_from_weight(const WeightType weight, PC } template -__launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacement_large_kernel( + unsigned int BLOCK_SIZE> +__launch_bounds__(BLOCK_SIZE) __global__ void generate_weighted_keys_and_idxs_kernel( wholememory_gref_t wm_csr_row_ptr, wholememory_array_description_t wm_csr_row_ptr_desc, wholememory_gref_t wm_csr_col_ptr, @@ -74,18 +76,14 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen const int input_node_count, const int max_sample_count, unsigned long long random_seed, - const int* sample_offset, - wholememory_array_description_t sample_offset_desc, const int* target_neighbor_offset, - WMIdType* output, - int* src_lid, - int64_t* out_edge_gid, - WeightKeyType* weight_keys_buff) + WeightKeyType* output_weighted_keys, + NeighborIdxType* output_idxs, + bool need_random = true) { int input_idx = blockIdx.x; if (input_idx >= input_node_count) return; int gidx = threadIdx.x + blockIdx.x * BLOCK_SIZE; - wholememory::device_reference csr_row_ptr_gen(wm_csr_row_ptr); wholememory::device_reference csr_col_ptr_gen(wm_csr_col_ptr); wholememory::device_reference csr_weight_ptr_gen(wm_csr_weight_ptr); @@ -93,13 +91,57 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen int64_t start = csr_row_ptr_gen[nid]; int64_t end = csr_row_ptr_gen[nid + 1]; int neighbor_count = (int)(end - start); + if (neighbor_count <= max_sample_count) { need_random = false; } + + PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); + int output_offset = target_neighbor_offset[input_idx]; + output_weighted_keys += output_offset; + output_idxs += output_offset; + for (int id = threadIdx.x; id < neighbor_count; id += BLOCK_SIZE) { + WeightType thread_weight = csr_weight_ptr_gen[start + id]; + output_weighted_keys[id] = + need_random ? static_cast(gen_key_from_weight(thread_weight, rng)) + : (static_cast(thread_weight)); + output_idxs[id] = static_cast(id); + } +} + +template +__launch_bounds__(BLOCK_SIZE) __global__ + void weighted_sample_select_k_kernel(wholememory_gref_t wm_csr_row_ptr, + wholememory_array_description_t wm_csr_row_ptr_desc, + wholememory_gref_t wm_csr_col_ptr, + wholememory_array_description_t wm_csr_col_ptr_desc, + const IdType* input_nodes, + const int input_node_count, + const int max_sample_count, + const int* sample_offset, + wholememory_array_description_t sample_offset_desc, + const NeighborIdxType* sorted_idxs, + const int* target_neighbor_offset, + WMIdType* output, + LocalIdType* src_lid, + int64_t* out_edge_gid) +{ + int input_idx = blockIdx.x; + if (input_idx >= input_node_count) return; + wholememory::device_reference csr_row_ptr_gen(wm_csr_row_ptr); + wholememory::device_reference csr_col_ptr_gen(wm_csr_col_ptr); + IdType nid = input_nodes[input_idx]; + int64_t start = csr_row_ptr_gen[nid]; + int64_t end = csr_row_ptr_gen[nid + 1]; + int neighbor_count = (int)(end - start); + + int offset = sample_offset[input_idx]; - WeightKeyType* weight_keys_local_buff = weight_keys_buff + target_neighbor_offset[input_idx]; - int offset = sample_offset[input_idx]; if (neighbor_count <= max_sample_count) { for (int sample_id = threadIdx.x; sample_id < neighbor_count; sample_id += BLOCK_SIZE) { - int neighbor_idx = sample_id; - int original_neighbor_idx = neighbor_idx; + int original_neighbor_idx = sample_id; IdType gid = csr_col_ptr_gen[start + original_neighbor_idx]; output[offset + sample_id] = gid; if (src_lid) src_lid[offset + sample_id] = (LocalIdType)input_idx; @@ -108,83 +150,14 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen } return; } - - PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); - for (int id = threadIdx.x; id < neighbor_count; id += BLOCK_SIZE) { - WeightType thread_weight = csr_weight_ptr_gen[start + id]; - weight_keys_local_buff[id] = - NeedRandom ? static_cast(gen_key_from_weight(thread_weight, rng)) - : (static_cast(thread_weight)); - } - - __syncthreads(); - - WeightKeyType topk_val; - bool topk_is_unique; - - using BlockRadixSelectT = - std::conditional_t, - BlockRadixTopKGlobalMemory>; - __shared__ typename BlockRadixSelectT::TempStorage share_storage; - - BlockRadixSelectT{share_storage}.radixTopKGetThreshold( - weight_keys_local_buff, max_sample_count, neighbor_count, topk_val, topk_is_unique); - __shared__ int cnt; - - if (threadIdx.x == 0) { cnt = 0; } - __syncthreads(); - - for (int i = threadIdx.x; i < max_sample_count; i += BLOCK_SIZE) { - if (src_lid) src_lid[offset + i] = (LocalIdType)input_idx; - } - - // We use atomicAdd 1 operations instead of binaryScan to calculate the write - // index, since we do not need to keep the relative positions of element. - - if (topk_is_unique) { - for (int neighbor_idx = threadIdx.x; neighbor_idx < neighbor_count; - neighbor_idx += BLOCK_SIZE) { - WeightKeyType key = weight_keys_local_buff[neighbor_idx]; - bool has_topk = Ascending ? (key <= topk_val) : (key >= topk_val); - - if (has_topk) { - int write_index = atomicAdd(&cnt, 1); - LocalIdType local_original_idx = neighbor_idx; - output[offset + write_index] = csr_col_ptr_gen[start + local_original_idx]; - if (out_edge_gid) - out_edge_gid[offset + write_index] = static_cast(start + local_original_idx); - } - } - } else { - for (int neighbor_idx = threadIdx.x; neighbor_idx < neighbor_count; - neighbor_idx += BLOCK_SIZE) { - WeightKeyType key = weight_keys_local_buff[neighbor_idx]; - bool has_topk = Ascending ? (key < topk_val) : (key > topk_val); - - if (has_topk) { - int write_index = atomicAdd(&cnt, 1); - LocalIdType local_original_idx = neighbor_idx; - output[offset + write_index] = csr_col_ptr_gen[start + local_original_idx]; - if (out_edge_gid) - out_edge_gid[offset + write_index] = static_cast(start + local_original_idx); - } - } - __syncthreads(); - for (int neighbor_idx = threadIdx.x; neighbor_idx < neighbor_count; - neighbor_idx += BLOCK_SIZE) { - WeightKeyType key = weight_keys_local_buff[neighbor_idx]; - bool has_topk = (key == topk_val); - - if (has_topk) { - int write_index = atomicAdd(&cnt, 1); - if (write_index >= max_sample_count) break; - LocalIdType local_original_idx = neighbor_idx; - output[offset + write_index] = csr_col_ptr_gen[start + local_original_idx]; - if (out_edge_gid) - out_edge_gid[offset + write_index] = static_cast(start + local_original_idx); - } - } + int neighbor_offset = target_neighbor_offset[input_idx]; + for (int sample_id = threadIdx.x; sample_id < max_sample_count; sample_id += BLOCK_SIZE) { + int original_neighbor_idx = sorted_idxs[neighbor_offset + sample_id]; + IdType gid = csr_col_ptr_gen[start + original_neighbor_idx]; + output[offset + sample_id] = gid; + if (src_lid) src_lid[offset + sample_id] = (LocalIdType)input_idx; + if (out_edge_gid) + out_edge_gid[offset + sample_id] = static_cast(start + original_neighbor_idx); } } @@ -216,21 +189,30 @@ __global__ void get_sample_count_and_neighbor_count_without_replacement_kernel( } } +// to avoid queue.store() store keys or values in output. +struct null_store_t {}; +struct null_store_op { + template + constexpr auto operator()(const Type& in, UnusedArgs...) const + { + return null_store_t{}; + } +}; + // A-RES algorithmn // https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Res -// max_sample_count should <=(BLOCK_SIZE*ITEMS_PER_THREAD*/4) otherwise,need to -// change the template parameters of BlockRadixTopK. -template class WarpSortClass, + int Capacity, + typename IdType, typename LocalIdType, typename WeightType, + typename NeighborIdxType, typename WMIdType, typename WMOffsetType, typename WMWeightType, - unsigned int ITEMS_PER_THREAD, - unsigned int BLOCK_SIZE, - bool NeedRandom = true, - bool Ascending = false> -__launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacement_kernel( + bool NEED_RANDOM = true, + bool ASCENDING = false> +__launch_bounds__(256) __global__ void weighted_sample_without_replacement_raft_kernel( wholememory_gref_t wm_csr_row_ptr, wholememory_array_description_t wm_csr_row_ptr_desc, wholememory_gref_t wm_csr_col_ptr, @@ -244,13 +226,12 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen const int* sample_offset, wholememory_array_description_t sample_offset_desc, WMIdType* output, - int* src_lid, + LocalIdType* src_lid, int64_t* out_edge_gid) { int input_idx = blockIdx.x; if (input_idx >= input_node_count) return; - int gidx = threadIdx.x + blockIdx.x * BLOCK_SIZE; - + int gidx = threadIdx.x + blockIdx.x * blockDim.x; wholememory::device_reference csr_row_ptr_gen(wm_csr_row_ptr); wholememory::device_reference csr_col_ptr_gen(wm_csr_col_ptr); wholememory::device_reference csr_weight_ptr_gen(wm_csr_weight_ptr); @@ -258,86 +239,153 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen IdType nid = input_nodes[input_idx]; int64_t start = csr_row_ptr_gen[nid]; int64_t end = csr_row_ptr_gen[nid + 1]; - int neighbor_count = (int)(end - start); + int neighbor_count = static_cast(end - start); int offset = sample_offset[input_idx]; if (neighbor_count <= max_sample_count) { - for (int sample_id = threadIdx.x; sample_id < neighbor_count; sample_id += BLOCK_SIZE) { + for (int sample_id = threadIdx.x; sample_id < neighbor_count; sample_id += blockDim.x) { int neighbor_idx = sample_id; int original_neighbor_idx = neighbor_idx; IdType gid = csr_col_ptr_gen[start + original_neighbor_idx]; output[offset + sample_id] = gid; - if (src_lid) src_lid[offset + sample_id] = (LocalIdType)input_idx; + if (src_lid) src_lid[offset + sample_id] = input_idx; if (out_edge_gid) out_edge_gid[offset + sample_id] = static_cast(start + original_neighbor_idx); } return; } else { - PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); - - float weight_keys[ITEMS_PER_THREAD]; - int neighbor_idxs[ITEMS_PER_THREAD]; - - using BlockRadixTopKT = - std::conditional_t, - BlockRadixTopKRegister>; - - __shared__ typename BlockRadixTopKT::TempStorage sort_tmp_storage; - - const int tx = threadIdx.x; -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - int idx = BLOCK_SIZE * i + tx; + extern __shared__ __align__(256) uint8_t smem_buf_bytes[]; + using bq_t = raft::matrix::detail::select::warpsort:: + block_sort; + + uint8_t* warp_smem = bq_t::queue_t::mem_required(blockDim.x) > 0 ? smem_buf_bytes : nullptr; + bq_t queue(max_sample_count, warp_smem); + PCGenerator rng(random_seed, static_cast(gidx), static_cast(0)); + const int per_thread_lim = neighbor_count + raft::laneId(); + for (int idx = threadIdx.x; idx < per_thread_lim; idx += blockDim.x) { + WeightType weight_key = + WarpSortClass::kDummy; if (idx < neighbor_count) { WeightType thread_weight = csr_weight_ptr_gen[start + idx]; - weight_keys[i] = - NeedRandom ? gen_key_from_weight(thread_weight, rng) : (float)thread_weight; - neighbor_idxs[i] = idx; + weight_key = NEED_RANDOM ? gen_key_from_weight(thread_weight, rng) : thread_weight; } + queue.add(weight_key, idx); } - const int valid_count = (neighbor_count < (BLOCK_SIZE * ITEMS_PER_THREAD)) - ? neighbor_count - : (BLOCK_SIZE * ITEMS_PER_THREAD); - BlockRadixTopKT{sort_tmp_storage}.radixTopKToStriped( - weight_keys, neighbor_idxs, max_sample_count, valid_count); + queue.done(smem_buf_bytes); + __syncthreads(); - const int stride = BLOCK_SIZE * ITEMS_PER_THREAD - max_sample_count; - - for (int idx_offset = ITEMS_PER_THREAD * BLOCK_SIZE; idx_offset < neighbor_count; - idx_offset += stride) { -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - int local_idx = BLOCK_SIZE * i + tx - max_sample_count; - // [0,BLOCK_SIZE*ITEMS_PER_THREAD-max_sample_count) - int target_idx = idx_offset + local_idx; - if (local_idx >= 0 && target_idx < neighbor_count) { - WeightType thread_weight = csr_weight_ptr_gen[start + target_idx]; - weight_keys[i] = - NeedRandom ? gen_key_from_weight(thread_weight, rng) : (float)thread_weight; - neighbor_idxs[i] = target_idx; - } + NeighborIdxType* smem_topk_idx = reinterpret_cast(smem_buf_bytes); + queue.store(static_cast(nullptr), smem_topk_idx, null_store_op{}); + __syncthreads(); + for (int idx = threadIdx.x; idx < max_sample_count; idx += blockDim.x) { + NeighborIdxType local_original_idx = static_cast(smem_topk_idx[idx]); + if (src_lid) { src_lid[offset + idx] = static_cast(input_idx); } + output[offset + idx] = csr_col_ptr_gen[start + local_original_idx]; + if (out_edge_gid) { + out_edge_gid[offset + idx] = static_cast(start + local_original_idx); } - const int iter_valid_count = ((neighbor_count - idx_offset) >= stride) - ? (BLOCK_SIZE * ITEMS_PER_THREAD) - : (max_sample_count + neighbor_count - idx_offset); - BlockRadixTopKT{sort_tmp_storage}.radixTopKToStriped( - weight_keys, neighbor_idxs, max_sample_count, iter_valid_count); - __syncthreads(); } -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - int idx = i * BLOCK_SIZE + tx; - if (idx < max_sample_count) { - if (src_lid) src_lid[offset + idx] = (LocalIdType)input_idx; - LocalIdType local_original_idx = neighbor_idxs[i]; - output[offset + idx] = csr_col_ptr_gen[start + local_original_idx]; - if (out_edge_gid) - out_edge_gid[offset + idx] = static_cast(start + local_original_idx); - } + }; +} + +template