diff --git a/build.sh b/build.sh index feb2d7256e..02eed05813 100755 --- a/build.sh +++ b/build.sh @@ -18,8 +18,8 @@ ARGS=$* # scripts, and that this script resides in the repo dir! REPODIR=$(cd $(dirname $0); pwd) -VALIDARGS="clean libraft pylibraft raft-dask docs tests template bench-prims bench-ann clean --uninstall -v -g -n --compile-lib --compile-static-lib --allgpuarch --no-nvtx --cpu-only --show_depr_warn --incl-cache-stats --time -h" -HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=] [--limit-tests=] [--limit-bench-prims=] [--limit-bench-ann=] [--build-metrics=] +VALIDARGS="clean libraft pylibraft raft-dask docs tests template bench-prims clean --uninstall -v -g -n --compile-lib --compile-static-lib --allgpuarch --no-nvtx --cpu-only --show_depr_warn --incl-cache-stats --time -h" +HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=] [--limit-tests=] [--limit-bench-prims=] [--build-metrics=] where is: clean - remove all existing build artifacts and configuration (start over) libraft - build the raft C++ code only. Also builds the C-wrapper library @@ -29,7 +29,6 @@ HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool= is: @@ -42,7 +41,6 @@ HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool==1.8.2 -- c-compiler -- clang-tools=16.0.6 -- clang==16.0.6 -- cmake>=3.26.4,!=3.30.0 -- cuda-nvtx=11.8 -- cuda-profiler-api=11.8.86 -- cuda-version=11.8 -- cudatoolkit -- cxx-compiler -- cython>=3.0.0 -- gcc_linux-aarch64=11.* -- glog>=0.6.0 -- h5py>=3.8.0 -- hnswlib=0.7.0 -- libcublas-dev=11.11.3.6 -- libcublas=11.11.3.6 -- libcurand-dev=10.3.0.86 -- libcurand=10.3.0.86 -- libcusolver-dev=11.4.1.48 -- libcusolver=11.4.1.48 -- libcusparse-dev=11.7.5.86 -- libcusparse=11.7.5.86 -- libucxx==0.41.*,>=0.0.0a0 -- matplotlib -- nccl>=2.19 -- ninja -- nlohmann_json>=3.11.2 -- nvcc_linux-aarch64=11.8 -- openblas -- pandas -- pyyaml -- rapids-build-backend>=0.3.0,<0.4.0.dev0 -- rmm==24.12.*,>=0.0.0a0 -- scikit-build-core>=0.10.0 -- sysroot_linux-aarch64==2.17 -name: bench_ann_cuda-118_arch-aarch64 diff --git a/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml deleted file mode 100644 index 1b62c492cf..0000000000 --- a/conda/environments/bench_ann_cuda-118_arch-x86_64.yaml +++ /dev/null @@ -1,46 +0,0 @@ -# This file is generated by `rapids-dependency-file-generator`. -# To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. -channels: -- rapidsai -- rapidsai-nightly -- dask/label/dev -- conda-forge -- nvidia -dependencies: -- benchmark>=1.8.2 -- c-compiler -- clang-tools=16.0.6 -- clang==16.0.6 -- cmake>=3.26.4,!=3.30.0 -- cuda-nvtx=11.8 -- cuda-profiler-api=11.8.86 -- cuda-version=11.8 -- cudatoolkit -- cxx-compiler -- cython>=3.0.0 -- gcc_linux-64=11.* -- glog>=0.6.0 -- h5py>=3.8.0 -- hnswlib=0.7.0 -- libcublas-dev=11.11.3.6 -- libcublas=11.11.3.6 -- libcurand-dev=10.3.0.86 -- libcurand=10.3.0.86 -- libcusolver-dev=11.4.1.48 -- libcusolver=11.4.1.48 -- libcusparse-dev=11.7.5.86 -- libcusparse=11.7.5.86 -- libucxx==0.41.*,>=0.0.0a0 -- matplotlib -- nccl>=2.19 -- ninja -- nlohmann_json>=3.11.2 -- nvcc_linux-64=11.8 -- openblas -- pandas -- pyyaml -- rapids-build-backend>=0.3.0,<0.4.0.dev0 -- rmm==24.12.*,>=0.0.0a0 -- scikit-build-core>=0.10.0 -- sysroot_linux-64==2.17 -name: bench_ann_cuda-118_arch-x86_64 diff --git a/conda/environments/bench_ann_cuda-120_arch-aarch64.yaml b/conda/environments/bench_ann_cuda-120_arch-aarch64.yaml deleted file mode 100644 index 54d67f462a..0000000000 --- a/conda/environments/bench_ann_cuda-120_arch-aarch64.yaml +++ /dev/null @@ -1,42 +0,0 @@ -# This file is generated by `rapids-dependency-file-generator`. -# To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. -channels: -- rapidsai -- rapidsai-nightly -- dask/label/dev -- conda-forge -- nvidia -dependencies: -- benchmark>=1.8.2 -- c-compiler -- clang-tools=16.0.6 -- clang==16.0.6 -- cmake>=3.26.4,!=3.30.0 -- cuda-cudart-dev -- cuda-nvcc -- cuda-nvtx-dev -- cuda-profiler-api -- cuda-version=12.0 -- cxx-compiler -- cython>=3.0.0 -- gcc_linux-aarch64=11.* -- glog>=0.6.0 -- h5py>=3.8.0 -- hnswlib=0.7.0 -- libcublas-dev -- libcurand-dev -- libcusolver-dev -- libcusparse-dev -- libucxx==0.41.*,>=0.0.0a0 -- matplotlib -- nccl>=2.19 -- ninja -- nlohmann_json>=3.11.2 -- openblas -- pandas -- pyyaml -- rapids-build-backend>=0.3.0,<0.4.0.dev0 -- rmm==24.12.*,>=0.0.0a0 -- scikit-build-core>=0.10.0 -- sysroot_linux-aarch64==2.17 -name: bench_ann_cuda-120_arch-aarch64 diff --git a/conda/environments/bench_ann_cuda-120_arch-x86_64.yaml b/conda/environments/bench_ann_cuda-120_arch-x86_64.yaml deleted file mode 100644 index 4f39378047..0000000000 --- a/conda/environments/bench_ann_cuda-120_arch-x86_64.yaml +++ /dev/null @@ -1,42 +0,0 @@ -# This file is generated by `rapids-dependency-file-generator`. -# To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. -channels: -- rapidsai -- rapidsai-nightly -- dask/label/dev -- conda-forge -- nvidia -dependencies: -- benchmark>=1.8.2 -- c-compiler -- clang-tools=16.0.6 -- clang==16.0.6 -- cmake>=3.26.4,!=3.30.0 -- cuda-cudart-dev -- cuda-nvcc -- cuda-nvtx-dev -- cuda-profiler-api -- cuda-version=12.0 -- cxx-compiler -- cython>=3.0.0 -- gcc_linux-64=11.* -- glog>=0.6.0 -- h5py>=3.8.0 -- hnswlib=0.7.0 -- libcublas-dev -- libcurand-dev -- libcusolver-dev -- libcusparse-dev -- libucxx==0.41.*,>=0.0.0a0 -- matplotlib -- nccl>=2.19 -- ninja -- nlohmann_json>=3.11.2 -- openblas -- pandas -- pyyaml -- rapids-build-backend>=0.3.0,<0.4.0.dev0 -- rmm==24.12.*,>=0.0.0a0 -- scikit-build-core>=0.10.0 -- sysroot_linux-64==2.17 -name: bench_ann_cuda-120_arch-x86_64 diff --git a/conda/recipes/raft-ann-bench-cpu/build.sh b/conda/recipes/raft-ann-bench-cpu/build.sh deleted file mode 100644 index 4462d5124b..0000000000 --- a/conda/recipes/raft-ann-bench-cpu/build.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env bash -# Copyright (c) 2023, NVIDIA CORPORATION. - -./build.sh bench-ann --cpu-only --no-nvtx --build-metrics=bench_ann_cpu --incl-cache-stats -cmake --install cpp/build --component ann_bench diff --git a/conda/recipes/raft-ann-bench-cpu/conda_build_config.yaml b/conda/recipes/raft-ann-bench-cpu/conda_build_config.yaml deleted file mode 100644 index ed6f708e14..0000000000 --- a/conda/recipes/raft-ann-bench-cpu/conda_build_config.yaml +++ /dev/null @@ -1,29 +0,0 @@ -c_compiler_version: - - 11 - -cxx_compiler_version: - - 11 - -c_stdlib: - - sysroot - -c_stdlib_version: - - "2.17" - -cmake_version: - - ">=3.26.4,!=3.30.0" - -glog_version: - - ">=0.6.0" - -h5py_version: - - ">=3.8.0" - -nlohmann_json_version: - - ">=3.11.2" - -spdlog_version: - - ">=1.14.1,<1.15" - -fmt_version: - - ">=11.0.2,<12" diff --git a/conda/recipes/raft-ann-bench-cpu/meta.yaml b/conda/recipes/raft-ann-bench-cpu/meta.yaml deleted file mode 100644 index 94f7102726..0000000000 --- a/conda/recipes/raft-ann-bench-cpu/meta.yaml +++ /dev/null @@ -1,69 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. - -# Usage: -# conda build . -c conda-forge -c nvidia -c rapidsai -{% set version = environ['RAPIDS_PACKAGE_VERSION'].lstrip('v') + environ.get('VERSION_SUFFIX', '') %} -{% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} -{% set py_version = environ['CONDA_PY'] %} -{% set cuda_version = '.'.join(environ['RAPIDS_CUDA_VERSION'].split('.')[:2]) %} -{% set date_string = environ['RAPIDS_DATE_STRING'] %} - -package: - name: raft-ann-bench-cpu - version: {{ version }} - script: build.sh - -source: - path: ../../.. - -build: - script_env: - - AWS_ACCESS_KEY_ID - - AWS_SECRET_ACCESS_KEY - - AWS_SESSION_TOKEN - - CMAKE_C_COMPILER_LAUNCHER - - CMAKE_CUDA_COMPILER_LAUNCHER - - CMAKE_CXX_COMPILER_LAUNCHER - - CMAKE_GENERATOR - - PARALLEL_LEVEL - - RAPIDS_ARTIFACTS_DIR - - SCCACHE_BUCKET - - SCCACHE_IDLE_TIMEOUT - - SCCACHE_REGION - - SCCACHE_S3_KEY_PREFIX=libraft-aarch64 # [aarch64] - - SCCACHE_S3_KEY_PREFIX=libraft-linux64 # [linux64] - - SCCACHE_S3_USE_SSL - number: {{ GIT_DESCRIBE_NUMBER }} - string: py{{ py_version }}_{{ date_string }}_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }} - -requirements: - build: - - {{ compiler('c') }} - - {{ compiler('cxx') }} - - cmake {{ cmake_version }} - - ninja - - {{ stdlib("c") }} - - host: - - glog {{ glog_version }} - - matplotlib - - nlohmann_json {{ nlohmann_json_version }} - - spdlog {{ spdlog_version }} - - fmt {{ fmt_version }} - - python - - pyyaml - - pandas - - rapids-build-backend>=0.3.0,<0.4.0.dev0 - - run: - - glog {{ glog_version }} - - h5py {{ h5py_version }} - - matplotlib - - python - - pyyaml - - pandas - - benchmark -about: - home: https://rapids.ai/ - license: Apache-2.0 - summary: RAFT ANN CPU benchmarks diff --git a/conda/recipes/raft-ann-bench/build.sh b/conda/recipes/raft-ann-bench/build.sh deleted file mode 100644 index 00078792a1..0000000000 --- a/conda/recipes/raft-ann-bench/build.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/usr/bin/env bash -# Copyright (c) 2023, NVIDIA CORPORATION. - -./build.sh bench-ann --allgpuarch --no-nvtx --build-metrics=bench_ann --incl-cache-stats -cmake --install cpp/build --component ann_bench diff --git a/conda/recipes/raft-ann-bench/conda_build_config.yaml b/conda/recipes/raft-ann-bench/conda_build_config.yaml deleted file mode 100644 index 47bd730daf..0000000000 --- a/conda/recipes/raft-ann-bench/conda_build_config.yaml +++ /dev/null @@ -1,70 +0,0 @@ -c_compiler_version: - - 11 - -cxx_compiler_version: - - 11 - -cuda_compiler: - - cuda-nvcc - -cuda11_compiler: - - nvcc - -c_stdlib: - - sysroot - -c_stdlib_version: - - "2.17" - -cmake_version: - - ">=3.26.4,!=3.30.0" - -nccl_version: - - ">=2.19" - -glog_version: - - ">=0.6.0" - -h5py_version: - - ">=3.8.0" - -nlohmann_json_version: - - ">=3.11.2" - -# The CTK libraries below are missing from the conda-forge::cudatoolkit package -# for CUDA 11. The "*_host_*" version specifiers correspond to `11.8` packages -# and the "*_run_*" version specifiers correspond to `11.x` packages. - -cuda11_libcublas_host_version: - - "=11.11.3.6" - -cuda11_libcublas_run_version: - - ">=11.5.2.43,<12.0.0" - -cuda11_libcurand_host_version: - - "=10.3.0.86" - -cuda11_libcurand_run_version: - - ">=10.2.5.43,<10.3.1" - -cuda11_libcusolver_host_version: - - "=11.4.1.48" - -cuda11_libcusolver_run_version: - - ">=11.2.0.43,<11.4.2" - -cuda11_libcusparse_host_version: - - "=11.7.5.86" - -cuda11_libcusparse_run_version: - - ">=11.6.0.43,<12.0.0" - -# `cuda-profiler-api` only has `11.8.0` and `12.0.0` packages for all -# architectures. The "*_host_*" version specifiers correspond to `11.8` packages and the -# "*_run_*" version specifiers correspond to `11.x` packages. - -cuda11_cuda_profiler_api_host_version: - - "=11.8.86" - -cuda11_cuda_profiler_api_run_version: - - ">=11.4.240,<12" diff --git a/conda/recipes/raft-ann-bench/meta.yaml b/conda/recipes/raft-ann-bench/meta.yaml deleted file mode 100644 index d6aeb5f860..0000000000 --- a/conda/recipes/raft-ann-bench/meta.yaml +++ /dev/null @@ -1,110 +0,0 @@ -# Copyright (c) 2022-2024, NVIDIA CORPORATION. - -# Usage: -# conda build . -c conda-forge -c nvidia -c rapidsai -{% set version = environ['RAPIDS_PACKAGE_VERSION'].lstrip('v') + environ.get('VERSION_SUFFIX', '') %} -{% set minor_version = version.split('.')[0] + '.' + version.split('.')[1] %} -{% set py_version = environ['CONDA_PY'] %} -{% set cuda_version = '.'.join(environ['RAPIDS_CUDA_VERSION'].split('.')[:2]) %} -{% set cuda_major = cuda_version.split('.')[0] %} -{% set date_string = environ['RAPIDS_DATE_STRING'] %} - -package: - name: raft-ann-bench - version: {{ version }} - script: build.sh - -source: - path: ../../.. - -build: - script_env: - - AWS_ACCESS_KEY_ID - - AWS_SECRET_ACCESS_KEY - - AWS_SESSION_TOKEN - - CMAKE_C_COMPILER_LAUNCHER - - CMAKE_CUDA_COMPILER_LAUNCHER - - CMAKE_CXX_COMPILER_LAUNCHER - - CMAKE_GENERATOR - - PARALLEL_LEVEL - - RAPIDS_ARTIFACTS_DIR - - SCCACHE_BUCKET - - SCCACHE_IDLE_TIMEOUT - - SCCACHE_REGION - - SCCACHE_S3_KEY_PREFIX=libraft-aarch64 # [aarch64] - - SCCACHE_S3_KEY_PREFIX=libraft-linux64 # [linux64] - - SCCACHE_S3_USE_SSL - number: {{ GIT_DESCRIBE_NUMBER }} - string: cuda{{ cuda_major }}_py{{ py_version }}_{{ date_string }}_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }} - ignore_run_exports_from: - {% if cuda_major == "11" %} - - {{ compiler('cuda11') }} - {% else %} - - {{ compiler('cuda') }} - - cuda-cudart-dev - - libcublas-dev - {% endif %} - -requirements: - build: - - {{ compiler('c') }} - - {{ compiler('cxx') }} - {% if cuda_major == "11" %} - - {{ compiler('cuda11') }} ={{ cuda_version }} - {% else %} - - {{ compiler('cuda') }} - {% endif %} - - cuda-version ={{ cuda_version }} - - cmake {{ cmake_version }} - - ninja - - {{ stdlib("c") }} - - host: - - python - - libraft {{ version }} - - cuda-version ={{ cuda_version }} - {% if cuda_major == "11" %} - - cuda-profiler-api {{ cuda11_cuda_profiler_api_run_version }} - - libcublas {{ cuda11_libcublas_host_version }} - - libcublas-dev {{ cuda11_libcublas_host_version }} - {% else %} - - cuda-cudart-dev - - cuda-profiler-api - - libcublas-dev - {% endif %} - - glog {{ glog_version }} - - nlohmann_json {{ nlohmann_json_version }} - - h5py {{ h5py_version }} - - benchmark - - matplotlib - - python - - pandas - - pyyaml - # rmm is needed to determine if package is gpu-enabled - - rmm ={{ minor_version }} - - rapids-build-backend>=0.3.0,<0.4.0.dev0 - - run: - - python - - libraft {{ version }} - - {{ pin_compatible('cuda-version', max_pin='x', min_pin='x') }} - {% if cuda_major == "11" %} - - cudatoolkit - {% else %} - - cuda-cudart - - libcublas - {% endif %} - - glog {{ glog_version }} - - h5py {{ h5py_version }} - - benchmark - - glog {{ glog_version }} - - matplotlib - - python - - pandas - - pyyaml - # rmm is needed to determine if package is gpu-enabled - - rmm ={{ minor_version }} -about: - home: https://rapids.ai/ - license: Apache-2.0 - summary: RAFT ANN GPU and CPU benchmarks diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index d7eeb60b27..fcd6383cb6 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -54,7 +54,6 @@ option(BUILD_SHARED_LIBS "Build raft shared libraries" ON) option(BUILD_TESTS "Build raft unit-tests" ON) option(BUILD_PRIMS_BENCH "Build raft C++ benchmark tests" OFF) option(BUILD_ANN_BENCH "Build raft ann benchmarks" OFF) -option(BUILD_CAGRA_HNSWLIB "Build CAGRA+hnswlib interface" ON) option(CUDA_ENABLE_KERNELINFO "Enable kernel resource usage info" OFF) option(CUDA_ENABLE_LINEINFO "Enable the -lineinfo option for nvcc (useful for cuda-memcheck / profiler)" OFF @@ -105,12 +104,8 @@ message(VERBOSE "RAFT: Disable OpenMP: ${DISABLE_OPENMP}") message(VERBOSE "RAFT: Enable kernel resource usage info: ${CUDA_ENABLE_KERNELINFO}") message(VERBOSE "RAFT: Enable lineinfo in nvcc: ${CUDA_ENABLE_LINEINFO}") message(VERBOSE "RAFT: Enable nvtx markers: ${RAFT_NVTX}") -message(VERBOSE - "RAFT: Statically link the CUDA runtime: ${CUDA_STATIC_RUNTIME}" -) -message(VERBOSE - "RAFT: Statically link the CUDA math libraries: ${CUDA_STATIC_MATH_LIBRARIES}" -) +message(VERBOSE "RAFT: Statically link the CUDA runtime: ${CUDA_STATIC_RUNTIME}") +message(VERBOSE "RAFT: Statically link the CUDA math libraries: ${CUDA_STATIC_MATH_LIBRARIES}") # Set RMM logging level set(RMM_LOGGING_LEVEL @@ -198,10 +193,6 @@ if(BUILD_PRIMS_BENCH OR BUILD_ANN_BENCH) rapids_cpm_gbench(BUILD_STATIC) endif() -if(BUILD_CAGRA_HNSWLIB) - include(cmake/thirdparty/get_hnswlib.cmake) -endif() - # ################################################################################################## # * raft --------------------------------------------------------------------- add_library(raft INTERFACE) @@ -210,9 +201,6 @@ add_library(raft::raft ALIAS raft) target_include_directories( raft INTERFACE "$" "$" ) -if(BUILD_CAGRA_HNSWLIB) - target_link_libraries(raft INTERFACE hnswlib::hnswlib) -endif() if(NOT BUILD_CPU_ONLY) # Keep RAFT as lightweight as possible. Only CUDA libs and rmm should be used in global target. @@ -300,277 +288,11 @@ if(RAFT_COMPILE_LIBRARY) add_library( raft_objs OBJECT src/core/logger.cpp - src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu - src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu - src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu - src/distance/detail/pairwise_matrix/dispatch_dice_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_dice_float_float_float_int.cu - src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu - src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu - src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu - src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu - src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu - src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu - src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu - src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu - src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu - src/distance/detail/pairwise_matrix/dispatch_rbf.cu - src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu - src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu - src/distance/distance.cu - src/distance/fused_l2_nn.cu - src/distance/fused_distance_nn.cu src/linalg/detail/coalesced_reduction.cu - src/matrix/detail/select_k_double_int64_t.cu - src/matrix/detail/select_k_double_uint32_t.cu - src/matrix/detail/select_k_float_int64_t.cu - src/matrix/detail/select_k_float_uint32_t.cu - src/matrix/detail/select_k_float_int32.cu - src/matrix/detail/select_k_half_int64_t.cu - src/matrix/detail/select_k_half_uint32_t.cu - src/neighbors/ball_cover.cu - src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu - src/neighbors/brute_force_knn_int64_t_float_int64_t.cu - src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu - src/neighbors/brute_force_knn_int_float_int.cu - src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu - src/neighbors/brute_force_knn_index_float.cu - src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu - src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu - src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu - src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu - src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim128_t8.cu - src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim256_t16.cu - src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim512_t32.cu - src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim1024_t32.cu - src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu - src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu - src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu - src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu - src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu - src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu - src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu - src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu - src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu - src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu - src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu - src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu - src/neighbors/detail/cagra/search_single_cta_half_uint32_dim128_t8.cu - src/neighbors/detail/cagra/search_single_cta_half_uint32_dim256_t16.cu - src/neighbors/detail/cagra/search_single_cta_half_uint32_dim512_t32.cu - src/neighbors/detail/cagra/search_single_cta_half_uint32_dim1024_t32.cu - src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu - src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu - src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu - src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu - src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu - src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu - src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu - src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim128_t8_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim128_t8_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim256_t16_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim256_t16_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim512_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim512_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim1024_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint32_dim1024_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim128_t8_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim128_t8_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim256_t16_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim256_t16_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim512_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim512_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim1024_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint32_dim1024_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim128_t8_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim128_t8_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim256_t16_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim256_t16_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim512_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim512_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim1024_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_int8_uint32_dim1024_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim128_t8_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim128_t8_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim256_t16_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim256_t16_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim512_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim512_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim1024_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_uint8_uint32_dim1024_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim128_t8_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim128_t8_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim256_t16_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim256_t16_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim512_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim512_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim1024_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_float_uint64_dim1024_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim128_t8_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim128_t8_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim256_t16_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim256_t16_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim512_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim512_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim1024_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_multi_cta_half_uint64_dim1024_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim128_t8_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim128_t8_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim256_t16_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim256_t16_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim512_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim512_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim1024_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint32_dim1024_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim128_t8_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim128_t8_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim256_t16_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim256_t16_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim512_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim512_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim1024_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint32_dim1024_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim128_t8_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim128_t8_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim256_t16_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim256_t16_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim512_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim512_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim1024_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_int8_uint32_dim1024_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim128_t8_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim128_t8_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim256_t16_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim256_t16_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim512_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim512_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim1024_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_uint8_uint32_dim1024_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim128_t8_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim128_t8_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim256_t16_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim256_t16_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim512_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim512_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim1024_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_float_uint64_dim1024_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim128_t8_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim128_t8_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim256_t16_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim256_t16_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim512_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim512_t32_8pq_4subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim1024_t32_8pq_2subd_half.cu - src/neighbors/detail/cagra/q_search_single_cta_half_uint64_dim1024_t32_8pq_4subd_half.cu - src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu - src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu - src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu - src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu - src/neighbors/detail/ivf_flat_search.cu - src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu - src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu - src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu - src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu - src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu - src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu - src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu - src/neighbors/detail/refine_host_float_float.cpp - src/neighbors/detail/refine_host_half_float.cpp - src/neighbors/detail/refine_host_int8_t_float.cpp - src/neighbors/detail/refine_host_uint8_t_float.cpp - src/neighbors/ivf_flat_build_float_int64_t.cu - src/neighbors/ivf_flat_build_int8_t_int64_t.cu - src/neighbors/ivf_flat_build_uint8_t_int64_t.cu - src/neighbors/ivf_flat_extend_float_int64_t.cu - src/neighbors/ivf_flat_extend_int8_t_int64_t.cu - src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu - src/neighbors/ivf_flat_search_float_int64_t.cu - src/neighbors/ivf_flat_search_int8_t_int64_t.cu - src/neighbors/ivf_flat_search_uint8_t_int64_t.cu - src/neighbors/ivfpq_build_float_int64_t.cu - src/neighbors/ivfpq_build_half_int64_t.cu - src/neighbors/ivfpq_build_int8_t_int64_t.cu - src/neighbors/ivfpq_build_uint8_t_int64_t.cu - src/neighbors/ivfpq_extend_float_int64_t.cu - src/neighbors/ivfpq_extend_half_int64_t.cu - src/neighbors/ivfpq_extend_int8_t_int64_t.cu - src/neighbors/ivfpq_extend_uint8_t_int64_t.cu - src/neighbors/ivfpq_search_float_int64_t.cu - src/neighbors/ivfpq_search_half_int64_t.cu - src/neighbors/ivfpq_search_int8_t_int64_t.cu - src/neighbors/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/refine_float_float.cu - src/neighbors/refine_half_float.cu - src/neighbors/refine_int8_t_float.cu - src/neighbors/refine_uint8_t_float.cu - src/raft_runtime/cluster/cluster_cost.cuh - src/raft_runtime/cluster/cluster_cost_double.cu - src/raft_runtime/cluster/cluster_cost_float.cu - src/raft_runtime/cluster/kmeans_fit_double.cu - src/raft_runtime/cluster/kmeans_fit_float.cu - src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu - src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu - src/raft_runtime/cluster/update_centroids.cuh - src/raft_runtime/cluster/update_centroids_double.cu - src/raft_runtime/cluster/update_centroids_float.cu - src/raft_runtime/distance/fused_distance_min_arg.cu - src/raft_runtime/distance/fused_l2_min_arg.cu - src/raft_runtime/distance/pairwise_distance.cu - src/raft_runtime/matrix/select_k_float_int64_t.cu - src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu - src/raft_runtime/neighbors/cagra_build.cu - src/raft_runtime/neighbors/cagra_search.cu - src/raft_runtime/neighbors/cagra_serialize.cu - src/raft_runtime/neighbors/eps_neighborhood.cu - $<$:src/raft_runtime/neighbors/hnsw.cpp> - src/raft_runtime/neighbors/ivf_flat_build.cu - src/raft_runtime/neighbors/ivf_flat_search.cu - src/raft_runtime/neighbors/ivf_flat_serialize.cu - src/raft_runtime/neighbors/ivfpq_build.cu - src/raft_runtime/neighbors/ivfpq_deserialize.cu - src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu - src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu - src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu - src/raft_runtime/neighbors/ivfpq_serialize.cu - src/raft_runtime/neighbors/refine_d_int64_t_float.cu - src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu - src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu - src/raft_runtime/neighbors/refine_h_int64_t_float.cu - src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu - src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu src/raft_runtime/random/rmat_rectangular_generator_int64_double.cu src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu src/raft_runtime/random/rmat_rectangular_generator_int_double.cu src/raft_runtime/random/rmat_rectangular_generator_int_float.cu - src/spatial/knn/detail/ball_cover/registers_eps_pass_euclidean.cu - src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu - src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu - src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu - src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu - src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu - src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu - src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu - src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu - src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu - src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu - src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu - src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu - src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu - src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu - src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu ) set_target_properties( raft_objs @@ -847,10 +569,3 @@ endif() if(BUILD_PRIMS_BENCH) add_subdirectory(bench/prims/) endif() - -# ################################################################################################## -# * build ann benchmark executable ----------------------------------------------- - -if(BUILD_ANN_BENCH) - add_subdirectory(bench/ann/) -endif() diff --git a/cpp/bench/ann/CMakeLists.txt b/cpp/bench/ann/CMakeLists.txt deleted file mode 100644 index 35df378438..0000000000 --- a/cpp/bench/ann/CMakeLists.txt +++ /dev/null @@ -1,349 +0,0 @@ -# ============================================================================= -# Copyright (c) 2023-2024, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except -# in compliance with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software distributed under the License -# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express -# or implied. See the License for the specific language governing permissions and limitations under -# the License. -# ============================================================================= - -list(APPEND CMAKE_MODULE_PATH "${RAFT_SOURCE_DIR}") - -# ################################################################################################## -# * benchmark options ------------------------------------------------------------------------------ - -option(RAFT_ANN_BENCH_USE_FAISS_GPU_FLAT "Include faiss' brute-force knn algorithm in benchmark" ON) -option(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_FLAT "Include faiss' ivf flat algorithm in benchmark" ON) -option(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_PQ "Include faiss' ivf pq algorithm in benchmark" ON) -option(RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT "Include faiss' cpu brute-force algorithm in benchmark" ON) - -option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_FLAT "Include faiss' cpu ivf flat algorithm in benchmark" - ON -) -option(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_PQ "Include faiss' cpu ivf pq algorithm in benchmark" ON) -option(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT "Include raft's ivf flat algorithm in benchmark" ON) -option(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ "Include raft's ivf pq algorithm in benchmark" ON) -option(RAFT_ANN_BENCH_USE_RAFT_CAGRA "Include raft's CAGRA in benchmark" ON) -option(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE "Include raft's brute force knn in benchmark" ON) -option(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB "Include raft's CAGRA in benchmark" ON) -option(RAFT_ANN_BENCH_USE_HNSWLIB "Include hnsw algorithm in benchmark" ON) -option(RAFT_ANN_BENCH_USE_GGNN "Include ggnn algorithm in benchmark" ON) -option(RAFT_ANN_BENCH_SINGLE_EXE - "Make a single executable with benchmark as shared library modules" OFF -) - -# ################################################################################################## -# * Process options ---------------------------------------------------------- - -find_package(Threads REQUIRED) - -set(RAFT_ANN_BENCH_USE_FAISS ON) -set(RAFT_FAISS_ENABLE_GPU ON) -set(RAFT_USE_FAISS_STATIC ON) - -if(BUILD_CPU_ONLY) - - # Include necessary logging dependencies - include(cmake/thirdparty/get_fmt) - include(cmake/thirdparty/get_spdlog) - set(RAFT_FAISS_ENABLE_GPU OFF) - set(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT OFF) - set(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ OFF) - set(RAFT_ANN_BENCH_USE_RAFT_CAGRA OFF) - set(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE OFF) - set(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB OFF) - set(RAFT_ANN_BENCH_USE_GGNN OFF) -endif() - -set(RAFT_ANN_BENCH_USE_RAFT OFF) -if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ - OR RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE - OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT - OR RAFT_ANN_BENCH_USE_RAFT_CAGRA - OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB -) - set(RAFT_ANN_BENCH_USE_RAFT ON) -endif() - -# ################################################################################################## -# * Fetch requirements ------------------------------------------------------------- - -if(RAFT_ANN_BENCH_USE_HNSWLIB OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) - include(cmake/thirdparty/get_hnswlib) -endif() - -include(cmake/thirdparty/get_nlohmann_json) - -if(RAFT_ANN_BENCH_USE_GGNN) - include(cmake/thirdparty/get_ggnn) -endif() - -if(RAFT_ANN_BENCH_USE_FAISS) - include(cmake/thirdparty/get_faiss) -endif() - -# ################################################################################################## -# * Enable NVTX if available - -# Note: ANN_BENCH wrappers have extra NVTX code not related to raft::nvtx.They track gbench -# benchmark cases and iterations. This is to make limited NVTX available to all algos, not just -# raft. -if(TARGET CUDA::nvtx3) - set(_CMAKE_REQUIRED_INCLUDES_ORIG ${CMAKE_REQUIRED_INCLUDES}) - get_target_property(CMAKE_REQUIRED_INCLUDES CUDA::nvtx3 INTERFACE_INCLUDE_DIRECTORIES) - unset(NVTX3_HEADERS_FOUND CACHE) - # Check the headers explicitly to make sure the cpu-only build succeeds - CHECK_INCLUDE_FILE_CXX(nvtx3/nvToolsExt.h NVTX3_HEADERS_FOUND) - set(CMAKE_REQUIRED_INCLUDES ${_CMAKE_REQUIRED_INCLUDES_ORIG}) -endif() - -# ################################################################################################## -# * Configure tests function------------------------------------------------------------- - -function(ConfigureAnnBench) - - set(oneValueArgs NAME) - set(multiValueArgs PATH LINKS CXXFLAGS) - - if(NOT BUILD_CPU_ONLY) - set(GPU_BUILD ON) - endif() - - cmake_parse_arguments( - ConfigureAnnBench "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} - ) - - set(BENCH_NAME ${ConfigureAnnBench_NAME}_ANN_BENCH) - - if(RAFT_ANN_BENCH_SINGLE_EXE) - add_library(${BENCH_NAME} SHARED ${ConfigureAnnBench_PATH}) - string(TOLOWER ${BENCH_NAME} BENCH_LIB_NAME) - set_target_properties(${BENCH_NAME} PROPERTIES OUTPUT_NAME ${BENCH_LIB_NAME}) - add_dependencies(${BENCH_NAME} ANN_BENCH) - else() - add_executable(${BENCH_NAME} ${ConfigureAnnBench_PATH}) - target_compile_definitions( - ${BENCH_NAME} PRIVATE ANN_BENCH_BUILD_MAIN - $<$:ANN_BENCH_NVTX3_HEADERS_FOUND> - ) - target_link_libraries( - ${BENCH_NAME} PRIVATE benchmark::benchmark $<$:CUDA::nvtx3> - ) - endif() - - target_link_libraries( - ${BENCH_NAME} - PRIVATE raft::raft - nlohmann_json::nlohmann_json - ${ConfigureAnnBench_LINKS} - Threads::Threads - $<$:${RAFT_CTK_MATH_DEPENDENCIES}> - $ - $ - $<$:fmt::fmt-header-only> - $<$:spdlog::spdlog_header_only> - ) - - set_target_properties( - ${BENCH_NAME} - PROPERTIES # set target compile options - CXX_STANDARD 17 - CXX_STANDARD_REQUIRED ON - CUDA_STANDARD 17 - CUDA_STANDARD_REQUIRED ON - POSITION_INDEPENDENT_CODE ON - INTERFACE_POSITION_INDEPENDENT_CODE ON - BUILD_RPATH "\$ORIGIN" - INSTALL_RPATH "\$ORIGIN" - ) - - set(${ConfigureAnnBench_CXXFLAGS} ${RAFT_CXX_FLAGS} ${ConfigureAnnBench_CXXFLAGS}) - - target_compile_options( - ${BENCH_NAME} PRIVATE "$<$:${ConfigureAnnBench_CXXFLAGS}>" - "$<$:${RAFT_CUDA_FLAGS}>" - ) - - if(RAFT_ANN_BENCH_USE_${ConfigureAnnBench_NAME}) - target_compile_definitions( - ${BENCH_NAME} - PUBLIC - RAFT_ANN_BENCH_USE_${ConfigureAnnBench_NAME}=RAFT_ANN_BENCH_USE_${ConfigureAnnBench_NAME} - ) - endif() - - target_include_directories( - ${BENCH_NAME} - PUBLIC "$" - PRIVATE ${ConfigureAnnBench_INCLUDES} - ) - - install( - TARGETS ${BENCH_NAME} - COMPONENT ann_bench - DESTINATION bin/ann - ) -endfunction() - -# ################################################################################################## -# * Configure tests------------------------------------------------------------- - -if(RAFT_ANN_BENCH_USE_HNSWLIB) - ConfigureAnnBench( - NAME HNSWLIB PATH src/hnswlib/hnswlib_benchmark.cpp LINKS hnswlib::hnswlib - ) - -endif() - -if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) - ConfigureAnnBench( - NAME - RAFT_IVF_PQ - PATH - src/raft/raft_benchmark.cu - src/raft/raft_ivf_pq.cu - LINKS - raft::compiled - ) -endif() - -if(RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT) - ConfigureAnnBench( - NAME - RAFT_IVF_FLAT - PATH - src/raft/raft_benchmark.cu - src/raft/raft_ivf_flat.cu - LINKS - raft::compiled - ) -endif() - -if(RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE) - ConfigureAnnBench( - NAME RAFT_BRUTE_FORCE PATH src/raft/raft_benchmark.cu LINKS raft::compiled - ) -endif() - -if(RAFT_ANN_BENCH_USE_RAFT_CAGRA) - ConfigureAnnBench( - NAME - RAFT_CAGRA - PATH - src/raft/raft_benchmark.cu - src/raft/raft_cagra_float.cu - src/raft/raft_cagra_half.cu - src/raft/raft_cagra_int8_t.cu - src/raft/raft_cagra_uint8_t.cu - LINKS - raft::compiled - ) -endif() - -if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) - ConfigureAnnBench( - NAME RAFT_CAGRA_HNSWLIB PATH src/raft/raft_cagra_hnswlib.cu LINKS raft::compiled - hnswlib::hnswlib - ) -endif() - -message("RAFT_FAISS_TARGETS: ${RAFT_FAISS_TARGETS}") -message("CUDAToolkit_LIBRARY_DIR: ${CUDAToolkit_LIBRARY_DIR}") -if(RAFT_ANN_BENCH_USE_FAISS_CPU_FLAT) - ConfigureAnnBench( - NAME FAISS_CPU_FLAT PATH src/faiss/faiss_cpu_benchmark.cpp LINKS - ${RAFT_FAISS_TARGETS} - ) -endif() - -if(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_FLAT) - ConfigureAnnBench( - NAME FAISS_CPU_IVF_FLAT PATH src/faiss/faiss_cpu_benchmark.cpp LINKS - ${RAFT_FAISS_TARGETS} - ) -endif() - -if(RAFT_ANN_BENCH_USE_FAISS_CPU_IVF_PQ) - ConfigureAnnBench( - NAME FAISS_CPU_IVF_PQ PATH src/faiss/faiss_cpu_benchmark.cpp LINKS - ${RAFT_FAISS_TARGETS} - ) -endif() - -if(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_FLAT AND RAFT_FAISS_ENABLE_GPU) - ConfigureAnnBench( - NAME FAISS_GPU_IVF_FLAT PATH src/faiss/faiss_gpu_benchmark.cu LINKS - ${RAFT_FAISS_TARGETS} - ) -endif() - -if(RAFT_ANN_BENCH_USE_FAISS_GPU_IVF_PQ AND RAFT_FAISS_ENABLE_GPU) - ConfigureAnnBench( - NAME FAISS_GPU_IVF_PQ PATH src/faiss/faiss_gpu_benchmark.cu LINKS - ${RAFT_FAISS_TARGETS} - ) -endif() - -if(RAFT_ANN_BENCH_USE_FAISS_GPU_FLAT AND RAFT_FAISS_ENABLE_GPU) - ConfigureAnnBench( - NAME FAISS_GPU_FLAT PATH src/faiss/faiss_gpu_benchmark.cu LINKS ${RAFT_FAISS_TARGETS} - ) -endif() - -if(RAFT_ANN_BENCH_USE_GGNN) - include(cmake/thirdparty/get_glog) - ConfigureAnnBench(NAME GGNN PATH src/ggnn/ggnn_benchmark.cu LINKS glog::glog ggnn::ggnn) -endif() - -# ################################################################################################## -# * Dynamically-loading ANN_BENCH executable ------------------------------------------------------- -if(RAFT_ANN_BENCH_SINGLE_EXE) - add_executable(ANN_BENCH src/common/benchmark.cpp) - - target_include_directories(ANN_BENCH PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) - - target_link_libraries( - ANN_BENCH - PRIVATE raft::raft - nlohmann_json::nlohmann_json - benchmark::benchmark - dl - -static-libgcc - fmt::fmt-header-only - spdlog::spdlog_header_only - -static-libstdc++ - $<$:CUDA::nvtx3> - ) - set_target_properties( - ANN_BENCH - PROPERTIES # set target compile options - CXX_STANDARD 17 - CXX_STANDARD_REQUIRED ON - CUDA_STANDARD 17 - CUDA_STANDARD_REQUIRED ON - POSITION_INDEPENDENT_CODE ON - INTERFACE_POSITION_INDEPENDENT_CODE ON - BUILD_RPATH "\$ORIGIN" - INSTALL_RPATH "\$ORIGIN" - ) - target_compile_definitions( - ANN_BENCH - PRIVATE - $<$:ANN_BENCH_LINK_CUDART="libcudart.so.${CUDAToolkit_VERSION_MAJOR}.${CUDAToolkit_VERSION_MINOR}.${CUDAToolkit_VERSION_PATCH}"> - $<$:ANN_BENCH_NVTX3_HEADERS_FOUND> - ) - - target_link_options(ANN_BENCH PRIVATE -export-dynamic) - - install( - TARGETS ANN_BENCH - COMPONENT ann_bench - DESTINATION bin/ann - EXCLUDE_FROM_ALL - ) -endif() diff --git a/cpp/bench/ann/README.md b/cpp/bench/ann/README.md deleted file mode 100644 index 1a8af2e448..0000000000 --- a/cpp/bench/ann/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# RAFT CUDA ANN Benchmarks - -Please see the [ANN Benchmarks](https://docs.rapids.ai/api/raft/stable/cuda_ann_benchmarks.html) section of the RAFT documentation for instructions on building and using the ANN benchmarks. \ No newline at end of file diff --git a/cpp/bench/ann/src/common/ann_types.hpp b/cpp/bench/ann/src/common/ann_types.hpp deleted file mode 100644 index b010063dee..0000000000 --- a/cpp/bench/ann/src/common/ann_types.hpp +++ /dev/null @@ -1,168 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "cuda_stub.hpp" // cudaStream_t - -#include -#include -#include -#include - -namespace raft::bench::ann { - -enum Objective { - THROUGHPUT, // See how many vectors we can push through - LATENCY // See how fast we can push a vector through -}; - -enum class MemoryType { - Host, - HostMmap, - Device, -}; - -enum class Metric { - kInnerProduct, - kEuclidean, -}; - -inline auto parse_metric(const std::string& metric_str) -> Metric -{ - if (metric_str == "inner_product") { - return raft::bench::ann::Metric::kInnerProduct; - } else if (metric_str == "euclidean") { - return raft::bench::ann::Metric::kEuclidean; - } else { - throw std::runtime_error("invalid metric: '" + metric_str + "'"); - } -} - -inline auto parse_memory_type(const std::string& memory_type) -> MemoryType -{ - if (memory_type == "host") { - return MemoryType::Host; - } else if (memory_type == "mmap") { - return MemoryType::HostMmap; - } else if (memory_type == "device") { - return MemoryType::Device; - } else { - throw std::runtime_error("invalid memory type: '" + memory_type + "'"); - } -} - -struct AlgoProperty { - MemoryType dataset_memory_type; - // neighbors/distances should have same memory type as queries - MemoryType query_memory_type; -}; - -class AnnBase { - public: - using index_type = size_t; - - inline AnnBase(Metric metric, int dim) : metric_(metric), dim_(dim) {} - virtual ~AnnBase() noexcept = default; - - protected: - Metric metric_; - int dim_; -}; - -/** - * The GPU-based algorithms, which do not perform CPU synchronization at the end of their build or - * search methods, must implement this interface. - * - * The `cuda_timer` / `cuda_lap` from `util.hpp` uses this stream to record GPU times with events - * and, if necessary, also synchronize (via events) between iterations. - * - * If the algo does not implement this interface, GPU timings are disabled. - */ -class AnnGPU { - public: - /** - * Return the main cuda stream for this algorithm. - * If any work is done in multiple streams, they should synchornize with the main stream at the - * end. - */ - [[nodiscard]] virtual auto get_sync_stream() const noexcept -> cudaStream_t = 0; - /** - * By default a GPU algorithm uses a fixed stream to order GPU operations. - * However, an algorithm may need to synchronize with the host at the end of its execution. - * In that case, also synchronizing with a benchmark event would put it at disadvantage. - * - * We can disable event sync by passing `false` here - * - ONLY IF THE ALGORITHM HAS PRODUCED ITS OUTPUT BY THE TIME IT SYNCHRONIZES WITH CPU. - */ - [[nodiscard]] virtual auto uses_stream() const noexcept -> bool { return true; } - virtual ~AnnGPU() noexcept = default; -}; - -template -class ANN : public AnnBase { - public: - struct AnnSearchParam { - Objective metric_objective = Objective::LATENCY; - virtual ~AnnSearchParam() = default; - [[nodiscard]] virtual auto needs_dataset() const -> bool { return false; }; - }; - - inline ANN(Metric metric, int dim) : AnnBase(metric, dim) {} - virtual ~ANN() noexcept override = default; - - virtual void build(const T* dataset, size_t nrow) = 0; - - virtual void set_search_param(const AnnSearchParam& param) = 0; - // TODO: this assumes that an algorithm can always return k results. - // This is not always possible. - virtual void search(const T* queries, - int batch_size, - int k, - AnnBase::index_type* neighbors, - float* distances) const = 0; - - virtual void save(const std::string& file) const = 0; - virtual void load(const std::string& file) = 0; - - virtual AlgoProperty get_preference() const = 0; - - // Some algorithms don't save the building dataset in their indices. - // So they should be given the access to that dataset during searching. - // The advantage of this way is that index has smaller size - // and many indices can share one dataset. - // - // SearchParam::needs_dataset() of such algorithm should be true, - // and set_search_dataset() should save the passed-in pointer somewhere. - // The client code should call set_search_dataset() before searching, - // and should not release dataset before searching is finished. - virtual void set_search_dataset(const T* /*dataset*/, size_t /*nrow*/){}; - - /** - * Make a shallow copy of the ANN wrapper that shares the resources and ensures thread-safe access - * to them. */ - virtual auto copy() -> std::unique_ptr> = 0; -}; - -} // namespace raft::bench::ann - -#define REGISTER_ALGO_INSTANCE(DataT) \ - template auto raft::bench::ann::create_algo( \ - const std::string&, const std::string&, int, const nlohmann::json&, const std::vector&) \ - ->std::unique_ptr>; \ - template auto raft::bench::ann::create_search_param(const std::string&, \ - const nlohmann::json&) \ - ->std::unique_ptr::AnnSearchParam>; diff --git a/cpp/bench/ann/src/common/benchmark.cpp b/cpp/bench/ann/src/common/benchmark.cpp deleted file mode 100644 index 5510abf42f..0000000000 --- a/cpp/bench/ann/src/common/benchmark.cpp +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// clang-format off -#include "cuda_stub.hpp" // must go first -// clang-format on - -#include "ann_types.hpp" - -#include -#define JSON_DIAGNOSTICS 1 -#include - -#include -#include -#include - -namespace raft::bench::ann { - -struct lib_handle { - void* handle{nullptr}; - explicit lib_handle(const std::string& name) - { - handle = dlopen(name.c_str(), RTLD_LAZY | RTLD_LOCAL); - if (handle == nullptr) { - auto error_msg = "Failed to load " + name; - auto err = dlerror(); - if (err != nullptr && err[0] != '\0') { error_msg += ": " + std::string(err); } - throw std::runtime_error(error_msg); - } - } - ~lib_handle() noexcept - { - if (handle != nullptr) { dlclose(handle); } - } -}; - -auto load_lib(const std::string& algo) -> void* -{ - static std::unordered_map libs{}; - auto found = libs.find(algo); - - if (found != libs.end()) { return found->second.handle; } - auto lib_name = "lib" + algo + "_ann_bench.so"; - return libs.emplace(algo, lib_name).first->second.handle; -} - -auto get_fun_name(void* addr) -> std::string -{ - Dl_info dl_info; - if (dladdr(addr, &dl_info) != 0) { - if (dl_info.dli_sname != nullptr && dl_info.dli_sname[0] != '\0') { - return std::string{dl_info.dli_sname}; - } - } - throw std::logic_error("Failed to find out name of the looked up function"); -} - -template -auto create_algo(const std::string& algo, - const std::string& distance, - int dim, - const nlohmann::json& conf, - const std::vector& dev_list) -> std::unique_ptr> -{ - static auto fname = get_fun_name(reinterpret_cast(&create_algo)); - auto handle = load_lib(algo); - auto fun_addr = dlsym(handle, fname.c_str()); - if (fun_addr == nullptr) { - throw std::runtime_error("Couldn't load the create_algo function (" + algo + ")"); - } - auto fun = reinterpret_cast)>(fun_addr); - return fun(algo, distance, dim, conf, dev_list); -} - -template -std::unique_ptr::AnnSearchParam> create_search_param( - const std::string& algo, const nlohmann::json& conf) -{ - static auto fname = get_fun_name(reinterpret_cast(&create_search_param)); - auto handle = load_lib(algo); - auto fun_addr = dlsym(handle, fname.c_str()); - if (fun_addr == nullptr) { - throw std::runtime_error("Couldn't load the create_search_param function (" + algo + ")"); - } - auto fun = reinterpret_cast)>(fun_addr); - return fun(algo, conf); -} - -}; // namespace raft::bench::ann - -REGISTER_ALGO_INSTANCE(float); -REGISTER_ALGO_INSTANCE(std::int8_t); -REGISTER_ALGO_INSTANCE(std::uint8_t); - -#include "benchmark.hpp" - -int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); } diff --git a/cpp/bench/ann/src/common/benchmark.hpp b/cpp/bench/ann/src/common/benchmark.hpp deleted file mode 100644 index 185d54a0a3..0000000000 --- a/cpp/bench/ann/src/common/benchmark.hpp +++ /dev/null @@ -1,736 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "ann_types.hpp" -#include "conf.hpp" -#include "dataset.hpp" -#include "util.hpp" - -#include - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft::bench::ann { - -static inline std::unique_ptr current_algo{nullptr}; -static inline std::unique_ptr current_algo_props{nullptr}; - -using kv_series = std::vector>>; - -inline auto apply_overrides(const std::vector& configs, - const kv_series& overrides, - std::size_t override_idx = 0) -> std::vector -{ - std::vector results{}; - if (override_idx >= overrides.size()) { - auto n = configs.size(); - for (size_t i = 0; i < n; i++) { - auto c = configs[i]; - c["override_suffix"] = n > 1 ? "/" + std::to_string(i) : ""; - results.push_back(c); - } - return results; - } - auto rec_configs = apply_overrides(configs, overrides, override_idx + 1); - auto [key, vals] = overrides[override_idx]; - auto n = vals.size(); - for (size_t i = 0; i < n; i++) { - const auto& val = vals[i]; - for (auto rc : rec_configs) { - if (n > 1) { - rc["override_suffix"] = - static_cast(rc["override_suffix"]) + "/" + std::to_string(i); - } - rc[key] = val; - results.push_back(rc); - } - } - return results; -} - -inline auto apply_overrides(const nlohmann::json& config, - const kv_series& overrides, - std::size_t override_idx = 0) -{ - return apply_overrides(std::vector{config}, overrides, 0); -} - -inline void dump_parameters(::benchmark::State& state, nlohmann::json params) -{ - std::string label = ""; - bool label_empty = true; - for (auto& [key, val] : params.items()) { - if (val.is_number()) { - state.counters.insert({{key, val}}); - } else if (val.is_boolean()) { - state.counters.insert({{key, val ? 1.0 : 0.0}}); - } else { - auto kv = key + "=" + val.dump(); - if (label_empty) { - label = kv; - } else { - label += "#" + kv; - } - label_empty = false; - } - } - if (!label_empty) { state.SetLabel(label); } -} - -inline auto parse_algo_property(AlgoProperty prop, const nlohmann::json& conf) -> AlgoProperty -{ - if (conf.contains("dataset_memory_type")) { - prop.dataset_memory_type = parse_memory_type(conf.at("dataset_memory_type")); - } - if (conf.contains("query_memory_type")) { - prop.query_memory_type = parse_memory_type(conf.at("query_memory_type")); - } - return prop; -}; - -template -void bench_build(::benchmark::State& state, - std::shared_ptr> dataset, - Configuration::Index index, - bool force_overwrite) -{ - // NB: these two thread-local vars can be used within algo wrappers - raft::bench::ann::benchmark_thread_id = state.thread_index(); - raft::bench::ann::benchmark_n_threads = state.threads(); - dump_parameters(state, index.build_param); - if (file_exists(index.file)) { - if (force_overwrite) { - log_info("Overwriting file: %s", index.file.c_str()); - } else { - return state.SkipWithMessage( - "Index file already exists (use --force to overwrite the index)."); - } - } - - std::unique_ptr> algo; - try { - algo = ann::create_algo( - index.algo, dataset->distance(), dataset->dim(), index.build_param, index.dev_list); - } catch (const std::exception& e) { - return state.SkipWithError("Failed to create an algo: " + std::string(e.what())); - } - - const auto algo_property = parse_algo_property(algo->get_preference(), index.build_param); - - const T* base_set = dataset->base_set(algo_property.dataset_memory_type); - std::size_t index_size = dataset->base_set_size(); - - cuda_timer gpu_timer{algo}; - { - nvtx_case nvtx{state.name()}; - for (auto _ : state) { - [[maybe_unused]] auto ntx_lap = nvtx.lap(); - [[maybe_unused]] auto gpu_lap = gpu_timer.lap(); - try { - algo->build(base_set, index_size); - } catch (const std::exception& e) { - state.SkipWithError(std::string(e.what())); - } - } - } - if (gpu_timer.active()) { - state.counters.insert({"GPU", {gpu_timer.total_time(), benchmark::Counter::kAvgIterations}}); - } - state.counters.insert({{"index_size", index_size}}); - - if (state.skipped()) { return; } - make_sure_parent_dir_exists(index.file); - algo->save(index.file); -} - -template -void bench_search(::benchmark::State& state, - Configuration::Index index, - std::size_t search_param_ix, - std::shared_ptr> dataset, - Objective metric_objective) -{ - // NB: these two thread-local vars can be used within algo wrappers - raft::bench::ann::benchmark_thread_id = state.thread_index(); - raft::bench::ann::benchmark_n_threads = state.threads(); - std::size_t queries_processed = 0; - - const auto& sp_json = index.search_params[search_param_ix]; - - if (state.thread_index() == 0) { dump_parameters(state, sp_json); } - - // NB: `k` and `n_queries` are guaranteed to be populated in conf.cpp - const std::uint32_t k = sp_json["k"]; - // Amount of data processes in one go - const std::size_t n_queries = sp_json["n_queries"]; - // Round down the query data to a multiple of the batch size to loop over full batches of data - const std::size_t query_set_size = (dataset->query_set_size() / n_queries) * n_queries; - - if (dataset->query_set_size() < n_queries) { - std::stringstream msg; - msg << "Not enough queries in benchmark set. Expected " << n_queries << ", actual " - << dataset->query_set_size(); - state.SkipWithError(msg.str()); - return; - } - - // Each thread start from a different offset, so that the queries that they process do not - // overlap. - std::ptrdiff_t batch_offset = (state.thread_index() * n_queries) % query_set_size; - std::ptrdiff_t queries_stride = state.threads() * n_queries; - // Output is saved into a contiguous buffer (separate buffers for each thread). - std::ptrdiff_t out_offset = 0; - - const T* query_set = nullptr; - - if (!file_exists(index.file)) { - state.SkipWithError("Index file is missing. Run the benchmark in the build mode first."); - return; - } - - /** - * Make sure the first thread loads the algo and dataset - */ - progress_barrier load_barrier{}; - if (load_barrier.arrive(1) == 0) { - // algo is static to cache it between close search runs to save time on index loading - static std::string index_file = ""; - if (index.file != index_file) { - current_algo.reset(); - index_file = index.file; - } - - std::unique_ptr::AnnSearchParam> search_param; - ANN* algo; - try { - if (!current_algo || (algo = dynamic_cast*>(current_algo.get())) == nullptr) { - auto ualgo = ann::create_algo( - index.algo, dataset->distance(), dataset->dim(), index.build_param, index.dev_list); - algo = ualgo.get(); - algo->load(index_file); - current_algo = std::move(ualgo); - } - search_param = ann::create_search_param(index.algo, sp_json); - search_param->metric_objective = metric_objective; - } catch (const std::exception& e) { - state.SkipWithError("Failed to create an algo: " + std::string(e.what())); - return; - } - - current_algo_props = std::make_unique( - std::move(parse_algo_property(algo->get_preference(), sp_json))); - - if (search_param->needs_dataset()) { - try { - algo->set_search_dataset(dataset->base_set(current_algo_props->dataset_memory_type), - dataset->base_set_size()); - } catch (const std::exception& ex) { - state.SkipWithError("The algorithm '" + index.name + - "' requires the base set, but it's not available. " + - "Exception: " + std::string(ex.what())); - return; - } - } - try { - algo->set_search_param(*search_param); - } catch (const std::exception& ex) { - state.SkipWithError("An error occurred setting search parameters: " + std::string(ex.what())); - return; - } - - query_set = dataset->query_set(current_algo_props->query_memory_type); - load_barrier.arrive(state.threads()); - } else { - // All other threads will wait for the first thread to initialize the algo. - load_barrier.wait(state.threads() * 2); - // gbench ensures that all threads are synchronized at the start of the benchmark loop. - // We are accessing shared variables (like current_algo, current_algo_probs) before the - // benchmark loop, therefore the synchronization here is necessary. - } - query_set = dataset->query_set(current_algo_props->query_memory_type); - - /** - * Each thread will manage its own outputs - */ - using index_type = AnnBase::index_type; - constexpr size_t kAlignResultBuf = 64; - size_t result_elem_count = k * query_set_size; - result_elem_count = - ((result_elem_count + kAlignResultBuf - 1) / kAlignResultBuf) * kAlignResultBuf; - auto& result_buf = - get_result_buffer_from_global_pool(result_elem_count * (sizeof(float) + sizeof(index_type))); - auto* neighbors_ptr = - reinterpret_cast(result_buf.data(current_algo_props->query_memory_type)); - auto* distances_ptr = reinterpret_cast(neighbors_ptr + result_elem_count); - - { - nvtx_case nvtx{state.name()}; - - std::unique_ptr> algo{nullptr}; - try { - dynamic_cast*>(current_algo.get())->copy().swap(algo); - } catch (const std::exception& e) { - state.SkipWithError("Algo::copy: " + std::string(e.what())); - return; - } - // Initialize with algo, so that the timer.lap() object can sync with algo::get_sync_stream() - cuda_timer gpu_timer{algo}; - auto start = std::chrono::high_resolution_clock::now(); - for (auto _ : state) { - [[maybe_unused]] auto ntx_lap = nvtx.lap(); - [[maybe_unused]] auto gpu_lap = gpu_timer.lap(); - try { - algo->search(query_set + batch_offset * dataset->dim(), - n_queries, - k, - neighbors_ptr + out_offset * k, - distances_ptr + out_offset * k); - } catch (const std::exception& e) { - state.SkipWithError("Benchmark loop: " + std::string(e.what())); - break; - } - - // advance to the next batch - batch_offset = (batch_offset + queries_stride) % query_set_size; - out_offset = (out_offset + n_queries) % query_set_size; - - queries_processed += n_queries; - } - auto end = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast>(end - start).count(); - if (state.thread_index() == 0) { state.counters.insert({{"end_to_end", duration}}); } - state.counters.insert({"Latency", {duration, benchmark::Counter::kAvgIterations}}); - - if (gpu_timer.active()) { - state.counters.insert({"GPU", {gpu_timer.total_time(), benchmark::Counter::kAvgIterations}}); - } - } - - state.SetItemsProcessed(queries_processed); - - // This will be the total number of queries across all threads - state.counters.insert({{"total_queries", queries_processed}}); - - if (state.skipped()) { return; } - - // Each thread calculates recall on their partition of queries. - // evaluate recall - if (dataset->max_k() >= k) { - const std::int32_t* gt = dataset->gt_set(); - const std::uint32_t max_k = dataset->max_k(); - result_buf.transfer_data(MemoryType::Host, current_algo_props->query_memory_type); - auto* neighbors_host = reinterpret_cast(result_buf.data(MemoryType::Host)); - std::size_t rows = std::min(queries_processed, query_set_size); - std::size_t match_count = 0; - std::size_t total_count = rows * static_cast(k); - - // We go through the groundtruth with same stride as the benchmark loop. - size_t out_offset = 0; - size_t batch_offset = (state.thread_index() * n_queries) % query_set_size; - while (out_offset < rows) { - for (std::size_t i = 0; i < n_queries; i++) { - size_t i_orig_idx = batch_offset + i; - size_t i_out_idx = out_offset + i; - if (i_out_idx < rows) { - for (std::uint32_t j = 0; j < k; j++) { - auto act_idx = std::int32_t(neighbors_host[i_out_idx * k + j]); - for (std::uint32_t l = 0; l < k; l++) { - auto exp_idx = gt[i_orig_idx * max_k + l]; - if (act_idx == exp_idx) { - match_count++; - break; - } - } - } - } - } - out_offset += n_queries; - batch_offset = (batch_offset + queries_stride) % query_set_size; - } - double actual_recall = static_cast(match_count) / static_cast(total_count); - state.counters.insert({"Recall", {actual_recall, benchmark::Counter::kAvgThreads}}); - } -} - -inline void printf_usage() -{ - ::benchmark::PrintDefaultHelp(); - fprintf(stdout, - " [--build|--search] \n" - " [--force]\n" - " [--data_prefix=]\n" - " [--index_prefix=]\n" - " [--override_kv=]\n" - " [--mode=\n" - " [--threads=min[:max]]\n" - " .json\n" - "\n" - "Note the non-standard benchmark parameters:\n" - " --build: build mode, will build index\n" - " --search: search mode, will search using the built index\n" - " one and only one of --build and --search should be specified\n" - " --force: force overwriting existing index files\n" - " --data_prefix=:" - " prepend to dataset file paths specified in the .json (default = " - "'data/').\n" - " --index_prefix=:" - " prepend to index file paths specified in the .json (default = " - "'index/').\n" - " --override_kv=:" - " override a build/search key one or more times multiplying the number of configurations;" - " you can use this parameter multiple times to get the Cartesian product of benchmark" - " configs.\n" - " --mode=" - " run the benchmarks in latency (accumulate times spent in each batch) or " - " throughput (pipeline batches and measure end-to-end) mode\n" - " --threads=min[:max] specify the number threads to use for throughput benchmark." - " Power of 2 values between 'min' and 'max' will be used. If only 'min' is specified," - " then a single test is run with 'min' threads. By default min=1, max=.\n"); -} - -template -void register_build(std::shared_ptr> dataset, - std::vector indices, - bool force_overwrite) -{ - for (auto index : indices) { - auto suf = static_cast(index.build_param["override_suffix"]); - auto file_suf = suf; - index.build_param.erase("override_suffix"); - std::replace(file_suf.begin(), file_suf.end(), '/', '-'); - index.file += file_suf; - auto* b = ::benchmark::RegisterBenchmark( - index.name + suf, bench_build, dataset, index, force_overwrite); - b->Unit(benchmark::kSecond); - b->MeasureProcessCPUTime(); - b->UseRealTime(); - } -} - -template -void register_search(std::shared_ptr> dataset, - std::vector indices, - Objective metric_objective, - const std::vector& threads) -{ - for (auto index : indices) { - for (std::size_t i = 0; i < index.search_params.size(); i++) { - auto suf = static_cast(index.search_params[i]["override_suffix"]); - index.search_params[i].erase("override_suffix"); - - auto* b = ::benchmark::RegisterBenchmark( - index.name + suf, bench_search, index, i, dataset, metric_objective) - ->Unit(benchmark::kMillisecond) - /** - * The following are important for getting accuracy QPS measurements on both CPU - * and GPU These make sure that - * - `end_to_end` ~ (`Time` * `Iterations`) - * - `items_per_second` ~ (`total_queries` / `end_to_end`) - * - Throughput = `items_per_second` - */ - ->MeasureProcessCPUTime() - ->UseRealTime(); - if (metric_objective == Objective::THROUGHPUT) { - if (index.algo.find("faiss_gpu") != std::string::npos) { - log_warn( - "FAISS GPU does not work in throughput mode because the underlying " - "StandardGpuResources object is not thread-safe. This will cause unexpected results"); - } - b->ThreadRange(threads[0], threads[1]); - } - } - } -} - -template -void dispatch_benchmark(const Configuration& conf, - bool force_overwrite, - bool build_mode, - bool search_mode, - std::string data_prefix, - std::string index_prefix, - kv_series override_kv, - Objective metric_objective, - const std::vector& threads) -{ - if (cudart.found()) { - for (auto [key, value] : cuda_info()) { - ::benchmark::AddCustomContext(key, value); - } - } - const auto dataset_conf = conf.get_dataset_conf(); - auto base_file = combine_path(data_prefix, dataset_conf.base_file); - auto query_file = combine_path(data_prefix, dataset_conf.query_file); - auto gt_file = dataset_conf.groundtruth_neighbors_file; - if (gt_file.has_value()) { gt_file.emplace(combine_path(data_prefix, gt_file.value())); } - auto dataset = std::make_shared>(dataset_conf.name, - base_file, - dataset_conf.subset_first_row, - dataset_conf.subset_size, - query_file, - dataset_conf.distance, - gt_file); - ::benchmark::AddCustomContext("dataset", dataset_conf.name); - ::benchmark::AddCustomContext("distance", dataset_conf.distance); - std::vector indices = conf.get_indices(); - if (build_mode) { - if (file_exists(base_file)) { - log_info("Using the dataset file '%s'", base_file.c_str()); - ::benchmark::AddCustomContext("n_records", std::to_string(dataset->base_set_size())); - ::benchmark::AddCustomContext("dim", std::to_string(dataset->dim())); - } else { - log_warn("Dataset file '%s' does not exist; benchmarking index building is impossible.", - base_file.c_str()); - } - std::vector more_indices{}; - for (auto& index : indices) { - for (auto param : apply_overrides(index.build_param, override_kv)) { - auto modified_index = index; - modified_index.build_param = param; - modified_index.file = combine_path(index_prefix, modified_index.file); - more_indices.push_back(modified_index); - } - } - register_build(dataset, more_indices, force_overwrite); - } else if (search_mode) { - if (file_exists(query_file)) { - log_info("Using the query file '%s'", query_file.c_str()); - ::benchmark::AddCustomContext("max_n_queries", std::to_string(dataset->query_set_size())); - ::benchmark::AddCustomContext("dim", std::to_string(dataset->dim())); - if (gt_file.has_value()) { - if (file_exists(*gt_file)) { - log_info("Using the ground truth file '%s'", gt_file->c_str()); - ::benchmark::AddCustomContext("max_k", std::to_string(dataset->max_k())); - } else { - log_warn("Ground truth file '%s' does not exist; the recall won't be reported.", - gt_file->c_str()); - } - } else { - log_warn( - "Ground truth file is not provided; the recall won't be reported. NB: use " - "the 'groundtruth_neighbors_file' alongside the 'query_file' key to specify the " - "path to " - "the ground truth in your conf.json."); - } - } else { - log_warn("Query file '%s' does not exist; benchmarking search is impossible.", - query_file.c_str()); - } - for (auto& index : indices) { - index.search_params = apply_overrides(index.search_params, override_kv); - index.file = combine_path(index_prefix, index.file); - } - register_search(dataset, indices, metric_objective, threads); - } -} - -inline auto parse_bool_flag(const char* arg, const char* pat, bool& result) -> bool -{ - if (strcmp(arg, pat) == 0) { - result = true; - return true; - } - return false; -} - -inline auto parse_string_flag(const char* arg, const char* pat, std::string& result) -> bool -{ - auto n = strlen(pat); - if (strncmp(pat, arg, strlen(pat)) == 0) { - result = arg + n + 1; - return true; - } - return false; -} - -inline auto run_main(int argc, char** argv) -> int -{ - bool force_overwrite = false; - bool build_mode = false; - bool search_mode = false; - std::string data_prefix = "data"; - std::string index_prefix = "index"; - std::string new_override_kv = ""; - std::string mode = "latency"; - std::string threads_arg_txt = ""; - std::vector threads = {1, -1}; // min_thread, max_thread - std::string log_level_str = ""; - int raft_log_level = raft::logger::get(RAFT_NAME).get_level(); - kv_series override_kv{}; - - char arg0_default[] = "benchmark"; // NOLINT - char* args_default = arg0_default; - if (!argv) { - argc = 1; - argv = &args_default; - } - if (argc == 1) { - printf_usage(); - return -1; - } - - char* conf_path = argv[--argc]; - std::ifstream conf_stream(conf_path); - - for (int i = 1; i < argc; i++) { - if (parse_bool_flag(argv[i], "--force", force_overwrite) || - parse_bool_flag(argv[i], "--build", build_mode) || - parse_bool_flag(argv[i], "--search", search_mode) || - parse_string_flag(argv[i], "--data_prefix", data_prefix) || - parse_string_flag(argv[i], "--index_prefix", index_prefix) || - parse_string_flag(argv[i], "--mode", mode) || - parse_string_flag(argv[i], "--override_kv", new_override_kv) || - parse_string_flag(argv[i], "--threads", threads_arg_txt) || - parse_string_flag(argv[i], "--raft_log_level", log_level_str)) { - if (!log_level_str.empty()) { - raft_log_level = std::stoi(log_level_str); - log_level_str = ""; - } - if (!threads_arg_txt.empty()) { - auto threads_arg = split(threads_arg_txt, ':'); - threads[0] = std::stoi(threads_arg[0]); - if (threads_arg.size() > 1) { - threads[1] = std::stoi(threads_arg[1]); - } else { - threads[1] = threads[0]; - } - threads_arg_txt = ""; - } - if (!new_override_kv.empty()) { - auto kvv = split(new_override_kv, ':'); - auto key = kvv[0]; - std::vector vals{}; - for (std::size_t j = 1; j < kvv.size(); j++) { - vals.push_back(nlohmann::json::parse(kvv[j])); - } - override_kv.emplace_back(key, vals); - new_override_kv = ""; - } - for (int j = i; j < argc - 1; j++) { - argv[j] = argv[j + 1]; - } - argc--; - i--; - } - } - - raft::logger::get(RAFT_NAME).set_level(raft_log_level); - - Objective metric_objective = Objective::LATENCY; - if (mode == "throughput") { metric_objective = Objective::THROUGHPUT; } - - int max_threads = - (metric_objective == Objective::THROUGHPUT) ? std::thread::hardware_concurrency() : 1; - if (threads[1] == -1) threads[1] = max_threads; - - if (metric_objective == Objective::LATENCY) { - if (threads[0] != 1 || threads[1] != 1) { - log_warn("Latency mode enabled. Overriding threads arg, running with single thread."); - threads = {1, 1}; - } - } - - if (build_mode == search_mode) { - log_error("One and only one of --build and --search should be specified"); - printf_usage(); - return -1; - } - - if (!conf_stream) { - log_error("Can't open configuration file: %s", conf_path); - return -1; - } - - if (cudart.needed() && !cudart.found()) { - log_warn("cudart library is not found, GPU-based indices won't work."); - } - - Configuration conf(conf_stream); - std::string dtype = conf.get_dataset_conf().dtype; - - if (dtype == "float") { - dispatch_benchmark(conf, - force_overwrite, - build_mode, - search_mode, - data_prefix, - index_prefix, - override_kv, - metric_objective, - threads); - } else if (dtype == "half") { - dispatch_benchmark(conf, - force_overwrite, - build_mode, - search_mode, - data_prefix, - index_prefix, - override_kv, - metric_objective, - threads); - } else if (dtype == "uint8") { - dispatch_benchmark(conf, - force_overwrite, - build_mode, - search_mode, - data_prefix, - index_prefix, - override_kv, - metric_objective, - threads); - } else if (dtype == "int8") { - dispatch_benchmark(conf, - force_overwrite, - build_mode, - search_mode, - data_prefix, - index_prefix, - override_kv, - metric_objective, - threads); - } else { - log_error("datatype '%s' is not supported", dtype.c_str()); - return -1; - } - - ::benchmark::Initialize(&argc, argv, printf_usage); - if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return -1; - ::benchmark::RunSpecifiedBenchmarks(); - ::benchmark::Shutdown(); - // Release a possibly cached ANN object, so that it cannot be alive longer than the handle - // to a shared library it depends on (dynamic benchmark executable). - current_algo.reset(); - current_algo_props.reset(); - reset_global_device_resources(); - return 0; -} -}; // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/common/conf.hpp b/cpp/bench/ann/src/common/conf.hpp deleted file mode 100644 index 92ba86c6cf..0000000000 --- a/cpp/bench/ann/src/common/conf.hpp +++ /dev/null @@ -1,161 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "util.hpp" - -#include -#include -#include -#include -#include - -#define JSON_DIAGNOSTICS 1 -#include - -namespace raft::bench::ann { - -class Configuration { - public: - struct Index { - std::string name; - std::string algo; - nlohmann::json build_param; - std::string file; - std::vector dev_list; - - int batch_size; - int k; - std::vector search_params; - }; - - struct DatasetConf { - std::string name; - std::string base_file; - // use only a subset of base_file, - // the range of rows is [subset_first_row, subset_first_row + subset_size) - // however, subset_size = 0 means using all rows after subset_first_row - // that is, the subset is [subset_first_row, #rows in base_file) - size_t subset_first_row{0}; - size_t subset_size{0}; - std::string query_file; - std::string distance; - std::optional groundtruth_neighbors_file{std::nullopt}; - - // data type of input dataset, possible values ["float", "int8", "uint8"] - std::string dtype; - }; - - explicit inline Configuration(std::istream& conf_stream) - { - // to enable comments in json - auto conf = nlohmann::json::parse(conf_stream, nullptr, true, true); - - parse_dataset_(conf.at("dataset")); - parse_index_(conf.at("index"), conf.at("search_basic_param")); - } - - [[nodiscard]] inline auto get_dataset_conf() const -> DatasetConf { return dataset_conf_; } - [[nodiscard]] inline auto get_indices() const -> std::vector { return indices_; }; - - private: - inline void parse_dataset_(const nlohmann::json& conf) - { - dataset_conf_.name = conf.at("name"); - dataset_conf_.base_file = conf.at("base_file"); - dataset_conf_.query_file = conf.at("query_file"); - dataset_conf_.distance = conf.at("distance"); - - if (conf.contains("groundtruth_neighbors_file")) { - dataset_conf_.groundtruth_neighbors_file = conf.at("groundtruth_neighbors_file"); - } - if (conf.contains("subset_first_row")) { - dataset_conf_.subset_first_row = conf.at("subset_first_row"); - } - if (conf.contains("subset_size")) { dataset_conf_.subset_size = conf.at("subset_size"); } - - if (conf.contains("dtype")) { - dataset_conf_.dtype = conf.at("dtype"); - } else { - auto filename = dataset_conf_.base_file; - if (filename.size() > 6 && filename.compare(filename.size() - 6, 6, "f16bin") == 0) { - dataset_conf_.dtype = "half"; - } else if (filename.size() > 9 && - filename.compare(filename.size() - 9, 9, "fp16.fbin") == 0) { - dataset_conf_.dtype = "half"; - } else if (filename.size() > 4 && filename.compare(filename.size() - 4, 4, "fbin") == 0) { - dataset_conf_.dtype = "float"; - } else if (filename.size() > 5 && filename.compare(filename.size() - 5, 5, "u8bin") == 0) { - dataset_conf_.dtype = "uint8"; - } else if (filename.size() > 5 && filename.compare(filename.size() - 5, 5, "i8bin") == 0) { - dataset_conf_.dtype = "int8"; - } else { - log_error("Could not determine data type of the dataset %s", filename.c_str()); - } - } - } - inline void parse_index_(const nlohmann::json& index_conf, - const nlohmann::json& search_basic_conf) - { - const int batch_size = search_basic_conf.at("batch_size"); - const int k = search_basic_conf.at("k"); - - for (const auto& conf : index_conf) { - Index index; - index.name = conf.at("name"); - index.algo = conf.at("algo"); - index.build_param = conf.at("build_param"); - index.file = conf.at("file"); - index.batch_size = batch_size; - index.k = k; - - if (conf.contains("multigpu")) { - for (auto it : conf.at("multigpu")) { - index.dev_list.push_back(it); - } - if (index.dev_list.empty()) { throw std::runtime_error("dev_list shouln't be empty!"); } - index.dev_list.shrink_to_fit(); - index.build_param["multigpu"] = conf["multigpu"]; - } - - for (auto param : conf.at("search_params")) { - /* ### Special parameters for backward compatibility ### - - - Local values of `k` and `n_queries` take priority. - - The legacy "batch_size" renamed to `n_queries`. - - Basic search params are used otherwise. - */ - if (!param.contains("k")) { param["k"] = k; } - if (!param.contains("n_queries")) { - if (param.contains("batch_size")) { - param["n_queries"] = param["batch_size"]; - param.erase("batch_size"); - } else { - param["n_queries"] = batch_size; - } - } - index.search_params.push_back(param); - } - - indices_.push_back(index); - } - } - - DatasetConf dataset_conf_; - std::vector indices_; -}; - -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp b/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp deleted file mode 100644 index 27be26dfe9..0000000000 --- a/cpp/bench/ann/src/common/cuda_huge_page_resource.hpp +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include -#include - -#include - -#include -#include - -namespace raft::mr { -/** - * @brief `device_memory_resource` derived class that uses mmap to allocate memory. - * This class enables memory allocation using huge pages. - * It is assumed that the allocated memory is directly accessible on device. This currently only - * works on GH systems. - * - * TODO(tfeher): consider improving or removing this helper once we made progress with - * https://github.com/rapidsai/raft/issues/1819 - */ -class cuda_huge_page_resource final : public rmm::mr::device_memory_resource { - public: - cuda_huge_page_resource() = default; - ~cuda_huge_page_resource() override = default; - cuda_huge_page_resource(cuda_huge_page_resource const&) = default; - cuda_huge_page_resource(cuda_huge_page_resource&&) = default; - cuda_huge_page_resource& operator=(cuda_huge_page_resource const&) = default; - cuda_huge_page_resource& operator=(cuda_huge_page_resource&&) = default; - - private: - /** - * @brief Allocates memory of size at least `bytes` using cudaMalloc. - * - * The returned pointer has at least 256B alignment. - * - * @note Stream argument is ignored - * - * @throws `rmm::bad_alloc` if the requested allocation could not be fulfilled - * - * @param bytes The size, in bytes, of the allocation - * @return void* Pointer to the newly allocated memory - */ - void* do_allocate(std::size_t bytes, rmm::cuda_stream_view) override - { - void* _addr{nullptr}; - _addr = mmap(NULL, bytes, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); - if (_addr == MAP_FAILED) { RAFT_FAIL("huge_page_resource::MAP FAILED"); } - if (madvise(_addr, bytes, MADV_HUGEPAGE) == -1) { - munmap(_addr, bytes); - RAFT_FAIL("huge_page_resource::madvise MADV_HUGEPAGE"); - } - memset(_addr, 0, bytes); - return _addr; - } - - /** - * @brief Deallocate memory pointed to by \p p. - * - * @note Stream argument is ignored. - * - * @throws Nothing. - * - * @param p Pointer to be deallocated - */ - void do_deallocate(void* ptr, std::size_t size, rmm::cuda_stream_view) override - { - if (munmap(ptr, size) == -1) { RAFT_FAIL("huge_page_resource::munmap"); } - } - - /** - * @brief Compare this resource to another. - * - * Two cuda_huge_page_resources always compare equal, because they can each - * deallocate memory allocated by the other. - * - * @throws Nothing. - * - * @param other The other resource to compare to - * @return true If the two resources are equivalent - * @return false If the two resources are not equal - */ - [[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override - { - return dynamic_cast(&other) != nullptr; - } -}; -} // namespace raft::mr diff --git a/cpp/bench/ann/src/common/cuda_pinned_resource.hpp b/cpp/bench/ann/src/common/cuda_pinned_resource.hpp deleted file mode 100644 index 3256fc293c..0000000000 --- a/cpp/bench/ann/src/common/cuda_pinned_resource.hpp +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include - -#include - -namespace raft::mr { -/** - * @brief `device_memory_resource` derived class that uses cudaMallocHost/Free for - * allocation/deallocation. - * - * This is almost the same as rmm::mr::host::pinned_memory_resource, but it has - * device_memory_resource as base class. Pinned memory can be accessed from device, - * and using this allocator we can create device_mdarray backed by pinned allocator. - * - * TODO(tfeher): it would be preferred to just rely on the existing allocator from rmm - * (pinned_memory_resource), but that is incompatible with the container_policy class - * for device matrix, because the latter expects a device_memory_resource. We shall - * revise this once we progress with Issue https://github.com/rapidsai/raft/issues/1819 - */ -class cuda_pinned_resource final : public rmm::mr::device_memory_resource { - public: - cuda_pinned_resource() = default; - ~cuda_pinned_resource() override = default; - cuda_pinned_resource(cuda_pinned_resource const&) = default; - cuda_pinned_resource(cuda_pinned_resource&&) = default; - cuda_pinned_resource& operator=(cuda_pinned_resource const&) = default; - cuda_pinned_resource& operator=(cuda_pinned_resource&&) = default; - - private: - /** - * @brief Allocates memory of size at least `bytes` using cudaMalloc. - * - * The returned pointer has at least 256B alignment. - * - * @note Stream argument is ignored - * - * @throws `rmm::bad_alloc` if the requested allocation could not be fulfilled - * - * @param bytes The size, in bytes, of the allocation - * @return void* Pointer to the newly allocated memory - */ - void* do_allocate(std::size_t bytes, rmm::cuda_stream_view) override - { - void* ptr{nullptr}; - RMM_CUDA_TRY_ALLOC(cudaMallocHost(&ptr, bytes)); - return ptr; - } - - /** - * @brief Deallocate memory pointed to by \p p. - * - * @note Stream argument is ignored. - * - * @throws Nothing. - * - * @param p Pointer to be deallocated - */ - void do_deallocate(void* ptr, std::size_t, rmm::cuda_stream_view) override - { - RMM_ASSERT_CUDA_SUCCESS(cudaFreeHost(ptr)); - } - - /** - * @brief Compare this resource to another. - * - * Two cuda_pinned_resources always compare equal, because they can each - * deallocate memory allocated by the other. - * - * @throws Nothing. - * - * @param other The other resource to compare to - * @return true If the two resources are equivalent - * @return false If the two resources are not equal - */ - [[nodiscard]] bool do_is_equal(device_memory_resource const& other) const noexcept override - { - return dynamic_cast(&other) != nullptr; - } -}; -} // namespace raft::mr diff --git a/cpp/bench/ann/src/common/cuda_stub.hpp b/cpp/bench/ann/src/common/cuda_stub.hpp deleted file mode 100644 index 5ed138a86d..0000000000 --- a/cpp/bench/ann/src/common/cuda_stub.hpp +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* -The content of this header is governed by two preprocessor definitions: - - - BUILD_CPU_ONLY - whether none of the CUDA functions are used. - - ANN_BENCH_LINK_CUDART - dynamically link against this string if defined. - -___________________________________________________________________________________ -|BUILD_CPU_ONLY | ANN_BENCH_LINK_CUDART | cudart | cuda_runtime_api.h | -| | | found | needed | included | -|---------------|-----------------------|-----------|---------|--------------------| -| ON | | false | false | NO | -| ON | "cudart.so.xx.xx" | false | false | NO | -| OFF | | true | true | YES | -| OFF | "cudart.so.xx.xx" | | true | YES | ------------------------------------------------------------------------------------- -*/ - -#pragma once - -#ifndef BUILD_CPU_ONLY -#include -#include -#ifdef ANN_BENCH_LINK_CUDART -#include - -#include -#endif -#else -#include - -typedef void* cudaStream_t; -typedef void* cudaEvent_t; -typedef uint16_t half; -#endif - -namespace raft::bench::ann { - -struct cuda_lib_handle { - void* handle{nullptr}; - explicit cuda_lib_handle() - { -#ifdef ANN_BENCH_LINK_CUDART - constexpr int kFlags = RTLD_NOW | RTLD_GLOBAL | RTLD_DEEPBIND | RTLD_NODELETE; - // The full name of the linked cudart library 'cudart.so.MAJOR.MINOR.PATCH' - char libname[] = ANN_BENCH_LINK_CUDART; // NOLINT - handle = dlopen(ANN_BENCH_LINK_CUDART, kFlags); - if (handle != nullptr) { return; } - // try strip the PATCH - auto p = strrchr(libname, '.'); - p[0] = 0; - handle = dlopen(libname, kFlags); - if (handle != nullptr) { return; } - // try set the MINOR version to 0 - p = strrchr(libname, '.'); - p[1] = '0'; - p[2] = 0; - handle = dlopen(libname, kFlags); - if (handle != nullptr) { return; } - // try strip the MINOR - p[0] = 0; - handle = dlopen(libname, kFlags); - if (handle != nullptr) { return; } - // try strip the MAJOR - p = strrchr(libname, '.'); - p[0] = 0; - handle = dlopen(libname, kFlags); -#endif - } - ~cuda_lib_handle() noexcept - { -#ifdef ANN_BENCH_LINK_CUDART - if (handle != nullptr) { dlclose(handle); } -#endif - } - - template - auto sym(const char* name) -> Symbol - { -#ifdef ANN_BENCH_LINK_CUDART - return reinterpret_cast(dlsym(handle, name)); -#else - return nullptr; -#endif - } - - /** Whether this is NOT a cpu-only package. */ - [[nodiscard]] constexpr inline auto needed() const -> bool - { -#if defined(BUILD_CPU_ONLY) - return false; -#else - return true; -#endif - } - - /** CUDA found, either at compile time or at runtime. */ - [[nodiscard]] inline auto found() const -> bool - { -#if defined(BUILD_CPU_ONLY) - return false; -#elif defined(ANN_BENCH_LINK_CUDART) - return handle != nullptr; -#else - return true; -#endif - } -}; - -static inline cuda_lib_handle cudart{}; - -#ifdef ANN_BENCH_LINK_CUDART -namespace stub { - -[[gnu::weak, gnu::noinline]] cudaError_t cudaMemcpy(void* dst, - const void* src, - size_t count, - enum cudaMemcpyKind kind) -{ - return cudaSuccess; -} - -[[gnu::weak, gnu::noinline]] cudaError_t cudaMalloc(void** ptr, size_t size) -{ - *ptr = nullptr; - return cudaSuccess; -} -[[gnu::weak, gnu::noinline]] cudaError_t cudaMemset(void* devPtr, int value, size_t count) -{ - return cudaSuccess; -} -[[gnu::weak, gnu::noinline]] cudaError_t cudaFree(void* devPtr) { return cudaSuccess; } -[[gnu::weak, gnu::noinline]] cudaError_t cudaStreamCreate(cudaStream_t* pStream) -{ - *pStream = 0; - return cudaSuccess; -} -[[gnu::weak, gnu::noinline]] cudaError_t cudaStreamCreateWithFlags(cudaStream_t* pStream, - unsigned int flags) -{ - *pStream = 0; - return cudaSuccess; -} -[[gnu::weak, gnu::noinline]] cudaError_t cudaStreamDestroy(cudaStream_t pStream) -{ - return cudaSuccess; -} -[[gnu::weak, gnu::noinline]] cudaError_t cudaDeviceSynchronize() { return cudaSuccess; } - -[[gnu::weak, gnu::noinline]] cudaError_t cudaStreamSynchronize(cudaStream_t pStream) -{ - return cudaSuccess; -} -[[gnu::weak, gnu::noinline]] cudaError_t cudaEventCreate(cudaEvent_t* event) -{ - *event = 0; - return cudaSuccess; -} -[[gnu::weak, gnu::noinline]] cudaError_t cudaEventRecord(cudaEvent_t event, cudaStream_t stream) -{ - return cudaSuccess; -} -[[gnu::weak, gnu::noinline]] cudaError_t cudaEventSynchronize(cudaEvent_t event) -{ - return cudaSuccess; -} -[[gnu::weak, gnu::noinline]] cudaError_t cudaEventElapsedTime(float* ms, - cudaEvent_t start, - cudaEvent_t end) -{ - *ms = 0; - return cudaSuccess; -} -[[gnu::weak, gnu::noinline]] cudaError_t cudaEventDestroy(cudaEvent_t event) { return cudaSuccess; } -[[gnu::weak, gnu::noinline]] cudaError_t cudaGetDevice(int* device) -{ - *device = 0; - return cudaSuccess; -}; -[[gnu::weak, gnu::noinline]] cudaError_t cudaDriverGetVersion(int* driver) -{ - *driver = 0; - return cudaSuccess; -}; -[[gnu::weak, gnu::noinline]] cudaError_t cudaRuntimeGetVersion(int* runtime) -{ - *runtime = 0; - return cudaSuccess; -}; -[[gnu::weak, gnu::noinline]] cudaError_t cudaGetDeviceProperties(struct cudaDeviceProp* prop, - int device) -{ - *prop = cudaDeviceProp{}; - return cudaSuccess; -} - -} // namespace stub - -#define RAFT_DECLARE_CUDART(fun) \ - static inline decltype(&stub::fun) fun = \ - cudart.found() ? cudart.sym(#fun) : &stub::fun - -RAFT_DECLARE_CUDART(cudaMemcpy); -RAFT_DECLARE_CUDART(cudaMalloc); -RAFT_DECLARE_CUDART(cudaMemset); -RAFT_DECLARE_CUDART(cudaFree); -RAFT_DECLARE_CUDART(cudaStreamCreate); -RAFT_DECLARE_CUDART(cudaStreamCreateWithFlags); -RAFT_DECLARE_CUDART(cudaStreamDestroy); -RAFT_DECLARE_CUDART(cudaDeviceSynchronize); -RAFT_DECLARE_CUDART(cudaStreamSynchronize); -RAFT_DECLARE_CUDART(cudaEventCreate); -RAFT_DECLARE_CUDART(cudaEventRecord); -RAFT_DECLARE_CUDART(cudaEventSynchronize); -RAFT_DECLARE_CUDART(cudaEventElapsedTime); -RAFT_DECLARE_CUDART(cudaEventDestroy); -RAFT_DECLARE_CUDART(cudaGetDevice); -RAFT_DECLARE_CUDART(cudaDriverGetVersion); -RAFT_DECLARE_CUDART(cudaRuntimeGetVersion); -RAFT_DECLARE_CUDART(cudaGetDeviceProperties); - -#undef RAFT_DECLARE_CUDART -#endif - -}; // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/common/dataset.hpp b/cpp/bench/ann/src/common/dataset.hpp deleted file mode 100644 index 8fcff77d3c..0000000000 --- a/cpp/bench/ann/src/common/dataset.hpp +++ /dev/null @@ -1,495 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "util.hpp" - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft::bench::ann { - -// http://big-ann-benchmarks.com/index.html: -// binary format that starts with 8 bytes of data consisting of num_points(uint32_t) -// num_dimensions(uint32) followed by num_pts x num_dimensions x sizeof(type) bytes of -// data stored one vector after another. -// Data files will have suffixes .fbin, .u8bin, and .i8bin to represent float32, uint8 -// and int8 type data. -// As extensions for this benchmark, half and int data files will have suffixes .f16bin -// and .ibin, respectively. -template -class BinFile { - public: - BinFile(const std::string& file, - const std::string& mode, - uint32_t subset_first_row = 0, - uint32_t subset_size = 0); - ~BinFile() - { - if (mapped_ptr_ != nullptr) { unmap(); } - if (fp_ != nullptr) { fclose(fp_); } - } - BinFile(const BinFile&) = delete; - BinFile& operator=(const BinFile&) = delete; - - void get_shape(size_t* nrows, int* ndims) const - { - assert(read_mode_); - if (!fp_) { open_file_(); } - *nrows = nrows_; - *ndims = ndims_; - } - - void read(T* data) const - { - assert(read_mode_); - if (!fp_) { open_file_(); } - size_t total = static_cast(nrows_) * ndims_; - if (fread(data, sizeof(T), total, fp_) != total) { - throw std::runtime_error("fread() BinFile " + file_ + " failed"); - } - } - - void write(const T* data, uint32_t nrows, uint32_t ndims) - { - assert(!read_mode_); - if (!fp_) { open_file_(); } - if (fwrite(&nrows, sizeof(uint32_t), 1, fp_) != 1) { - throw std::runtime_error("fwrite() BinFile " + file_ + " failed"); - } - if (fwrite(&ndims, sizeof(uint32_t), 1, fp_) != 1) { - throw std::runtime_error("fwrite() BinFile " + file_ + " failed"); - } - - size_t total = static_cast(nrows) * ndims; - if (fwrite(data, sizeof(T), total, fp_) != total) { - throw std::runtime_error("fwrite() BinFile " + file_ + " failed"); - } - } - - T* map() const - { - assert(read_mode_); - if (!fp_) { open_file_(); } - int fid = fileno(fp_); - mapped_ptr_ = mmap(nullptr, file_size_, PROT_READ, MAP_PRIVATE, fid, 0); - if (mapped_ptr_ == MAP_FAILED) { - mapped_ptr_ = nullptr; - throw std::runtime_error("mmap error: Value of errno " + std::to_string(errno) + ", " + - std::string(strerror(errno))); - } - return reinterpret_cast(reinterpret_cast(mapped_ptr_) + 2 * sizeof(uint32_t) + - subset_first_row_ * ndims_ * sizeof(T)); - } - - void unmap() const - { - if (munmap(mapped_ptr_, file_size_) == -1) { - throw std::runtime_error("munmap error: " + std::string(strerror(errno))); - } - } - - private: - void check_suffix_(); - void open_file_() const; - - std::string file_; - bool read_mode_; - uint32_t subset_first_row_; - uint32_t subset_size_; - - mutable FILE* fp_{nullptr}; - mutable uint32_t nrows_; - mutable uint32_t ndims_; - mutable size_t file_size_; - mutable void* mapped_ptr_{nullptr}; -}; - -template -BinFile::BinFile(const std::string& file, - const std::string& mode, - uint32_t subset_first_row, - uint32_t subset_size) - : file_(file), - read_mode_(mode == "r"), - subset_first_row_(subset_first_row), - subset_size_(subset_size), - fp_(nullptr) -{ - check_suffix_(); - - if (!read_mode_) { - if (mode == "w") { - if (subset_first_row != 0) { - throw std::runtime_error("subset_first_row should be zero for write mode"); - } - if (subset_size != 0) { - throw std::runtime_error("subset_size should be zero for write mode"); - } - } else { - throw std::runtime_error("BinFile's mode must be either 'r' or 'w': " + file_); - } - } -} - -template -void BinFile::open_file_() const -{ - fp_ = fopen(file_.c_str(), read_mode_ ? "r" : "w"); - if (!fp_) { throw std::runtime_error("open BinFile failed: " + file_); } - - if (read_mode_) { - struct stat statbuf; - if (stat(file_.c_str(), &statbuf) != 0) { throw std::runtime_error("stat() failed: " + file_); } - file_size_ = statbuf.st_size; - - uint32_t header[2]; - if (fread(header, sizeof(uint32_t), 2, fp_) != 2) { - throw std::runtime_error("read header of BinFile failed: " + file_); - } - nrows_ = header[0]; - ndims_ = header[1]; - - size_t expected_file_size = - 2 * sizeof(uint32_t) + static_cast(nrows_) * ndims_ * sizeof(T); - if (file_size_ != expected_file_size) { - throw std::runtime_error("expected file size of " + file_ + " is " + - std::to_string(expected_file_size) + ", however, actual size is " + - std::to_string(file_size_)); - } - - if (subset_first_row_ >= nrows_) { - throw std::runtime_error(file_ + ": subset_first_row (" + std::to_string(subset_first_row_) + - ") >= nrows (" + std::to_string(nrows_) + ")"); - } - if (subset_first_row_ + subset_size_ > nrows_) { - throw std::runtime_error(file_ + ": subset_first_row (" + std::to_string(subset_first_row_) + - ") + subset_size (" + std::to_string(subset_size_) + ") > nrows (" + - std::to_string(nrows_) + ")"); - } - - if (subset_first_row_) { - static_assert(sizeof(long) == 8, "fseek() don't support 64-bit offset"); - if (fseek(fp_, sizeof(T) * subset_first_row_ * ndims_, SEEK_CUR) == -1) { - throw std::runtime_error(file_ + ": fseek failed"); - } - nrows_ -= subset_first_row_; - } - if (subset_size_) { nrows_ = subset_size_; } - } -} - -template -void BinFile::check_suffix_() -{ - auto pos = file_.rfind('.'); - if (pos == std::string::npos) { - throw std::runtime_error("name of BinFile doesn't have a suffix: " + file_); - } - std::string suffix = file_.substr(pos + 1); - - if constexpr (std::is_same_v) { - if (suffix != "fbin") { - throw std::runtime_error("BinFile should has .fbin suffix: " + file_); - } - } else if constexpr (std::is_same_v) { - if (suffix != "f16bin" && suffix != "fbin") { - throw std::runtime_error("BinFile should has .f16bin suffix: " + file_); - } - } else if constexpr (std::is_same_v) { - if (suffix != "ibin") { - throw std::runtime_error("BinFile should has .ibin suffix: " + file_); - } - } else if constexpr (std::is_same_v) { - if (suffix != "u8bin") { - throw std::runtime_error("BinFile should has .u8bin suffix: " + file_); - } - } else if constexpr (std::is_same_v) { - if (suffix != "i8bin") { - throw std::runtime_error("BinFile should has .i8bin suffix: " + file_); - } - } else { - throw std::runtime_error( - "T of BinFile should be one of float, half, int, uint8_t, or int8_t"); - } -} - -template -class Dataset { - public: - Dataset(const std::string& name) : name_(name) {} - Dataset(const std::string& name, const std::string& distance) : name_(name), distance_(distance) - { - } - Dataset(const Dataset&) = delete; - Dataset& operator=(const Dataset&) = delete; - virtual ~Dataset(); - - std::string name() const { return name_; } - std::string distance() const { return distance_; } - virtual int dim() const = 0; - virtual uint32_t max_k() const = 0; - virtual size_t base_set_size() const = 0; - virtual size_t query_set_size() const = 0; - - // load data lazily, so don't pay the overhead of reading unneeded set - // e.g. don't load base set when searching - const T* base_set() const - { - if (!base_set_) { load_base_set_(); } - return base_set_; - } - - const T* query_set() const - { - if (!query_set_) { load_query_set_(); } - return query_set_; - } - - const int32_t* gt_set() const - { - if (!gt_set_) { load_gt_set_(); } - return gt_set_; - } - - const T* base_set_on_gpu() const; - const T* query_set_on_gpu() const; - const T* mapped_base_set() const; - - auto query_set(MemoryType memory_type) const -> const T* - { - switch (memory_type) { - case MemoryType::Device: return query_set_on_gpu(); - default: return query_set(); - } - } - - auto base_set(MemoryType memory_type) const -> const T* - { - switch (memory_type) { - case MemoryType::Device: return base_set_on_gpu(); - case MemoryType::Host: return base_set(); - case MemoryType::HostMmap: return mapped_base_set(); - default: return nullptr; - } - } - - protected: - virtual void load_base_set_() const = 0; - virtual void load_gt_set_() const = 0; - virtual void load_query_set_() const = 0; - virtual void map_base_set_() const = 0; - - std::string name_; - std::string distance_; - - mutable T* base_set_ = nullptr; - mutable T* query_set_ = nullptr; - mutable T* d_base_set_ = nullptr; - mutable T* d_query_set_ = nullptr; - mutable T* mapped_base_set_ = nullptr; - mutable int32_t* gt_set_ = nullptr; -}; - -template -Dataset::~Dataset() -{ - delete[] base_set_; - delete[] query_set_; - delete[] gt_set_; -#ifndef BUILD_CPU_ONLY - if (d_base_set_) { cudaFree(d_base_set_); } - if (d_query_set_) { cudaFree(d_query_set_); } -#endif -} - -template -const T* Dataset::base_set_on_gpu() const -{ -#ifndef BUILD_CPU_ONLY - if (!d_base_set_) { - base_set(); - cudaMalloc((void**)&d_base_set_, base_set_size() * dim() * sizeof(T)); - cudaMemcpy(d_base_set_, base_set_, base_set_size() * dim() * sizeof(T), cudaMemcpyHostToDevice); - } -#endif - return d_base_set_; -} - -template -const T* Dataset::query_set_on_gpu() const -{ -#ifndef BUILD_CPU_ONLY - if (!d_query_set_) { - query_set(); - cudaMalloc((void**)&d_query_set_, query_set_size() * dim() * sizeof(T)); - cudaMemcpy( - d_query_set_, query_set_, query_set_size() * dim() * sizeof(T), cudaMemcpyHostToDevice); - } -#endif - return d_query_set_; -} - -template -const T* Dataset::mapped_base_set() const -{ - if (!mapped_base_set_) { map_base_set_(); } - return mapped_base_set_; -} - -template -class BinDataset : public Dataset { - public: - BinDataset(const std::string& name, - const std::string& base_file, - size_t subset_first_row, - size_t subset_size, - const std::string& query_file, - const std::string& distance, - const std::optional& groundtruth_neighbors_file); - - int dim() const override; - uint32_t max_k() const override; - size_t base_set_size() const override; - size_t query_set_size() const override; - - private: - void load_base_set_() const override; - void load_query_set_() const override; - void load_gt_set_() const override; - void map_base_set_() const override; - - mutable int dim_ = 0; - mutable uint32_t max_k_ = 0; - mutable size_t base_set_size_ = 0; - mutable size_t query_set_size_ = 0; - - BinFile base_file_; - BinFile query_file_; - std::optional> gt_file_{std::nullopt}; -}; - -template -BinDataset::BinDataset(const std::string& name, - const std::string& base_file, - size_t subset_first_row, - size_t subset_size, - const std::string& query_file, - const std::string& distance, - const std::optional& groundtruth_neighbors_file) - : Dataset(name, distance), - base_file_(base_file, "r", subset_first_row, subset_size), - query_file_(query_file, "r") -{ - if (groundtruth_neighbors_file.has_value()) { - gt_file_.emplace(groundtruth_neighbors_file.value(), "r"); - } -} - -template -int BinDataset::dim() const -{ - if (dim_ > 0) { return dim_; } - if (base_set_size() > 0) { return dim_; } - if (query_set_size() > 0) { return dim_; } - return dim_; -} - -template -uint32_t BinDataset::max_k() const -{ - if (!this->gt_set_) { load_gt_set_(); } - return max_k_; -} - -template -size_t BinDataset::query_set_size() const -{ - if (query_set_size_ > 0) { return query_set_size_; } - int dim; - query_file_.get_shape(&query_set_size_, &dim); - if (query_set_size_ == 0) { throw std::runtime_error("Zero query set size"); } - if (dim == 0) { throw std::runtime_error("Zero query set dim"); } - if (dim_ == 0) { - dim_ = dim; - } else if (dim_ != dim) { - throw std::runtime_error("base set dim (" + std::to_string(dim_) + ") != query set dim (" + - std::to_string(dim)); - } - return query_set_size_; -} - -template -size_t BinDataset::base_set_size() const -{ - if (base_set_size_ > 0) { return base_set_size_; } - int dim; - base_file_.get_shape(&base_set_size_, &dim); - if (base_set_size_ == 0) { throw std::runtime_error("Zero base set size"); } - if (dim == 0) { throw std::runtime_error("Zero base set dim"); } - if (dim_ == 0) { - dim_ = dim; - } else if (dim_ != dim) { - throw std::runtime_error("base set dim (" + std::to_string(dim) + ") != query set dim (" + - std::to_string(dim_)); - } - return base_set_size_; -} - -template -void BinDataset::load_base_set_() const -{ - this->base_set_ = new T[base_set_size() * dim()]; - base_file_.read(this->base_set_); -} - -template -void BinDataset::load_query_set_() const -{ - this->query_set_ = new T[query_set_size() * dim()]; - query_file_.read(this->query_set_); -} - -template -void BinDataset::load_gt_set_() const -{ - if (gt_file_.has_value()) { - size_t queries; - int k; - gt_file_->get_shape(&queries, &k); - this->gt_set_ = new std::int32_t[queries * k]; - gt_file_->read(this->gt_set_); - max_k_ = k; - } -} - -template -void BinDataset::map_base_set_() const -{ - this->mapped_base_set_ = base_file_.map(); -} - -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/common/thread_pool.hpp b/cpp/bench/ann/src/common/thread_pool.hpp deleted file mode 100644 index 4a5684ecb3..0000000000 --- a/cpp/bench/ann/src/common/thread_pool.hpp +++ /dev/null @@ -1,134 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include -#include -#include -#include -#include -#include -#include - -class FixedThreadPool { - public: - FixedThreadPool(int num_threads) - { - if (num_threads < 1) { - throw std::runtime_error("num_threads must >= 1"); - } else if (num_threads == 1) { - return; - } - - tasks_ = new Task_[num_threads]; - - threads_.reserve(num_threads); - for (int i = 0; i < num_threads; ++i) { - threads_.emplace_back([&, i] { - auto& task = tasks_[i]; - while (true) { - std::unique_lock lock(task.mtx); - task.cv.wait(lock, - [&] { return task.has_task || finished_.load(std::memory_order_relaxed); }); - if (finished_.load(std::memory_order_relaxed)) { break; } - - task.task(); - task.has_task = false; - } - }); - } - } - - ~FixedThreadPool() - { - if (threads_.empty()) { return; } - - finished_.store(true, std::memory_order_relaxed); - for (unsigned i = 0; i < threads_.size(); ++i) { - auto& task = tasks_[i]; - std::lock_guard(task.mtx); - - task.cv.notify_one(); - threads_[i].join(); - } - - delete[] tasks_; - } - - template - void submit(Func f, IdxT len) - { - // Run functions in main thread if thread pool has no threads - if (threads_.empty()) { - for (IdxT i = 0; i < len; ++i) { - f(i); - } - return; - } - - const int num_threads = threads_.size(); - // one extra part for competition among threads - const IdxT items_per_thread = len / (num_threads + 1); - std::atomic cnt(items_per_thread * num_threads); - - // Wrap function - auto wrapped_f = [&](IdxT start, IdxT end) { - for (IdxT i = start; i < end; ++i) { - f(i); - } - - while (true) { - IdxT i = cnt.fetch_add(1, std::memory_order_relaxed); - if (i >= len) { break; } - f(i); - } - }; - - std::vector> futures; - futures.reserve(num_threads); - for (int i = 0; i < num_threads; ++i) { - IdxT start = i * items_per_thread; - auto& task = tasks_[i]; - { - std::lock_guard lock(task.mtx); - (void)lock; // stop nvcc warning - task.task = std::packaged_task([=] { wrapped_f(start, start + items_per_thread); }); - futures.push_back(task.task.get_future()); - task.has_task = true; - } - task.cv.notify_one(); - } - - for (auto& fut : futures) { - fut.wait(); - } - return; - } - - private: - struct alignas(64) Task_ { - std::mutex mtx; - std::condition_variable cv; - bool has_task = false; - std::packaged_task task; - }; - - Task_* tasks_; - std::vector threads_; - std::atomic finished_{false}; -}; diff --git a/cpp/bench/ann/src/common/util.hpp b/cpp/bench/ann/src/common/util.hpp deleted file mode 100644 index 96185c79eb..0000000000 --- a/cpp/bench/ann/src/common/util.hpp +++ /dev/null @@ -1,557 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "ann_types.hpp" -#include "cuda_stub.hpp" // cuda-related utils - -#ifdef ANN_BENCH_NVTX3_HEADERS_FOUND -#include -#endif - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft::bench::ann { - -/** - * Current thread id as given by the benchmark State. - * It's populated on every call of a benchmark case. - * It's relevant in the 'throughput' mode of the search benchmarks, - * where some algorithms might want to coordinate allocation of the resources. - */ -inline thread_local int benchmark_thread_id = 0; -/** - * Total concurrent thread count as given by the benchmark State. - * It's populated on every call of a benchmark case. - * It's relevant in the 'throughput' mode of the search benchmarks, - * where some algorithms might want to coordinate allocation of the resources. - */ -inline thread_local int benchmark_n_threads = 1; - -struct cuda_timer { - private: - std::optional stream_; - cudaEvent_t start_{nullptr}; - cudaEvent_t stop_{nullptr}; - double total_time_{0}; - - template - static inline auto extract_stream(AnnT* algo) -> std::optional - { - auto gpu_ann = dynamic_cast(algo); - if (gpu_ann != nullptr && gpu_ann->uses_stream()) { - return std::make_optional(gpu_ann->get_sync_stream()); - } - return std::nullopt; - } - - public: - struct cuda_lap { - private: - cudaStream_t stream_; - cudaEvent_t start_; - cudaEvent_t stop_; - double& total_time_; - - public: - cuda_lap(cudaStream_t stream, cudaEvent_t start, cudaEvent_t stop, double& total_time) - : start_(start), stop_(stop), stream_(stream), total_time_(total_time) - { -#ifndef BUILD_CPU_ONLY - cudaEventRecord(start_, stream_); -#endif - } - cuda_lap() = delete; - - ~cuda_lap() noexcept - { -#ifndef BUILD_CPU_ONLY - cudaEventRecord(stop_, stream_); - cudaEventSynchronize(stop_); - float milliseconds = 0.0f; - cudaEventElapsedTime(&milliseconds, start_, stop_); - total_time_ += milliseconds / 1000.0; -#endif - } - }; - - explicit cuda_timer(std::optional stream) : stream_{stream} - { -#ifndef BUILD_CPU_ONLY - if (stream_.has_value()) { - cudaEventCreate(&stop_); - cudaEventCreate(&start_); - } -#endif - } - - template - explicit cuda_timer(const std::unique_ptr& algo) : cuda_timer{extract_stream(algo.get())} - { - } - - ~cuda_timer() noexcept - { -#ifndef BUILD_CPU_ONLY - if (stream_.has_value()) { - cudaStreamSynchronize(stream_.value()); - cudaEventDestroy(start_); - cudaEventDestroy(stop_); - } -#endif - } - - cuda_timer() = delete; - cuda_timer(cuda_timer const&) = delete; - cuda_timer(cuda_timer&&) = delete; - auto operator=(cuda_timer const&) -> cuda_timer& = delete; - auto operator=(cuda_timer&&) -> cuda_timer& = delete; - - [[nodiscard]] auto stream() const -> std::optional { return stream_; } - - [[nodiscard]] auto active() const -> bool { return stream_.has_value(); } - - [[nodiscard]] auto total_time() const -> double { return total_time_; } - - [[nodiscard]] auto lap(bool enabled = true) -> std::optional - { - return enabled && stream_.has_value() - ? std::make_optional(stream_.value(), start_, stop_, total_time_) - : std::nullopt; - } -}; - -#ifndef BUILD_CPU_ONLY -// ATM, rmm::stream does not support passing in flags; hence this helper type. -struct non_blocking_stream { - non_blocking_stream() { cudaStreamCreateWithFlags(&stream_, cudaStreamNonBlocking); } - ~non_blocking_stream() noexcept - { - if (stream_ != nullptr) { cudaStreamDestroy(stream_); } - } - non_blocking_stream(non_blocking_stream const&) = delete; - non_blocking_stream(non_blocking_stream&& other) noexcept { std::swap(stream_, other.stream_); } - auto operator=(non_blocking_stream const&) -> non_blocking_stream& = delete; - auto operator=(non_blocking_stream&&) -> non_blocking_stream& = delete; - [[nodiscard]] auto view() const noexcept -> cudaStream_t { return stream_; } - - private: - cudaStream_t stream_{nullptr}; -}; - -namespace detail { -inline std::vector global_stream_pool(0); -inline std::mutex gsp_mutex; -} // namespace detail -#endif - -/** - * Get a stream associated with the current benchmark thread. - * - * Note, the streams are reused between the benchmark cases. - * This makes it easier to profile and analyse multiple benchmark cases in one timeline using tools - * like nsys. - */ -inline auto get_stream_from_global_pool() -> cudaStream_t -{ -#ifndef BUILD_CPU_ONLY - std::lock_guard guard(detail::gsp_mutex); - if (int(detail::global_stream_pool.size()) < benchmark_n_threads) { - detail::global_stream_pool.resize(benchmark_n_threads); - } - return detail::global_stream_pool[benchmark_thread_id].view(); -#else - return nullptr; -#endif -} - -struct result_buffer { - explicit result_buffer(size_t size, cudaStream_t stream) : size_{size}, stream_{stream} - { - if (size_ == 0) { return; } - data_host_ = malloc(size_); -#ifndef BUILD_CPU_ONLY - cudaMallocAsync(&data_device_, size_, stream_); - cudaStreamSynchronize(stream_); -#endif - } - result_buffer() = delete; - result_buffer(result_buffer&&) = delete; - result_buffer& operator=(result_buffer&&) = delete; - result_buffer(const result_buffer&) = delete; - result_buffer& operator=(const result_buffer&) = delete; - ~result_buffer() noexcept - { - if (size_ == 0) { return; } -#ifndef BUILD_CPU_ONLY - cudaFreeAsync(data_device_, stream_); - cudaStreamSynchronize(stream_); -#endif - free(data_host_); - } - - [[nodiscard]] auto size() const noexcept { return size_; } - [[nodiscard]] auto data(ann::MemoryType loc) const noexcept - { - switch (loc) { - case MemoryType::Device: return data_device_; - default: return data_host_; - } - } - - void transfer_data(ann::MemoryType dst, ann::MemoryType src) - { - auto dst_ptr = data(dst); - auto src_ptr = data(src); - if (dst_ptr == src_ptr) { return; } -#ifndef BUILD_CPU_ONLY - cudaMemcpyAsync(dst_ptr, src_ptr, size_, cudaMemcpyDefault, stream_); - cudaStreamSynchronize(stream_); -#endif - } - - private: - size_t size_{0}; - cudaStream_t stream_ = nullptr; - void* data_host_ = nullptr; - void* data_device_ = nullptr; -}; - -namespace detail { -inline std::vector> global_result_buffer_pool(0); -inline std::mutex grp_mutex; -} // namespace detail - -/** - * Get a result buffer associated with the current benchmark thread. - * - * Note, the allocations are reused between the benchmark cases. - * This reduces the setup overhead and number of times the context is being blocked - * (this is relevant if there is a persistent kernel running across multiples benchmark cases). - */ -inline auto get_result_buffer_from_global_pool(size_t size) -> result_buffer& -{ - auto stream = get_stream_from_global_pool(); - auto& rb = [stream, size]() -> result_buffer& { - std::lock_guard guard(detail::grp_mutex); - if (static_cast(detail::global_result_buffer_pool.size()) < benchmark_n_threads) { - detail::global_result_buffer_pool.resize(benchmark_n_threads); - } - auto& rb = detail::global_result_buffer_pool[benchmark_thread_id]; - if (!rb || rb->size() < size) { rb = std::make_unique(size, stream); } - return *rb; - }(); - - memset(rb.data(MemoryType::Host), 0, size); -#ifndef BUILD_CPU_ONLY - cudaMemsetAsync(rb.data(MemoryType::Device), 0, size, stream); - cudaStreamSynchronize(stream); -#endif - return rb; -} - -/** - * Delete all streams and memory allocations in the global pool. - * It's called at the end of the `main` function - before global/static variables and cuda context - * is destroyed - to make sure they are destroyed gracefully and correctly seen by analysis tools - * such as nsys. - */ -inline void reset_global_device_resources() -{ -#ifndef BUILD_CPU_ONLY - std::lock_guard guard(detail::gsp_mutex); - detail::global_result_buffer_pool.resize(0); - detail::global_stream_pool.resize(0); -#endif -} - -inline auto cuda_info() -{ - std::vector> props; -#ifndef BUILD_CPU_ONLY - int dev, driver = 0, runtime = 0; - cudaDriverGetVersion(&driver); - cudaRuntimeGetVersion(&runtime); - - cudaDeviceProp device_prop; - cudaGetDevice(&dev); - cudaGetDeviceProperties(&device_prop, dev); - props.emplace_back("gpu_name", std::string(device_prop.name)); - props.emplace_back("gpu_sm_count", std::to_string(device_prop.multiProcessorCount)); - props.emplace_back("gpu_sm_freq", std::to_string(device_prop.clockRate * 1e3)); - props.emplace_back("gpu_mem_freq", std::to_string(device_prop.memoryClockRate * 1e3)); - props.emplace_back("gpu_mem_bus_width", std::to_string(device_prop.memoryBusWidth)); - props.emplace_back("gpu_mem_global_size", std::to_string(device_prop.totalGlobalMem)); - props.emplace_back("gpu_mem_shared_size", std::to_string(device_prop.sharedMemPerMultiprocessor)); - props.emplace_back("gpu_driver_version", - std::to_string(driver / 1000) + "." + std::to_string((driver % 100) / 10)); - props.emplace_back("gpu_runtime_version", - std::to_string(runtime / 1000) + "." + std::to_string((runtime % 100) / 10)); -#endif - return props; -} - -struct nvtx_case { -#ifdef ANN_BENCH_NVTX3_HEADERS_FOUND - private: - std::string case_name_; - std::array iter_name_{0}; - nvtxDomainHandle_t domain_; - int64_t iteration_ = 0; - nvtxEventAttributes_t case_attrib_{0}; - nvtxEventAttributes_t iter_attrib_{0}; -#endif - - public: - struct nvtx_lap { -#ifdef ANN_BENCH_NVTX3_HEADERS_FOUND - private: - nvtxDomainHandle_t domain_; - - public: - nvtx_lap(nvtxDomainHandle_t domain, nvtxEventAttributes_t* attr) : domain_(domain) - { - nvtxDomainRangePushEx(domain_, attr); - } - nvtx_lap() = delete; - ~nvtx_lap() noexcept { nvtxDomainRangePop(domain_); } -#endif - }; - -#ifdef ANN_BENCH_NVTX3_HEADERS_FOUND - explicit nvtx_case(std::string case_name) - : case_name_(std::move(case_name)), domain_(nvtxDomainCreateA("ANN benchmark")) - { - case_attrib_.version = NVTX_VERSION; - iter_attrib_.version = NVTX_VERSION; - case_attrib_.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; - iter_attrib_.size = NVTX_EVENT_ATTRIB_STRUCT_SIZE; - case_attrib_.colorType = NVTX_COLOR_ARGB; - iter_attrib_.colorType = NVTX_COLOR_ARGB; - case_attrib_.messageType = NVTX_MESSAGE_TYPE_ASCII; - iter_attrib_.messageType = NVTX_MESSAGE_TYPE_ASCII; - case_attrib_.message.ascii = case_name_.c_str(); - auto c = std::hash{}(case_name_); - case_attrib_.color = c | 0xA0A0A0; - nvtxDomainRangePushEx(domain_, &case_attrib_); - } - - ~nvtx_case() - { - nvtxDomainRangePop(domain_); - nvtxDomainDestroy(domain_); - } -#else - explicit nvtx_case(std::string) {} -#endif - - [[nodiscard]] auto lap() -> nvtx_case::nvtx_lap - { -#ifdef ANN_BENCH_NVTX3_HEADERS_FOUND - auto i = iteration_++; - uint32_t c = (i % 5); - uint32_t r = 150 + c * 20; - uint32_t g = 200 + c * 10; - uint32_t b = 220 + c * 5; - std::snprintf(iter_name_.data(), iter_name_.size(), "Lap %zd", i); - iter_attrib_.message.ascii = iter_name_.data(); - iter_attrib_.color = (r << 16) + (g << 8) + b; - return nvtx_lap{domain_, &iter_attrib_}; -#else - return nvtx_lap{}; -#endif - } -}; - -/** - * A progress tracker that allows syncing threads multiple times and resets the global - * progress once the threads are done. - */ -struct progress_barrier { - progress_barrier() = default; - ~progress_barrier() noexcept - { - { - // Lock makes sure the notified threads see the updates to `done_`. - std::unique_lock lk(mutex_); - done_.store(true, std::memory_order_relaxed); - cv_.notify_all(); - } - // This is the only place where the order of the updates to thread_progress_ and done_ is - // important. They are not guarded by the mutex, and `done_` must not be reset to `true` by - // other threads after the `total_progress_` is zero. - // Hence the default memory order (std::memory_order_seq_cst). - auto rem = total_progress_.fetch_sub(thread_progress_); - if (rem == thread_progress_) { - // the last thread to exit clears the progress state. - done_.store(false); - } - } - - /** - * Advance the progress counter by `n` and return the previous `progress` value. - * - * This can be used to track which thread arrives on the call site first. - * - * @return the previous progress counter value (before incrementing it by `n`). - */ - auto arrive(int n) - { - thread_progress_ += n; - // Lock makes sure the notified threads see the updates to `total_progress_`. - std::unique_lock lk(mutex_); - auto prev = total_progress_.fetch_add(n, std::memory_order_relaxed); - cv_.notify_all(); - return prev; - } - - /** - * Wait till the progress counter reaches `n` or finishes abnormally. - * - * @return the latest observed value of the progress counter. - */ - auto wait(int limit) - { - int cur = total_progress_.load(std::memory_order_relaxed); - if (cur >= limit) { return cur; } - auto done = done_.load(std::memory_order_relaxed); - if (done) { return cur; } - std::unique_lock lk(mutex_); - while (cur < limit && !done) { - using namespace std::chrono_literals; - cv_.wait_for(lk, 10ms); - cur = total_progress_.load(std::memory_order_relaxed); - done = done_.load(std::memory_order_relaxed); - } - return cur; - } - - private: - static inline std::atomic total_progress_; - static inline std::atomic done_; - static inline std::mutex mutex_; - static inline std::condition_variable cv_; - int thread_progress_{0}; -}; - -inline std::vector split(const std::string& s, char delimiter) -{ - std::vector tokens; - std::string token; - std::istringstream iss(s); - while (getline(iss, token, delimiter)) { - if (!token.empty()) { tokens.push_back(token); } - } - return tokens; -} - -inline bool file_exists(const std::string& filename) -{ - struct stat statbuf; - if (stat(filename.c_str(), &statbuf) != 0) { return false; } - return S_ISREG(statbuf.st_mode); -} - -inline bool dir_exists(const std::string& dir) -{ - struct stat statbuf; - if (stat(dir.c_str(), &statbuf) != 0) { return false; } - return S_ISDIR(statbuf.st_mode); -} - -inline bool create_dir(const std::string& dir) -{ - const auto path = split(dir, '/'); - - std::string cwd; - if (!dir.empty() && dir[0] == '/') { cwd += '/'; } - - for (const auto& p : path) { - cwd += p + "/"; - if (!dir_exists(cwd)) { - int ret = mkdir(cwd.c_str(), S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH); - if (ret != 0) { return false; } - } - } - return true; -} - -inline void make_sure_parent_dir_exists(const std::string& file_path) -{ - const auto pos = file_path.rfind('/'); - if (pos != std::string::npos) { - auto dir = file_path.substr(0, pos); - if (!dir_exists(dir)) { create_dir(dir); } - } -} - -inline auto combine_path(const std::string& dir, const std::string& path) -{ - std::filesystem::path p_dir(dir); - std::filesystem::path p_suf(path); - return (p_dir / p_suf).string(); -} - -template -void log_(const char* level, const Ts&... vs) -{ - char buf[20]; - std::time_t now = std::time(nullptr); - std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", std::localtime(&now)); - printf("%s [%s] ", buf, level); - if constexpr (sizeof...(Ts) == 1) { - printf("%s", vs...); - } else { - printf(vs...); - } - printf("\n"); - fflush(stdout); -} - -template -void log_info(Ts&&... vs) -{ - log_("info", std::forward(vs)...); -} - -template -void log_warn(Ts&&... vs) -{ - log_("warn", std::forward(vs)...); -} - -template -void log_error(Ts&&... vs) -{ - log_("error", std::forward(vs)...); -} - -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp b/cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp deleted file mode 100644 index 234b33d80a..0000000000 --- a/cpp/bench/ann/src/faiss/faiss_cpu_benchmark.cpp +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../common/ann_types.hpp" -#include "faiss_cpu_wrapper.h" - -#define JSON_DIAGNOSTICS 1 -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace raft::bench::ann { - -template -void parse_base_build_param(const nlohmann::json& conf, - typename raft::bench::ann::FaissCpu::BuildParam& param) -{ - param.nlist = conf.at("nlist"); - if (conf.contains("ratio")) { param.ratio = conf.at("ratio"); } -} - -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::FaissCpuIVFFlat::BuildParam& param) -{ - parse_base_build_param(conf, param); -} - -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::FaissCpuIVFPQ::BuildParam& param) -{ - parse_base_build_param(conf, param); - param.M = conf.at("M"); - if (conf.contains("use_precomputed_table")) { - param.use_precomputed_table = conf.at("use_precomputed_table"); - } else { - param.use_precomputed_table = false; - } - if (conf.contains("bitsPerCode")) { - param.bitsPerCode = conf.at("bitsPerCode"); - } else { - param.bitsPerCode = 8; - } -} - -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::FaissCpuIVFSQ::BuildParam& param) -{ - parse_base_build_param(conf, param); - param.quantizer_type = conf.at("quantizer_type"); -} - -template -void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::FaissCpu::SearchParam& param) -{ - param.nprobe = conf.at("nprobe"); - if (conf.contains("refine_ratio")) { param.refine_ratio = conf.at("refine_ratio"); } - if (conf.contains("numThreads")) { param.num_threads = conf.at("numThreads"); } -} - -template class Algo> -std::unique_ptr> make_algo(raft::bench::ann::Metric metric, - int dim, - const nlohmann::json& conf) -{ - typename Algo::BuildParam param; - parse_build_param(conf, param); - return std::make_unique>(metric, dim, param); -} - -template class Algo> -std::unique_ptr> make_algo(raft::bench::ann::Metric metric, - int dim, - const nlohmann::json& conf, - const std::vector& dev_list) -{ - typename Algo::BuildParam param; - parse_build_param(conf, param); - - (void)dev_list; - return std::make_unique>(metric, dim, param); -} - -template -std::unique_ptr> create_algo(const std::string& algo, - const std::string& distance, - int dim, - const nlohmann::json& conf, - const std::vector& dev_list) -{ - // stop compiler warning; not all algorithms support multi-GPU so it may not be used - (void)dev_list; - - std::unique_ptr> ann; - - if constexpr (std::is_same_v) { - raft::bench::ann::Metric metric = parse_metric(distance); - if (algo == "faiss_cpu_ivf_flat") { - ann = make_algo(metric, dim, conf, dev_list); - } else if (algo == "faiss_cpu_ivf_pq") { - ann = make_algo(metric, dim, conf); - } else if (algo == "faiss_cpu_ivf_sq") { - ann = make_algo(metric, dim, conf); - } else if (algo == "faiss_cpu_flat") { - ann = std::make_unique>(metric, dim); - } - } - - if constexpr (std::is_same_v) {} - - if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } - - return ann; -} - -template -std::unique_ptr::AnnSearchParam> create_search_param( - const std::string& algo, const nlohmann::json& conf) -{ - if (algo == "faiss_cpu_ivf_flat" || algo == "faiss_cpu_ivf_pq" || algo == "faiss_cpu_ivf_sq") { - auto param = std::make_unique::SearchParam>(); - parse_search_param(conf, *param); - return param; - } else if (algo == "faiss_cpu_flat") { - auto param = std::make_unique::SearchParam>(); - return param; - } - // else - throw std::runtime_error("invalid algo: '" + algo + "'"); -} - -} // namespace raft::bench::ann - -REGISTER_ALGO_INSTANCE(float); -REGISTER_ALGO_INSTANCE(std::int8_t); -REGISTER_ALGO_INSTANCE(std::uint8_t); - -#ifdef ANN_BENCH_BUILD_MAIN -#include "../common/benchmark.hpp" -int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); } -#endif diff --git a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h deleted file mode 100644 index c7ce4595b5..0000000000 --- a/cpp/bench/ann/src/faiss/faiss_cpu_wrapper.h +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../common/ann_types.hpp" -#include "../common/thread_pool.hpp" - -#include - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace { - -faiss::MetricType parse_metric_type(raft::bench::ann::Metric metric) -{ - if (metric == raft::bench::ann::Metric::kInnerProduct) { - return faiss::METRIC_INNER_PRODUCT; - } else if (metric == raft::bench::ann::Metric::kEuclidean) { - return faiss::METRIC_L2; - } else { - throw std::runtime_error("faiss supports only metric type of inner product and L2"); - } -} -} // namespace - -namespace raft::bench::ann { - -template -class FaissCpu : public ANN { - public: - using typename ANN::AnnSearchParam; - struct SearchParam : public AnnSearchParam { - int nprobe; - float refine_ratio = 1.0; - int num_threads = omp_get_num_procs(); - }; - - struct BuildParam { - int nlist = 1; - int ratio = 2; - }; - - FaissCpu(Metric metric, int dim, const BuildParam& param) - : ANN(metric, dim), - metric_type_(parse_metric_type(metric)), - nlist_{param.nlist}, - training_sample_fraction_{1.0 / double(param.ratio)} - { - static_assert(std::is_same_v, "faiss support only float type"); - } - - void build(const T* dataset, size_t nrow) final; - - void set_search_param(const AnnSearchParam& param) override; - - void init_quantizer(int dim) - { - if (this->metric_type_ == faiss::MetricType::METRIC_L2) { - this->quantizer_ = std::make_shared(dim); - } else if (this->metric_type_ == faiss::MetricType::METRIC_INNER_PRODUCT) { - this->quantizer_ = std::make_shared(dim); - } - } - - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search(const T* queries, - int batch_size, - int k, - AnnBase::index_type* neighbors, - float* distances) const final; - - AlgoProperty get_preference() const override - { - AlgoProperty property; - // to enable building big dataset which is larger than memory - property.dataset_memory_type = MemoryType::Host; - property.query_memory_type = MemoryType::Host; - return property; - } - - protected: - template - void save_(const std::string& file) const; - - template - void load_(const std::string& file); - - std::shared_ptr index_; - std::shared_ptr quantizer_; - std::shared_ptr index_refine_; - faiss::MetricType metric_type_; - int nlist_; - double training_sample_fraction_; - - int num_threads_; - std::shared_ptr thread_pool_; -}; - -template -void FaissCpu::build(const T* dataset, size_t nrow) -{ - auto index_ivf = dynamic_cast(index_.get()); - if (index_ivf != nullptr) { - // set the min/max training size for clustering to use the whole provided training set. - double trainset_size = training_sample_fraction_ * static_cast(nrow); - double points_per_centroid = trainset_size / static_cast(nlist_); - int max_ppc = std::ceil(points_per_centroid); - int min_ppc = std::floor(points_per_centroid); - if (min_ppc < index_ivf->cp.min_points_per_centroid) { - RAFT_LOG_WARN( - "The suggested training set size %zu (data size %zu, training sample ratio %f) yields %d " - "points per cluster (n_lists = %d). This is smaller than the FAISS default " - "min_points_per_centroid = %d.", - static_cast(trainset_size), - nrow, - training_sample_fraction_, - min_ppc, - nlist_, - index_ivf->cp.min_points_per_centroid); - } - index_ivf->cp.max_points_per_centroid = max_ppc; - index_ivf->cp.min_points_per_centroid = min_ppc; - } - index_->train(nrow, dataset); // faiss::IndexFlat::train() will do nothing - assert(index_->is_trained); - index_->add(nrow, dataset); - index_refine_ = std::make_shared(this->index_.get(), dataset); -} - -template -void FaissCpu::set_search_param(const AnnSearchParam& param) -{ - auto search_param = dynamic_cast(param); - int nprobe = search_param.nprobe; - assert(nprobe <= nlist_); - dynamic_cast(index_.get())->nprobe = nprobe; - - if (search_param.refine_ratio > 1.0) { - this->index_refine_.get()->k_factor = search_param.refine_ratio; - } - - if (!thread_pool_ || num_threads_ != search_param.num_threads) { - num_threads_ = search_param.num_threads; - thread_pool_ = std::make_shared(num_threads_); - } -} - -template -void FaissCpu::search( - const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const -{ - static_assert(sizeof(size_t) == sizeof(faiss::idx_t), - "sizes of size_t and faiss::idx_t are different"); - - thread_pool_->submit( - [&](int i) { - // Use thread pool for batch size = 1. FAISS multi-threads internally for batch size > 1. - index_->search(batch_size, queries, k, distances, reinterpret_cast(neighbors)); - }, - 1); -} - -template -template -void FaissCpu::save_(const std::string& file) const -{ - faiss::write_index(index_.get(), file.c_str()); -} - -template -template -void FaissCpu::load_(const std::string& file) -{ - index_ = std::shared_ptr(dynamic_cast(faiss::read_index(file.c_str()))); -} - -template -class FaissCpuIVFFlat : public FaissCpu { - public: - using typename FaissCpu::BuildParam; - - FaissCpuIVFFlat(Metric metric, int dim, const BuildParam& param) : FaissCpu(metric, dim, param) - { - this->init_quantizer(dim); - this->index_ = std::make_shared( - this->quantizer_.get(), dim, param.nlist, this->metric_type_); - } - - void save(const std::string& file) const override - { - this->template save_(file); - } - void load(const std::string& file) override { this->template load_(file); } - - std::unique_ptr> copy() - { - return std::make_unique>(*this); // use copy constructor - } -}; - -template -class FaissCpuIVFPQ : public FaissCpu { - public: - struct BuildParam : public FaissCpu::BuildParam { - int M; - int bitsPerCode; - bool use_precomputed_table; - }; - - FaissCpuIVFPQ(Metric metric, int dim, const BuildParam& param) : FaissCpu(metric, dim, param) - { - this->init_quantizer(dim); - this->index_ = std::make_shared( - this->quantizer_.get(), dim, param.nlist, param.M, param.bitsPerCode, this->metric_type_); - } - - void save(const std::string& file) const override - { - this->template save_(file); - } - void load(const std::string& file) override { this->template load_(file); } - - std::unique_ptr> copy() - { - return std::make_unique>(*this); // use copy constructor - } -}; - -// TODO: Enable this in cmake -// ref: https://github.com/rapidsai/raft/issues/1876 -template -class FaissCpuIVFSQ : public FaissCpu { - public: - struct BuildParam : public FaissCpu::BuildParam { - std::string quantizer_type; - }; - - FaissCpuIVFSQ(Metric metric, int dim, const BuildParam& param) : FaissCpu(metric, dim, param) - { - faiss::ScalarQuantizer::QuantizerType qtype; - if (param.quantizer_type == "fp16") { - qtype = faiss::ScalarQuantizer::QT_fp16; - } else if (param.quantizer_type == "int8") { - qtype = faiss::ScalarQuantizer::QT_8bit; - } else { - throw std::runtime_error("FaissCpuIVFSQ supports only fp16 and int8 but got " + - param.quantizer_type); - } - - this->init_quantizer(dim); - this->index_ = std::make_shared( - this->quantizer_.get(), dim, param.nlist, qtype, this->metric_type_, true); - } - - void save(const std::string& file) const override - { - this->template save_(file); - } - void load(const std::string& file) override - { - this->template load_(file); - } - - std::unique_ptr> copy() - { - return std::make_unique>(*this); // use copy constructor - } -}; - -template -class FaissCpuFlat : public FaissCpu { - public: - FaissCpuFlat(Metric metric, int dim) - : FaissCpu(metric, dim, typename FaissCpu::BuildParam{}) - { - this->index_ = std::make_shared(dim, this->metric_type_); - } - - // class FaissCpu is more like a IVF class, so need special treating here - void set_search_param(const typename ANN::AnnSearchParam& param) override - { - auto search_param = dynamic_cast::SearchParam&>(param); - if (!this->thread_pool_ || this->num_threads_ != search_param.num_threads) { - this->num_threads_ = search_param.num_threads; - this->thread_pool_ = std::make_shared(this->num_threads_); - } - }; - - void save(const std::string& file) const override - { - this->template save_(file); - } - void load(const std::string& file) override { this->template load_(file); } - - std::unique_ptr> copy() - { - return std::make_unique>(*this); // use copy constructor - } -}; - -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/faiss/faiss_gpu_benchmark.cu b/cpp/bench/ann/src/faiss/faiss_gpu_benchmark.cu deleted file mode 100644 index b47c497e3d..0000000000 --- a/cpp/bench/ann/src/faiss/faiss_gpu_benchmark.cu +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../common/ann_types.hpp" - -#undef WARP_SIZE -#include "faiss_gpu_wrapper.h" - -#define JSON_DIAGNOSTICS 1 -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace raft::bench::ann { - -template -void parse_base_build_param(const nlohmann::json& conf, - typename raft::bench::ann::FaissGpu::BuildParam& param) -{ - param.nlist = conf.at("nlist"); - if (conf.contains("ratio")) { param.ratio = conf.at("ratio"); } -} - -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::FaissGpuIVFFlat::BuildParam& param) -{ - parse_base_build_param(conf, param); - if (conf.contains("use_raft")) { - param.use_raft = conf.at("use_raft"); - } else { - param.use_raft = false; - } -} - -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::FaissGpuIVFPQ::BuildParam& param) -{ - parse_base_build_param(conf, param); - param.M = conf.at("M"); - if (conf.contains("usePrecomputed")) { - param.usePrecomputed = conf.at("usePrecomputed"); - } else { - param.usePrecomputed = false; - } - if (conf.contains("useFloat16")) { - param.useFloat16 = conf.at("useFloat16"); - } else { - param.useFloat16 = false; - } - if (conf.contains("use_raft")) { - param.use_raft = conf.at("use_raft"); - } else { - param.use_raft = false; - } - if (conf.contains("bitsPerCode")) { - param.bitsPerCode = conf.at("bitsPerCode"); - } else { - param.bitsPerCode = 8; - } -} - -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::FaissGpuIVFSQ::BuildParam& param) -{ - parse_base_build_param(conf, param); - param.quantizer_type = conf.at("quantizer_type"); -} - -template -void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::FaissGpu::SearchParam& param) -{ - param.nprobe = conf.at("nprobe"); - if (conf.contains("refine_ratio")) { param.refine_ratio = conf.at("refine_ratio"); } -} - -template class Algo> -std::unique_ptr> make_algo(raft::bench::ann::Metric metric, - int dim, - const nlohmann::json& conf) -{ - typename Algo::BuildParam param; - parse_build_param(conf, param); - return std::make_unique>(metric, dim, param); -} - -template class Algo> -std::unique_ptr> make_algo(raft::bench::ann::Metric metric, - int dim, - const nlohmann::json& conf, - const std::vector& dev_list) -{ - typename Algo::BuildParam param; - parse_build_param(conf, param); - - (void)dev_list; - return std::make_unique>(metric, dim, param); -} - -template -std::unique_ptr> create_algo(const std::string& algo, - const std::string& distance, - int dim, - const nlohmann::json& conf, - const std::vector& dev_list) -{ - // stop compiler warning; not all algorithms support multi-GPU so it may not be used - (void)dev_list; - - std::unique_ptr> ann; - - if constexpr (std::is_same_v) { - raft::bench::ann::Metric metric = parse_metric(distance); - if (algo == "faiss_gpu_ivf_flat") { - ann = make_algo(metric, dim, conf, dev_list); - } else if (algo == "faiss_gpu_ivf_pq") { - ann = make_algo(metric, dim, conf); - } else if (algo == "faiss_gpu_ivf_sq") { - ann = make_algo(metric, dim, conf); - } else if (algo == "faiss_gpu_flat") { - ann = std::make_unique>(metric, dim); - } - } - - if constexpr (std::is_same_v) {} - - if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } - - return ann; -} - -template -std::unique_ptr::AnnSearchParam> create_search_param( - const std::string& algo, const nlohmann::json& conf) -{ - if (algo == "faiss_gpu_ivf_flat" || algo == "faiss_gpu_ivf_pq" || algo == "faiss_gpu_ivf_sq") { - auto param = std::make_unique::SearchParam>(); - parse_search_param(conf, *param); - return param; - } else if (algo == "faiss_gpu_flat") { - auto param = std::make_unique::SearchParam>(); - return param; - } - // else - throw std::runtime_error("invalid algo: '" + algo + "'"); -} - -} // namespace raft::bench::ann - -REGISTER_ALGO_INSTANCE(float); -REGISTER_ALGO_INSTANCE(std::int8_t); -REGISTER_ALGO_INSTANCE(std::uint8_t); - -#ifdef ANN_BENCH_BUILD_MAIN -#include "../common/benchmark.hpp" -int main(int argc, char** argv) -{ - rmm::mr::cuda_memory_resource cuda_mr; - // Construct a resource that uses a coalescing best-fit pool allocator - // and is initially sized to half of free device memory. - rmm::mr::pool_memory_resource pool_mr{ - &cuda_mr, rmm::percent_of_free_device_memory(50)}; - // Updates the current device resource pointer to `pool_mr` - auto old_mr = rmm::mr::set_current_device_resource(&pool_mr); - auto ret = raft::bench::ann::run_main(argc, argv); - // Restores the current device resource pointer to its previous value - rmm::mr::set_current_device_resource(old_mr); - return ret; -} -#endif diff --git a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h b/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h deleted file mode 100644 index 6955201c5d..0000000000 --- a/cpp/bench/ann/src/faiss/faiss_gpu_wrapper.h +++ /dev/null @@ -1,515 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#ifndef FAISS_WRAPPER_H_ -#define FAISS_WRAPPER_H_ - -#include "../common/ann_types.hpp" -#include "../raft/raft_ann_bench_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace { - -faiss::MetricType parse_metric_faiss(raft::bench::ann::Metric metric) -{ - if (metric == raft::bench::ann::Metric::kInnerProduct) { - return faiss::METRIC_INNER_PRODUCT; - } else if (metric == raft::bench::ann::Metric::kEuclidean) { - return faiss::METRIC_L2; - } else { - throw std::runtime_error("faiss supports only metric type of inner product and L2"); - } -} - -// note BLAS library can still use multi-threading, and -// setting environment variable like OPENBLAS_NUM_THREADS can control it -class OmpSingleThreadScope { - public: - OmpSingleThreadScope() - { - max_threads_ = omp_get_max_threads(); - omp_set_num_threads(1); - } - ~OmpSingleThreadScope() - { - // the best we can do - omp_set_num_threads(max_threads_); - } - - private: - int max_threads_; -}; - -} // namespace - -namespace raft::bench::ann { - -template -class FaissGpu : public ANN, public AnnGPU { - public: - using typename ANN::AnnSearchParam; - struct SearchParam : public AnnSearchParam { - int nprobe; - float refine_ratio = 1.0; - auto needs_dataset() const -> bool override { return refine_ratio > 1.0f; } - }; - - struct BuildParam { - int nlist = 1; - int ratio = 2; - }; - - FaissGpu(Metric metric, int dim, const BuildParam& param) - : ANN(metric, dim), - gpu_resource_{std::make_shared()}, - metric_type_(parse_metric_faiss(metric)), - nlist_{param.nlist}, - training_sample_fraction_{1.0 / double(param.ratio)} - { - static_assert(std::is_same_v, "faiss support only float type"); - RAFT_CUDA_TRY(cudaGetDevice(&device_)); - } - - void build(const T* dataset, size_t nrow) final; - - virtual void set_search_param(const FaissGpu::AnnSearchParam& param) {} - - void set_search_dataset(const T* dataset, size_t nrow) override { dataset_ = dataset; } - - // TODO: if the number of results is less than k, the remaining elements of 'neighbors' - // will be filled with (size_t)-1 - void search(const T* queries, - int batch_size, - int k, - AnnBase::index_type* neighbors, - float* distances) const final; - - [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override - { - return gpu_resource_->getDefaultStream(device_); - } - - AlgoProperty get_preference() const override - { - AlgoProperty property; - // to enable building big dataset which is larger than GPU memory - property.dataset_memory_type = MemoryType::Host; - property.query_memory_type = MemoryType::Device; - return property; - } - - protected: - template - void save_(const std::string& file) const; - - template - void load_(const std::string& file); - - /** [NOTE Multithreading] - * - * `gpu_resource_` is a shared resource: - * 1. It uses a shared_ptr under the hood, so the copies of it refer to the same - * resource implementation instance - * 2. GpuIndex is probably keeping a reference to it, as it's passed to the constructor - * - * To avoid copying the index (database) in each thread, we make both the index and - * the gpu_resource shared. - * This means faiss GPU streams are possibly shared among the CPU threads; - * the throughput search mode may be inaccurate. - * - * WARNING: we haven't investigated whether faiss::gpu::GpuIndex or - * faiss::gpu::StandardGpuResources are thread-safe. - * - */ - mutable std::shared_ptr gpu_resource_; - std::shared_ptr index_; - std::shared_ptr index_refine_{nullptr}; - faiss::MetricType metric_type_; - int nlist_; - int device_; - double training_sample_fraction_; - std::shared_ptr search_params_; - std::shared_ptr refine_search_params_{nullptr}; - const T* dataset_; - float refine_ratio_ = 1.0; - Objective metric_objective_; -}; - -template -void FaissGpu::build(const T* dataset, size_t nrow) -{ - OmpSingleThreadScope omp_single_thread; - auto index_ivf = dynamic_cast(index_.get()); - if (index_ivf != nullptr) { - // set the min/max training size for clustering to use the whole provided training set. - double trainset_size = training_sample_fraction_ * static_cast(nrow); - double points_per_centroid = trainset_size / static_cast(nlist_); - int max_ppc = std::ceil(points_per_centroid); - int min_ppc = std::floor(points_per_centroid); - if (min_ppc < index_ivf->cp.min_points_per_centroid) { - RAFT_LOG_WARN( - "The suggested training set size %zu (data size %zu, training sample ratio %f) yields %d " - "points per cluster (n_lists = %d). This is smaller than the FAISS default " - "min_points_per_centroid = %d.", - static_cast(trainset_size), - nrow, - training_sample_fraction_, - min_ppc, - nlist_, - index_ivf->cp.min_points_per_centroid); - } - index_ivf->cp.max_points_per_centroid = max_ppc; - index_ivf->cp.min_points_per_centroid = min_ppc; - } - index_->train(nrow, dataset); // faiss::gpu::GpuIndexFlat::train() will do nothing - assert(index_->is_trained); - index_->add(nrow, dataset); -} - -template -void FaissGpu::search( - const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const -{ - ASSERT(Objective::LATENCY, "l2Knn: rowMajorIndex and rowMajorQuery should have same layout"); - using IdxT = faiss::idx_t; - static_assert(sizeof(size_t) == sizeof(faiss::idx_t), - "sizes of size_t and faiss::idx_t are different"); - - if (refine_ratio_ > 1.0) { - if (raft::get_device_for_address(queries) >= 0) { - uint32_t k0 = static_cast(refine_ratio_ * k); - auto distances_tmp = raft::make_device_matrix( - gpu_resource_->getRaftHandle(device_), batch_size, k0); - auto candidates = - raft::make_device_matrix(gpu_resource_->getRaftHandle(device_), batch_size, k0); - index_->search(batch_size, - queries, - k0, - distances_tmp.data_handle(), - candidates.data_handle(), - this->search_params_.get()); - - auto queries_host = raft::make_host_matrix(batch_size, index_->d); - auto candidates_host = raft::make_host_matrix(batch_size, k0); - auto neighbors_host = raft::make_host_matrix(batch_size, k); - auto distances_host = raft::make_host_matrix(batch_size, k); - auto dataset_v = raft::make_host_matrix_view( - this->dataset_, index_->ntotal, index_->d); - - raft::device_resources handle_ = gpu_resource_->getRaftHandle(device_); - - raft::copy(queries_host.data_handle(), queries, queries_host.size(), handle_.get_stream()); - raft::copy(candidates_host.data_handle(), - candidates.data_handle(), - candidates_host.size(), - handle_.get_stream()); - - // wait for the queries to copy to host in 'stream` - handle_.sync_stream(); - - raft::runtime::neighbors::refine(handle_, - dataset_v, - queries_host.view(), - candidates_host.view(), - neighbors_host.view(), - distances_host.view(), - parse_metric_type(this->metric_)); - - raft::copy(neighbors, - (size_t*)neighbors_host.data_handle(), - neighbors_host.size(), - handle_.get_stream()); - raft::copy( - distances, distances_host.data_handle(), distances_host.size(), handle_.get_stream()); - } else { - index_refine_->search(batch_size, - queries, - k, - distances, - reinterpret_cast(neighbors), - this->refine_search_params_.get()); - } - } else { - index_->search(batch_size, - queries, - k, - distances, - reinterpret_cast(neighbors), - this->search_params_.get()); - } -} - -template -template -void FaissGpu::save_(const std::string& file) const -{ - OmpSingleThreadScope omp_single_thread; - - auto cpu_index = std::make_unique(); - dynamic_cast(index_.get())->copyTo(cpu_index.get()); - faiss::write_index(cpu_index.get(), file.c_str()); -} - -template -template -void FaissGpu::load_(const std::string& file) -{ - OmpSingleThreadScope omp_single_thread; - - std::unique_ptr cpu_index(dynamic_cast(faiss::read_index(file.c_str()))); - assert(cpu_index); - - try { - dynamic_cast(index_.get())->copyFrom(cpu_index.get()); - - } catch (const std::exception& e) { - std::cout << "Error loading index file: " << std::string(e.what()) << std::endl; - } -} - -template -class FaissGpuIVFFlat : public FaissGpu { - public: - struct BuildParam : public FaissGpu::BuildParam { - bool use_raft; - }; - - FaissGpuIVFFlat(Metric metric, int dim, const BuildParam& param) : FaissGpu(metric, dim, param) - { - faiss::gpu::GpuIndexIVFFlatConfig config; - config.device = this->device_; - config.use_raft = param.use_raft; - this->index_ = std::make_shared( - this->gpu_resource_.get(), dim, param.nlist, this->metric_type_, config); - } - - void set_search_param(const typename FaissGpu::AnnSearchParam& param) override - { - auto search_param = dynamic_cast::SearchParam&>(param); - int nprobe = search_param.nprobe; - assert(nprobe <= nlist_); - - faiss::IVFSearchParameters faiss_search_params; - faiss_search_params.nprobe = nprobe; - this->search_params_ = std::make_shared(faiss_search_params); - this->refine_ratio_ = search_param.refine_ratio; - } - - void save(const std::string& file) const override - { - this->template save_(file); - } - void load(const std::string& file) override - { - this->template load_(file); - } - std::unique_ptr> copy() override { return std::make_unique>(*this); }; -}; - -template -class FaissGpuIVFPQ : public FaissGpu { - public: - struct BuildParam : public FaissGpu::BuildParam { - int M; - bool useFloat16; - bool usePrecomputed; - bool use_raft; - int bitsPerCode; - }; - - FaissGpuIVFPQ(Metric metric, int dim, const BuildParam& param) : FaissGpu(metric, dim, param) - { - faiss::gpu::GpuIndexIVFPQConfig config; - config.useFloat16LookupTables = param.useFloat16; - config.usePrecomputedTables = param.usePrecomputed; - config.use_raft = param.use_raft; - config.interleavedLayout = param.use_raft; - config.device = this->device_; - - this->index_ = std::make_shared(this->gpu_resource_.get(), - dim, - param.nlist, - param.M, - param.bitsPerCode, - this->metric_type_, - config); - } - - void set_search_param(const typename FaissGpu::AnnSearchParam& param) override - { - auto search_param = dynamic_cast::SearchParam&>(param); - int nprobe = search_param.nprobe; - assert(nprobe <= nlist_); - this->refine_ratio_ = search_param.refine_ratio; - faiss::IVFPQSearchParameters faiss_search_params; - faiss_search_params.nprobe = nprobe; - - this->search_params_ = std::make_shared(faiss_search_params); - - if (search_param.refine_ratio > 1.0) { - this->index_refine_ = - std::make_shared(this->index_.get(), this->dataset_); - this->index_refine_.get()->k_factor = search_param.refine_ratio; - faiss::IndexRefineSearchParameters faiss_refine_search_params; - faiss_refine_search_params.k_factor = this->index_refine_.get()->k_factor; - faiss_refine_search_params.base_index_params = this->search_params_.get(); - this->refine_search_params_ = - std::make_unique(faiss_refine_search_params); - } - } - - void save(const std::string& file) const override - { - this->template save_(file); - } - void load(const std::string& file) override - { - this->template load_(file); - } - std::unique_ptr> copy() override { return std::make_unique>(*this); }; -}; - -// TODO: Enable this in cmake -// ref: https://github.com/rapidsai/raft/issues/1876 -template -class FaissGpuIVFSQ : public FaissGpu { - public: - struct BuildParam : public FaissGpu::BuildParam { - std::string quantizer_type; - }; - - FaissGpuIVFSQ(Metric metric, int dim, const BuildParam& param) : FaissGpu(metric, dim, param) - { - faiss::ScalarQuantizer::QuantizerType qtype; - if (param.quantizer_type == "fp16") { - qtype = faiss::ScalarQuantizer::QT_fp16; - } else if (param.quantizer_type == "int8") { - qtype = faiss::ScalarQuantizer::QT_8bit; - } else { - throw std::runtime_error("FaissGpuIVFSQ supports only fp16 and int8 but got " + - param.quantizer_type); - } - - faiss::gpu::GpuIndexIVFScalarQuantizerConfig config; - config.device = this->device_; - this->index_ = std::make_shared( - this->gpu_resource_.get(), dim, param.nlist, qtype, this->metric_type_, true, config); - } - - void set_search_param(const typename FaissGpu::AnnSearchParam& param) override - { - auto search_param = dynamic_cast::SearchParam&>(param); - int nprobe = search_param.nprobe; - assert(nprobe <= nlist_); - - faiss::IVFSearchParameters faiss_search_params; - faiss_search_params.nprobe = nprobe; - - this->search_params_ = std::make_shared(faiss_search_params); - this->refine_ratio_ = search_param.refine_ratio; - if (search_param.refine_ratio > 1.0) { - this->index_refine_ = - std::make_shared(this->index_.get(), this->dataset_); - this->index_refine_.get()->k_factor = search_param.refine_ratio; - faiss::IndexRefineSearchParameters faiss_refine_search_params; - faiss_refine_search_params.k_factor = this->index_refine_.get()->k_factor; - faiss_refine_search_params.base_index_params = this->search_params_.get(); - this->refine_search_params_ = - std::make_unique(faiss_refine_search_params); - } - } - - void save(const std::string& file) const override - { - this->template save_( - file); - } - void load(const std::string& file) override - { - this->template load_( - file); - } - std::unique_ptr> copy() override { return std::make_unique>(*this); }; -}; - -template -class FaissGpuFlat : public FaissGpu { - public: - FaissGpuFlat(Metric metric, int dim) - : FaissGpu(metric, dim, typename FaissGpu::BuildParam{}) - { - faiss::gpu::GpuIndexFlatConfig config; - config.device = this->device_; - this->index_ = std::make_shared( - this->gpu_resource_.get(), dim, this->metric_type_, config); - } - void set_search_param(const typename FaissGpu::AnnSearchParam& param) override - { - auto search_param = dynamic_cast::SearchParam&>(param); - int nprobe = search_param.nprobe; - assert(nprobe <= nlist_); - - this->search_params_ = std::make_shared(); - } - - void save(const std::string& file) const override - { - this->template save_(file); - } - void load(const std::string& file) override - { - this->template load_(file); - } - std::unique_ptr> copy() override { return std::make_unique>(*this); }; -}; - -} // namespace raft::bench::ann - -#endif diff --git a/cpp/bench/ann/src/ggnn/ggnn_benchmark.cu b/cpp/bench/ann/src/ggnn/ggnn_benchmark.cu deleted file mode 100644 index 48d41388d4..0000000000 --- a/cpp/bench/ann/src/ggnn/ggnn_benchmark.cu +++ /dev/null @@ -1,130 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../common/ann_types.hpp" -#include "ggnn_wrapper.cuh" - -#define JSON_DIAGNOSTICS 1 -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace raft::bench::ann { - -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::Ggnn::BuildParam& param) -{ - param.k = conf.at("k"); - - if (conf.contains("k_build")) { param.k_build = conf.at("k_build"); } - if (conf.contains("segment_size")) { param.segment_size = conf.at("segment_size"); } - if (conf.contains("num_layers")) { param.num_layers = conf.at("num_layers"); } - if (conf.contains("tau")) { param.tau = conf.at("tau"); } - if (conf.contains("refine_iterations")) { - param.refine_iterations = conf.at("refine_iterations"); - } -} - -template -void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::Ggnn::SearchParam& param) -{ - param.tau = conf.at("tau"); - - if (conf.contains("block_dim")) { param.block_dim = conf.at("block_dim"); } - if (conf.contains("max_iterations")) { param.max_iterations = conf.at("max_iterations"); } - if (conf.contains("cache_size")) { param.cache_size = conf.at("cache_size"); } - if (conf.contains("sorted_size")) { param.sorted_size = conf.at("sorted_size"); } -} - -template class Algo> -std::unique_ptr> make_algo(raft::bench::ann::Metric metric, - int dim, - const nlohmann::json& conf) -{ - typename Algo::BuildParam param; - parse_build_param(conf, param); - return std::make_unique>(metric, dim, param); -} - -template class Algo> -std::unique_ptr> make_algo(raft::bench::ann::Metric metric, - int dim, - const nlohmann::json& conf, - const std::vector& dev_list) -{ - typename Algo::BuildParam param; - parse_build_param(conf, param); - - (void)dev_list; - return std::make_unique>(metric, dim, param); -} - -template -std::unique_ptr> create_algo(const std::string& algo, - const std::string& distance, - int dim, - const nlohmann::json& conf, - const std::vector& dev_list) -{ - // stop compiler warning; not all algorithms support multi-GPU so it may not be used - (void)dev_list; - - raft::bench::ann::Metric metric = parse_metric(distance); - std::unique_ptr> ann; - - if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { - if (algo == "ggnn") { ann = make_algo(metric, dim, conf); } - } - if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } - - return ann; -} - -template -std::unique_ptr::AnnSearchParam> create_search_param( - const std::string& algo, const nlohmann::json& conf) -{ - if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { - if (algo == "ggnn") { - auto param = std::make_unique::SearchParam>(); - parse_search_param(conf, *param); - return param; - } - } - // else - throw std::runtime_error("invalid algo: '" + algo + "'"); -} - -} // namespace raft::bench::ann - -REGISTER_ALGO_INSTANCE(float); -REGISTER_ALGO_INSTANCE(std::int8_t); -REGISTER_ALGO_INSTANCE(std::uint8_t); - -#ifdef ANN_BENCH_BUILD_MAIN -#include "../common/benchmark.hpp" -int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); } -#endif diff --git a/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh b/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh deleted file mode 100644 index 59cf3df806..0000000000 --- a/cpp/bench/ann/src/ggnn/ggnn_wrapper.cuh +++ /dev/null @@ -1,322 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "../common/ann_types.hpp" -#include "../common/util.hpp" - -#include - -#include - -#include -#include - -namespace raft::bench::ann { - -template -class GgnnImpl; - -template -class Ggnn : public ANN, public AnnGPU { - public: - struct BuildParam { - int k_build{24}; // KBuild - int segment_size{32}; // S - int num_layers{4}; // L - float tau{0.5}; - int refine_iterations{2}; - int k; // GGNN requires to know k during building - }; - - using typename ANN::AnnSearchParam; - struct SearchParam : public AnnSearchParam { - float tau; - int block_dim{32}; - int max_iterations{400}; - int cache_size{512}; - int sorted_size{256}; - auto needs_dataset() const -> bool override { return true; } - }; - - Ggnn(Metric metric, int dim, const BuildParam& param); - - void build(const T* dataset, size_t nrow) override { impl_->build(dataset, nrow); } - - void set_search_param(const AnnSearchParam& param) override { impl_->set_search_param(param); } - void search(const T* queries, - int batch_size, - int k, - AnnBase::index_type* neighbors, - float* distances) const override - { - impl_->search(queries, batch_size, k, neighbors, distances); - } - [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override - { - return dynamic_cast(impl_.get())->get_sync_stream(); - } - - void save(const std::string& file) const override { impl_->save(file); } - void load(const std::string& file) override { impl_->load(file); } - std::unique_ptr> copy() override { return std::make_unique>(*this); }; - - AlgoProperty get_preference() const override { return impl_->get_preference(); } - - void set_search_dataset(const T* dataset, size_t nrow) override - { - impl_->set_search_dataset(dataset, nrow); - }; - - private: - std::shared_ptr> impl_; -}; - -template -Ggnn::Ggnn(Metric metric, int dim, const BuildParam& param) : ANN(metric, dim) -{ - // ggnn/src/sift1m.cu - if (metric == Metric::kEuclidean && dim == 128 && param.k_build == 24 && param.k == 10 && - param.segment_size == 32) { - impl_ = std::make_shared>(metric, dim, param); - } - // ggnn/src/deep1b_multi_gpu.cu, and adapt it deep1B - else if (metric == Metric::kEuclidean && dim == 96 && param.k_build == 24 && param.k == 10 && - param.segment_size == 32) { - impl_ = std::make_shared>(metric, dim, param); - } else if (metric == Metric::kInnerProduct && dim == 96 && param.k_build == 24 && param.k == 10 && - param.segment_size == 32) { - impl_ = std::make_shared>(metric, dim, param); - } else if (metric == Metric::kInnerProduct && dim == 96 && param.k_build == 96 && param.k == 10 && - param.segment_size == 64) { - impl_ = std::make_shared>(metric, dim, param); - } - // ggnn/src/glove200.cu, adapt it to glove100 - else if (metric == Metric::kInnerProduct && dim == 100 && param.k_build == 96 && param.k == 10 && - param.segment_size == 64) { - impl_ = std::make_shared>(metric, dim, param); - } else { - throw std::runtime_error( - "ggnn: not supported combination of metric, dim and build param; " - "see Ggnn's constructor in ggnn_wrapper.cuh for available combinations"); - } -} - -template -class GgnnImpl : public ANN, public AnnGPU { - public: - using typename ANN::AnnSearchParam; - - GgnnImpl(Metric metric, int dim, const typename Ggnn::BuildParam& param); - - void build(const T* dataset, size_t nrow) override; - - void set_search_param(const AnnSearchParam& param) override; - void search(const T* queries, - int batch_size, - int k, - AnnBase::index_type* neighbors, - float* distances) const override; - [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override { return stream_; } - - void save(const std::string& file) const override; - void load(const std::string& file) override; - std::unique_ptr> copy() override - { - auto r = std::make_unique>(*this); - // set the thread-local stream to the copied handle. - r->stream_ = raft::bench::ann::get_stream_from_global_pool(); - return r; - }; - - AlgoProperty get_preference() const override - { - AlgoProperty property; - property.dataset_memory_type = MemoryType::Device; - property.query_memory_type = MemoryType::Device; - return property; - } - - void set_search_dataset(const T* dataset, size_t nrow) override; - - private: - using ANN::metric_; - using ANN::dim_; - - using GGNNGPUInstance = GGNNGPUInstance; - std::shared_ptr ggnn_; - typename Ggnn::BuildParam build_param_; - typename Ggnn::SearchParam search_param_; - cudaStream_t stream_; - const T* base_dataset = nullptr; - size_t base_n_rows = 0; - std::optional graph_file = std::nullopt; - - void load_impl() - { - if (base_dataset == nullptr) { return; } - if (base_n_rows == 0) { return; } - int device; - RAFT_CUDA_TRY(cudaGetDevice(&device)); - ggnn_ = std::make_shared( - device, base_n_rows, build_param_.num_layers, true, build_param_.tau); - ggnn_->set_base_data(base_dataset); - ggnn_->set_stream(get_sync_stream()); - if (graph_file.has_value()) { - auto& ggnn_host = ggnn_->ggnn_cpu_buffers.at(0); - auto& ggnn_device = ggnn_->ggnn_shards.at(0); - ggnn_->set_stream(get_sync_stream()); - - ggnn_host.load(graph_file.value()); - ggnn_host.uploadAsync(ggnn_device); - RAFT_CUDA_TRY(cudaStreamSynchronize(ggnn_device.stream)); - } - } -}; - -template -GgnnImpl::GgnnImpl(Metric metric, - int dim, - const typename Ggnn::BuildParam& param) - : ANN(metric, dim), - build_param_(param), - stream_(raft::bench::ann::get_stream_from_global_pool()) -{ - if (metric_ == Metric::kInnerProduct) { - if (measure != Cosine) { throw std::runtime_error("mis-matched metric"); } - } else if (metric_ == Metric::kEuclidean) { - if (measure != Euclidean) { throw std::runtime_error("mis-matched metric"); } - } else { - throw std::runtime_error( - "ggnn supports only metric type of InnerProduct, Cosine and Euclidean"); - } - - if (dim != D) { throw std::runtime_error("mis-matched dim"); } -} - -template -void GgnnImpl::build(const T* dataset, size_t nrow) -{ - base_dataset = dataset; - base_n_rows = nrow; - graph_file = std::nullopt; - load_impl(); - ggnn_->build(0); - for (int i = 0; i < build_param_.refine_iterations; ++i) { - ggnn_->refine(); - } -} - -template -void GgnnImpl::set_search_dataset(const T* dataset, size_t nrow) -{ - if (base_dataset != dataset || base_n_rows != nrow) { - base_dataset = dataset; - base_n_rows = nrow; - load_impl(); - } -} - -template -void GgnnImpl::set_search_param(const AnnSearchParam& param) -{ - search_param_ = dynamic_cast::SearchParam&>(param); -} - -template -void GgnnImpl::search( - const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const -{ - static_assert(sizeof(size_t) == sizeof(int64_t), "sizes of size_t and GGNN's KeyT are different"); - if (k != KQuery) { - throw std::runtime_error( - "k = " + std::to_string(k) + - ", but this GGNN instance only supports k = " + std::to_string(KQuery)); - } - - ggnn_->set_stream(get_sync_stream()); - RAFT_CUDA_TRY(cudaMemcpyToSymbol(c_tau_query, &search_param_.tau, sizeof(float))); - - const int block_dim = search_param_.block_dim; - const int max_iterations = search_param_.max_iterations; - const int cache_size = search_param_.cache_size; - const int sorted_size = search_param_.sorted_size; - // default value - if (block_dim == 32 && max_iterations == 400 && cache_size == 512 && sorted_size == 256) { - ggnn_->template queryLayer<32, 400, 512, 256, false>( - queries, batch_size, reinterpret_cast(neighbors), distances); - } - // ggnn/src/sift1m.cu - else if (block_dim == 32 && max_iterations == 200 && cache_size == 256 && sorted_size == 64) { - ggnn_->template queryLayer<32, 200, 256, 64, false>( - queries, batch_size, reinterpret_cast(neighbors), distances); - } - // ggnn/src/sift1m.cu - else if (block_dim == 32 && max_iterations == 400 && cache_size == 448 && sorted_size == 64) { - ggnn_->template queryLayer<32, 400, 448, 64, false>( - queries, batch_size, reinterpret_cast(neighbors), distances); - } - // ggnn/src/glove200.cu - else if (block_dim == 128 && max_iterations == 2000 && cache_size == 2048 && sorted_size == 32) { - ggnn_->template queryLayer<128, 2000, 2048, 32, false>( - queries, batch_size, reinterpret_cast(neighbors), distances); - } - // for glove100 - else if (block_dim == 64 && max_iterations == 400 && cache_size == 512 && sorted_size == 32) { - ggnn_->template queryLayer<64, 400, 512, 32, false>( - queries, batch_size, reinterpret_cast(neighbors), distances); - } else if (block_dim == 128 && max_iterations == 2000 && cache_size == 1024 && - sorted_size == 32) { - ggnn_->template queryLayer<128, 2000, 1024, 32, false>( - queries, batch_size, reinterpret_cast(neighbors), distances); - } else { - throw std::runtime_error("ggnn: not supported search param"); - } -} - -template -void GgnnImpl::save(const std::string& file) const -{ - auto& ggnn_host = ggnn_->ggnn_cpu_buffers.at(0); - auto& ggnn_device = ggnn_->ggnn_shards.at(0); - ggnn_->set_stream(get_sync_stream()); - - ggnn_host.downloadAsync(ggnn_device); - RAFT_CUDA_TRY(cudaStreamSynchronize(ggnn_device.stream)); - ggnn_host.store(file); -} - -template -void GgnnImpl::load(const std::string& file) -{ - if (!graph_file.has_value() || graph_file.value() != file) { - graph_file = file; - load_impl(); - } -} - -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_benchmark.cpp b/cpp/bench/ann/src/hnswlib/hnswlib_benchmark.cpp deleted file mode 100644 index df82c68830..0000000000 --- a/cpp/bench/ann/src/hnswlib/hnswlib_benchmark.cpp +++ /dev/null @@ -1,120 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../common/ann_types.hpp" -#include "hnswlib_wrapper.h" - -#define JSON_DIAGNOSTICS 1 -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace raft::bench::ann { - -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::HnswLib::BuildParam& param) -{ - param.ef_construction = conf.at("efConstruction"); - param.M = conf.at("M"); - if (conf.contains("numThreads")) { param.num_threads = conf.at("numThreads"); } -} - -template -void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::HnswLib::SearchParam& param) -{ - param.ef = conf.at("ef"); - if (conf.contains("numThreads")) { param.num_threads = conf.at("numThreads"); } -} - -template class Algo> -std::unique_ptr> make_algo(raft::bench::ann::Metric metric, - int dim, - const nlohmann::json& conf) -{ - typename Algo::BuildParam param; - parse_build_param(conf, param); - return std::make_unique>(metric, dim, param); -} - -template class Algo> -std::unique_ptr> make_algo(raft::bench::ann::Metric metric, - int dim, - const nlohmann::json& conf, - const std::vector& dev_list) -{ - typename Algo::BuildParam param; - parse_build_param(conf, param); - - (void)dev_list; - return std::make_unique>(metric, dim, param); -} - -template -std::unique_ptr> create_algo(const std::string& algo, - const std::string& distance, - int dim, - const nlohmann::json& conf, - const std::vector& dev_list) -{ - // stop compiler warning; not all algorithms support multi-GPU so it may not be used - (void)dev_list; - - raft::bench::ann::Metric metric = parse_metric(distance); - std::unique_ptr> ann; - - if constexpr (std::is_same_v) { - if (algo == "hnswlib") { ann = make_algo(metric, dim, conf); } - } - - if constexpr (std::is_same_v) { - if (algo == "hnswlib") { ann = make_algo(metric, dim, conf); } - } - - if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } - return ann; -} - -template -std::unique_ptr::AnnSearchParam> create_search_param( - const std::string& algo, const nlohmann::json& conf) -{ - if (algo == "hnswlib") { - auto param = std::make_unique::SearchParam>(); - parse_search_param(conf, *param); - return param; - } - // else - throw std::runtime_error("invalid algo: '" + algo + "'"); -} - -}; // namespace raft::bench::ann - -REGISTER_ALGO_INSTANCE(float); -REGISTER_ALGO_INSTANCE(std::int8_t); -REGISTER_ALGO_INSTANCE(std::uint8_t); - -#ifdef ANN_BENCH_BUILD_MAIN -#include "../common/benchmark.hpp" -int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); } -#endif diff --git a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h b/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h deleted file mode 100644 index 5743632bf4..0000000000 --- a/cpp/bench/ann/src/hnswlib/hnswlib_wrapper.h +++ /dev/null @@ -1,239 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../common/ann_types.hpp" -#include "../common/thread_pool.hpp" - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft::bench::ann { - -template -struct hnsw_dist_t { - using type = void; -}; - -template <> -struct hnsw_dist_t { - using type = float; -}; - -template <> -struct hnsw_dist_t { - using type = int; -}; - -template <> -struct hnsw_dist_t { - using type = int; -}; - -template -class HnswLib : public ANN { - public: - // https://github.com/nmslib/hnswlib/blob/master/ALGO_PARAMS.md - struct BuildParam { - int M; - int ef_construction; - int num_threads = omp_get_num_procs(); - }; - - using typename ANN::AnnSearchParam; - struct SearchParam : public AnnSearchParam { - int ef; - int num_threads = 1; - }; - - HnswLib(Metric metric, int dim, const BuildParam& param); - - void build(const T* dataset, size_t nrow) override; - - void set_search_param(const AnnSearchParam& param) override; - void search(const T* query, - int batch_size, - int k, - AnnBase::index_type* indices, - float* distances) const override; - - void save(const std::string& path_to_index) const override; - void load(const std::string& path_to_index) override; - std::unique_ptr> copy() override { return std::make_unique>(*this); }; - - AlgoProperty get_preference() const override - { - AlgoProperty property; - property.dataset_memory_type = MemoryType::Host; - property.query_memory_type = MemoryType::Host; - return property; - } - - void set_base_layer_only() { appr_alg_->base_layer_only = true; } - - private: - void get_search_knn_results_(const T* query, - int k, - AnnBase::index_type* indices, - float* distances) const; - - std::shared_ptr::type>> appr_alg_; - std::shared_ptr::type>> space_; - - using ANN::metric_; - using ANN::dim_; - int ef_construction_; - int m_; - int num_threads_; - std::shared_ptr thread_pool_; - Objective metric_objective_; -}; - -template -HnswLib::HnswLib(Metric metric, int dim, const BuildParam& param) : ANN(metric, dim) -{ - assert(dim_ > 0); - static_assert(std::is_same_v || std::is_same_v); - if constexpr (std::is_same_v) { - if (metric_ != Metric::kEuclidean) { - throw std::runtime_error("hnswlib only supports Euclidean distance"); - } - } - - ef_construction_ = param.ef_construction; - m_ = param.M; - num_threads_ = param.num_threads; -} - -template -void HnswLib::build(const T* dataset, size_t nrow) -{ - if constexpr (std::is_same_v) { - if (metric_ == Metric::kInnerProduct) { - space_ = std::make_shared(dim_); - } else { - space_ = std::make_shared(dim_); - } - } else if constexpr (std::is_same_v) { - space_ = std::make_shared>(dim_); - } - - appr_alg_ = std::make_shared::type>>( - space_.get(), nrow, m_, ef_construction_); - - thread_pool_ = std::make_shared(num_threads_); - const size_t items_per_thread = nrow / (num_threads_ + 1); - - thread_pool_->submit( - [&](size_t i) { - if (i < items_per_thread && i % 10000 == 0) { - char buf[20]; - std::time_t now = std::time(nullptr); - std::strftime(buf, sizeof(buf), "%Y-%m-%d %H:%M:%S", std::localtime(&now)); - printf("%s building %zu / %zu\n", buf, i, items_per_thread); - fflush(stdout); - } - - appr_alg_->addPoint(dataset + i * dim_, i); - }, - nrow); -} - -template -void HnswLib::set_search_param(const AnnSearchParam& param_) -{ - auto param = dynamic_cast(param_); - appr_alg_->ef_ = param.ef; - metric_objective_ = param.metric_objective; - num_threads_ = param.num_threads; - - // Create a pool if multiple query threads have been set and the pool hasn't been created already - bool create_pool = (metric_objective_ == Objective::LATENCY && num_threads_ > 1 && !thread_pool_); - if (create_pool) { thread_pool_ = std::make_shared(num_threads_); } -} - -template -void HnswLib::search( - const T* query, int batch_size, int k, AnnBase::index_type* indices, float* distances) const -{ - auto f = [&](int i) { - // hnsw can only handle a single vector at a time. - get_search_knn_results_(query + i * dim_, k, indices + i * k, distances + i * k); - }; - if (metric_objective_ == Objective::LATENCY && num_threads_ > 1) { - thread_pool_->submit(f, batch_size); - } else { - for (int i = 0; i < batch_size; i++) { - f(i); - } - } -} - -template -void HnswLib::save(const std::string& path_to_index) const -{ - appr_alg_->saveIndex(std::string(path_to_index)); -} - -template -void HnswLib::load(const std::string& path_to_index) -{ - if constexpr (std::is_same_v) { - if (metric_ == Metric::kInnerProduct) { - space_ = std::make_shared(dim_); - } else { - space_ = std::make_shared(dim_); - } - } else if constexpr (std::is_same_v) { - space_ = std::make_shared>(dim_); - } - - appr_alg_ = std::make_shared::type>>( - space_.get(), path_to_index); -} - -template -void HnswLib::get_search_knn_results_(const T* query, - int k, - AnnBase::index_type* indices, - float* distances) const -{ - auto result = appr_alg_->searchKnn(query, k); - assert(result.size() >= static_cast(k)); - - for (int i = k - 1; i >= 0; --i) { - indices[i] = result.top().second; - distances[i] = result.top().first; - result.pop(); - } -} - -}; // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h b/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h deleted file mode 100644 index 48bf1d70d8..0000000000 --- a/cpp/bench/ann/src/raft/raft_ann_bench_param_parser.h +++ /dev/null @@ -1,275 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#define JSON_DIAGNOSTICS 1 -#include - -#undef WARP_SIZE -#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE -#include "raft_wrapper.h" -#endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT -#include "raft_ivf_flat_wrapper.h" -extern template class raft::bench::ann::RaftIvfFlatGpu; -extern template class raft::bench::ann::RaftIvfFlatGpu; -extern template class raft::bench::ann::RaftIvfFlatGpu; -#endif -#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) -#include "raft_ivf_pq_wrapper.h" -#endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ -extern template class raft::bench::ann::RaftIvfPQ; -extern template class raft::bench::ann::RaftIvfPQ; -extern template class raft::bench::ann::RaftIvfPQ; -#endif -#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) -#include "raft_cagra_wrapper.h" -#endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA -extern template class raft::bench::ann::RaftCagra; -extern template class raft::bench::ann::RaftCagra; -extern template class raft::bench::ann::RaftCagra; -extern template class raft::bench::ann::RaftCagra; -#endif - -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftIvfFlatGpu::BuildParam& param) -{ - param.n_lists = conf.at("nlist"); - if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } - if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } -} - -template -void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftIvfFlatGpu::SearchParam& param) -{ - param.ivf_flat_params.n_probes = conf.at("nprobe"); -} -#endif - -#if defined(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || \ - defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftIvfPQ::BuildParam& param) -{ - if (conf.contains("nlist")) { param.n_lists = conf.at("nlist"); } - if (conf.contains("niter")) { param.kmeans_n_iters = conf.at("niter"); } - if (conf.contains("ratio")) { param.kmeans_trainset_fraction = 1.0 / (double)conf.at("ratio"); } - if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); } - if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); } - if (conf.contains("codebook_kind")) { - std::string kind = conf.at("codebook_kind"); - if (kind == "cluster") { - param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_CLUSTER; - } else if (kind == "subspace") { - param.codebook_kind = raft::neighbors::ivf_pq::codebook_gen::PER_SUBSPACE; - } else { - throw std::runtime_error("codebook_kind: '" + kind + - "', should be either 'cluster' or 'subspace'"); - } - } -} - -template -void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftIvfPQ::SearchParam& param) -{ - if (conf.contains("nprobe")) { param.pq_param.n_probes = conf.at("nprobe"); } - if (conf.contains("internalDistanceDtype")) { - std::string type = conf.at("internalDistanceDtype"); - if (type == "float") { - param.pq_param.internal_distance_dtype = CUDA_R_32F; - } else if (type == "half") { - param.pq_param.internal_distance_dtype = CUDA_R_16F; - } else { - throw std::runtime_error("internalDistanceDtype: '" + type + - "', should be either 'float' or 'half'"); - } - } else { - // set half as default type - param.pq_param.internal_distance_dtype = CUDA_R_16F; - } - - if (conf.contains("smemLutDtype")) { - std::string type = conf.at("smemLutDtype"); - if (type == "float") { - param.pq_param.lut_dtype = CUDA_R_32F; - } else if (type == "half") { - param.pq_param.lut_dtype = CUDA_R_16F; - } else if (type == "fp8") { - param.pq_param.lut_dtype = CUDA_R_8U; - } else { - throw std::runtime_error("smemLutDtype: '" + type + - "', should be either 'float', 'half' or 'fp8'"); - } - } else { - // set half as default - param.pq_param.lut_dtype = CUDA_R_16F; - } - if (conf.contains("refine_ratio")) { - param.refine_ratio = conf.at("refine_ratio"); - if (param.refine_ratio < 1.0f) { throw std::runtime_error("refine_ratio should be >= 1.0"); } - } -} -#endif - -#if defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA) || defined(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB) -template -void parse_build_param(const nlohmann::json& conf, - raft::neighbors::experimental::nn_descent::index_params& param) -{ - if (conf.contains("graph_degree")) { param.graph_degree = conf.at("graph_degree"); } - if (conf.contains("intermediate_graph_degree")) { - param.intermediate_graph_degree = conf.at("intermediate_graph_degree"); - } - // we allow niter shorthand for max_iterations - if (conf.contains("niter")) { param.max_iterations = conf.at("niter"); } - if (conf.contains("max_iterations")) { param.max_iterations = conf.at("max_iterations"); } - if (conf.contains("termination_threshold")) { - param.termination_threshold = conf.at("termination_threshold"); - } -} - -inline void parse_build_param(const nlohmann::json& conf, raft::neighbors::vpq_params& param) -{ - if (conf.contains("pq_bits")) { param.pq_bits = conf.at("pq_bits"); } - if (conf.contains("pq_dim")) { param.pq_dim = conf.at("pq_dim"); } - if (conf.contains("vq_n_centers")) { param.vq_n_centers = conf.at("vq_n_centers"); } - if (conf.contains("kmeans_n_iters")) { param.kmeans_n_iters = conf.at("kmeans_n_iters"); } - if (conf.contains("vq_kmeans_trainset_fraction")) { - param.vq_kmeans_trainset_fraction = conf.at("vq_kmeans_trainset_fraction"); - } - if (conf.contains("pq_kmeans_trainset_fraction")) { - param.pq_kmeans_trainset_fraction = conf.at("pq_kmeans_trainset_fraction"); - } -} - -nlohmann::json collect_conf_with_prefix(const nlohmann::json& conf, - const std::string& prefix, - bool remove_prefix = true) -{ - nlohmann::json out; - for (auto& i : conf.items()) { - if (i.key().compare(0, prefix.size(), prefix) == 0) { - auto new_key = remove_prefix ? i.key().substr(prefix.size()) : i.key(); - out[new_key] = i.value(); - } - } - return out; -} - -template -void parse_build_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftCagra::BuildParam& param) -{ - if (conf.contains("graph_degree")) { - param.cagra_params.graph_degree = conf.at("graph_degree"); - param.cagra_params.intermediate_graph_degree = param.cagra_params.graph_degree * 2; - } - if (conf.contains("intermediate_graph_degree")) { - param.cagra_params.intermediate_graph_degree = conf.at("intermediate_graph_degree"); - } - if (conf.contains("graph_build_algo")) { - if (conf.at("graph_build_algo") == "IVF_PQ") { - param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::IVF_PQ; - } else if (conf.at("graph_build_algo") == "NN_DESCENT") { - param.cagra_params.build_algo = raft::neighbors::cagra::graph_build_algo::NN_DESCENT; - } - } - nlohmann::json ivf_pq_build_conf = collect_conf_with_prefix(conf, "ivf_pq_build_"); - if (!ivf_pq_build_conf.empty()) { - raft::neighbors::ivf_pq::index_params bparam; - parse_build_param(ivf_pq_build_conf, bparam); - param.ivf_pq_build_params = bparam; - } - nlohmann::json ivf_pq_search_conf = collect_conf_with_prefix(conf, "ivf_pq_search_"); - if (!ivf_pq_search_conf.empty()) { - typename raft::bench::ann::RaftIvfPQ::SearchParam sparam; - parse_search_param(ivf_pq_search_conf, sparam); - param.ivf_pq_search_params = sparam.pq_param; - param.ivf_pq_refine_rate = sparam.refine_ratio; - } - nlohmann::json nn_descent_conf = collect_conf_with_prefix(conf, "nn_descent_"); - if (!nn_descent_conf.empty()) { - raft::neighbors::experimental::nn_descent::index_params nn_param; - nn_param.intermediate_graph_degree = 1.5 * param.cagra_params.intermediate_graph_degree; - parse_build_param(nn_descent_conf, nn_param); - if (nn_param.graph_degree != param.cagra_params.intermediate_graph_degree) { - nn_param.graph_degree = param.cagra_params.intermediate_graph_degree; - } - param.nn_descent_params = nn_param; - } - nlohmann::json comp_search_conf = collect_conf_with_prefix(conf, "compression_"); - if (!comp_search_conf.empty()) { - raft::neighbors::vpq_params vpq_pams; - parse_build_param(comp_search_conf, vpq_pams); - param.cagra_params.compression.emplace(vpq_pams); - } -} - -raft::bench::ann::AllocatorType parse_allocator(std::string mem_type) -{ - if (mem_type == "device") { - return raft::bench::ann::AllocatorType::Device; - } else if (mem_type == "host_pinned") { - return raft::bench::ann::AllocatorType::HostPinned; - } else if (mem_type == "host_huge_page") { - return raft::bench::ann::AllocatorType::HostHugePage; - } - THROW( - "Invalid value for memory type %s, must be one of [\"device\", \"host_pinned\", " - "\"host_huge_page\"", - mem_type.c_str()); -} - -template -void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftCagra::SearchParam& param) -{ - if (conf.contains("itopk")) { param.p.itopk_size = conf.at("itopk"); } - if (conf.contains("search_width")) { param.p.search_width = conf.at("search_width"); } - if (conf.contains("max_iterations")) { param.p.max_iterations = conf.at("max_iterations"); } - if (conf.contains("algo")) { - if (conf.at("algo") == "single_cta") { - param.p.algo = raft::neighbors::experimental::cagra::search_algo::SINGLE_CTA; - } else if (conf.at("algo") == "multi_cta") { - param.p.algo = raft::neighbors::experimental::cagra::search_algo::MULTI_CTA; - } else if (conf.at("algo") == "multi_kernel") { - param.p.algo = raft::neighbors::experimental::cagra::search_algo::MULTI_KERNEL; - } else if (conf.at("algo") == "auto") { - param.p.algo = raft::neighbors::experimental::cagra::search_algo::AUTO; - } else { - std::string tmp = conf.at("algo"); - THROW("Invalid value for algo: %s", tmp.c_str()); - } - } - if (conf.contains("graph_memory_type")) { - param.graph_mem = parse_allocator(conf.at("graph_memory_type")); - } - if (conf.contains("internal_dataset_memory_type")) { - param.dataset_mem = parse_allocator(conf.at("internal_dataset_memory_type")); - } - // Same ratio as in IVF-PQ - param.refine_ratio = conf.value("refine_ratio", 1.0f); -} -#endif diff --git a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h b/cpp/bench/ann/src/raft/raft_ann_bench_utils.h deleted file mode 100644 index 9b086fdb23..0000000000 --- a/cpp/bench/ann/src/raft/raft_ann_bench_utils.h +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../common/util.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace raft::bench::ann { - -inline raft::distance::DistanceType parse_metric_type(raft::bench::ann::Metric metric) -{ - if (metric == raft::bench::ann::Metric::kInnerProduct) { - return raft::distance::DistanceType::InnerProduct; - } else if (metric == raft::bench::ann::Metric::kEuclidean) { - // Even for L2 expanded RAFT IVF Flat uses unexpanded formula - return raft::distance::DistanceType::L2Expanded; - } else { - throw std::runtime_error("raft supports only metric type of inner product and L2"); - } -} - -/** Report a more verbose error with a backtrace when OOM occurs on RMM side. */ -inline auto rmm_oom_callback(std::size_t bytes, void*) -> bool -{ - auto cuda_status = cudaGetLastError(); - size_t free = 0; - size_t total = 0; - RAFT_CUDA_TRY_NO_THROW(cudaMemGetInfo(&free, &total)); - RAFT_FAIL( - "Failed to allocate %zu bytes using RMM memory resource. " - "NB: latest cuda status = %s, free memory = %zu, total memory = %zu.", - bytes, - cudaGetErrorName(cuda_status), - free, - total); -} - -/** - * This container keeps the part of raft state that should be shared among multiple copies of raft - * handles (in different CPU threads). - * An example of this is an RMM memory resource: if we had an RMM memory pool per thread, we'd - * quickly run out of memory. - */ -class shared_raft_resources { - public: - using pool_mr_type = rmm::mr::pool_memory_resource; - using mr_type = rmm::mr::failure_callback_resource_adaptor; - using large_mr_type = rmm::mr::managed_memory_resource; - - shared_raft_resources() - try : orig_resource_{rmm::mr::get_current_device_resource()}, - pool_resource_(orig_resource_, 1024 * 1024 * 1024ull), - resource_(&pool_resource_, rmm_oom_callback, nullptr), large_mr_() { - rmm::mr::set_current_device_resource(&resource_); - } catch (const std::exception& e) { - auto cuda_status = cudaGetLastError(); - size_t free = 0; - size_t total = 0; - RAFT_CUDA_TRY_NO_THROW(cudaMemGetInfo(&free, &total)); - RAFT_FAIL( - "Failed to initialize shared raft resources (NB: latest cuda status = %s, free memory = %zu, " - "total memory = %zu): %s", - cudaGetErrorName(cuda_status), - free, - total, - e.what()); - } - - shared_raft_resources(shared_raft_resources&&) = delete; - shared_raft_resources& operator=(shared_raft_resources&&) = delete; - shared_raft_resources(const shared_raft_resources& res) = delete; - shared_raft_resources& operator=(const shared_raft_resources& other) = delete; - - ~shared_raft_resources() noexcept { rmm::mr::set_current_device_resource(orig_resource_); } - - auto get_large_memory_resource() noexcept - { - return static_cast(&large_mr_); - } - - private: - rmm::mr::device_memory_resource* orig_resource_; - pool_mr_type pool_resource_; - mr_type resource_; - large_mr_type large_mr_; -}; - -/** - * This struct is used by multiple raft benchmark wrappers. It serves as a thread-safe keeper of - * shared and private GPU resources (see below). - * - * - Accessing the same `configured_raft_resources` from concurrent threads is not safe. - * - Accessing the copies of `configured_raft_resources` from concurrent threads is safe. - * - There must be at most one "original" `configured_raft_resources` at any time, but as many - * copies of it as needed (modifies the program static state). - */ -class configured_raft_resources { - public: - /** - * This constructor has the shared state passed unmodified but creates the local state anew. - * It's used by the copy constructor. - */ - explicit configured_raft_resources(const std::shared_ptr& shared_res) - : shared_res_{shared_res}, - res_{std::make_unique( - rmm::cuda_stream_view(get_stream_from_global_pool()))} - { - // set the large workspace resource to the raft handle, but without the deleter - // (this resource is managed by the shared_res). - raft::resource::set_large_workspace_resource( - *res_, - std::shared_ptr(shared_res_->get_large_memory_resource(), - raft::void_op{})); - } - - /** Default constructor creates all resources anew. */ - configured_raft_resources() : configured_raft_resources{std::make_shared()} - { - } - - configured_raft_resources(configured_raft_resources&&); - configured_raft_resources& operator=(configured_raft_resources&&); - ~configured_raft_resources() = default; - configured_raft_resources(const configured_raft_resources& res) - : configured_raft_resources{res.shared_res_} - { - } - configured_raft_resources& operator=(const configured_raft_resources& other) - { - this->shared_res_ = other.shared_res_; - return *this; - } - - operator raft::resources&() noexcept { return *res_; } - operator const raft::resources&() const noexcept { return *res_; } - - /** Get the main stream */ - [[nodiscard]] auto get_sync_stream() const noexcept { return resource::get_cuda_stream(*res_); } - - private: - /** The resources shared among multiple raft handles / threads. */ - std::shared_ptr shared_res_; - /** - * Until we make the use of copies of raft::resources thread-safe, each benchmark wrapper must - * have its own copy of it. - */ - std::unique_ptr res_ = std::make_unique(); -}; - -inline configured_raft_resources::configured_raft_resources(configured_raft_resources&&) = default; -inline configured_raft_resources& configured_raft_resources::operator=( - configured_raft_resources&&) = default; - -/** A helper to refine the neighbors when the data is on device or on host. */ -template -void refine_helper(const raft::resources& res, - DatasetT dataset, - QueriesT queries, - CandidatesT candidates, - int k, - AnnBase::index_type* neighbors, - float* distances, - raft::distance::DistanceType metric) -{ - using data_type = typename DatasetT::value_type; - using index_type = AnnBase::index_type; - using extents_type = index_type; // device-side refine requires this - - static_assert(std::is_same_v); - static_assert(std::is_same_v); - static_assert(std::is_same_v); - - extents_type batch_size = queries.extent(0); - extents_type dim = queries.extent(1); - extents_type k0 = candidates.extent(1); - - if (raft::get_device_for_address(dataset.data_handle()) >= 0) { - auto dataset_device = raft::make_device_matrix_view( - dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - auto queries_device = raft::make_device_matrix_view( - queries.data_handle(), batch_size, dim); - auto candidates_device = raft::make_device_matrix_view( - candidates.data_handle(), batch_size, k0); - auto neighbors_device = - raft::make_device_matrix_view(neighbors, batch_size, k); - auto distances_device = - raft::make_device_matrix_view(distances, batch_size, k); - - raft::neighbors::refine(res, - dataset_device, - queries_device, - candidates_device, - neighbors_device, - distances_device, - metric); - } else { - auto dataset_host = raft::make_host_matrix_view( - dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - auto queries_host = raft::make_host_matrix(batch_size, dim); - auto candidates_host = raft::make_host_matrix(batch_size, k0); - auto neighbors_host = raft::make_host_matrix(batch_size, k); - auto distances_host = raft::make_host_matrix(batch_size, k); - - auto stream = resource::get_cuda_stream(res); - raft::copy(queries_host.data_handle(), queries.data_handle(), queries_host.size(), stream); - raft::copy( - candidates_host.data_handle(), candidates.data_handle(), candidates_host.size(), stream); - - raft::resource::sync_stream(res); // wait for the queries and candidates - raft::neighbors::refine(res, - dataset_host, - queries_host.view(), - candidates_host.view(), - neighbors_host.view(), - distances_host.view(), - metric); - - raft::copy(neighbors, neighbors_host.data_handle(), neighbors_host.size(), stream); - raft::copy(distances, distances_host.data_handle(), distances_host.size(), stream); - } -} - -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu deleted file mode 100644 index 8bb4d9423c..0000000000 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ /dev/null @@ -1,140 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../common/ann_types.hpp" -#include "raft_ann_bench_param_parser.h" - -#include - -#include - -#define JSON_DIAGNOSTICS 1 -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace raft::bench::ann { - -template -std::unique_ptr> create_algo(const std::string& algo, - const std::string& distance, - int dim, - const nlohmann::json& conf, - const std::vector& dev_list) -{ - // stop compiler warning; not all algorithms support multi-GPU so it may not be used - (void)dev_list; - - [[maybe_unused]] raft::bench::ann::Metric metric = parse_metric(distance); - std::unique_ptr> ann; - - if constexpr (std::is_same_v) { -#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE - if (algo == "raft_brute_force") { - ann = std::make_unique>(metric, dim); - } -#endif - } - - if constexpr (std::is_same_v) {} - -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT - if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { - if (algo == "raft_ivf_flat") { - typename raft::bench::ann::RaftIvfFlatGpu::BuildParam param; - parse_build_param(conf, param); - ann = std::make_unique>(metric, dim, param); - } - } -#endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ - if (algo == "raft_ivf_pq") { - typename raft::bench::ann::RaftIvfPQ::BuildParam param; - parse_build_param(conf, param); - ann = std::make_unique>(metric, dim, param); - } -#endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA - if (algo == "raft_cagra") { - typename raft::bench::ann::RaftCagra::BuildParam param; - parse_build_param(conf, param); - ann = std::make_unique>(metric, dim, param); - } -#endif - - if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } - - return ann; -} - -template -std::unique_ptr::AnnSearchParam> create_search_param( - const std::string& algo, const nlohmann::json& conf) -{ -#ifdef RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE - if (algo == "raft_brute_force") { - auto param = std::make_unique::AnnSearchParam>(); - return param; - } -#endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT - if constexpr (std::is_same_v || std::is_same_v || - std::is_same_v) { - if (algo == "raft_ivf_flat") { - auto param = - std::make_unique::SearchParam>(); - parse_search_param(conf, *param); - return param; - } - } -#endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_IVF_PQ - if (algo == "raft_ivf_pq") { - auto param = std::make_unique::SearchParam>(); - parse_search_param(conf, *param); - return param; - } -#endif -#ifdef RAFT_ANN_BENCH_USE_RAFT_CAGRA - if (algo == "raft_cagra") { - auto param = std::make_unique::SearchParam>(); - parse_search_param(conf, *param); - return param; - } -#endif - - // else - throw std::runtime_error("invalid algo: '" + algo + "'"); -} - -}; // namespace raft::bench::ann - -REGISTER_ALGO_INSTANCE(float); -REGISTER_ALGO_INSTANCE(half); -REGISTER_ALGO_INSTANCE(std::int8_t); -REGISTER_ALGO_INSTANCE(std::uint8_t); - -#ifdef ANN_BENCH_BUILD_MAIN -#include "../common/benchmark.hpp" -int main(int argc, char** argv) { return raft::bench::ann::run_main(argc, argv); } -#endif diff --git a/cpp/bench/ann/src/raft/raft_cagra_float.cu b/cpp/bench/ann/src/raft/raft_cagra_float.cu deleted file mode 100644 index 058f5bf34a..0000000000 --- a/cpp/bench/ann/src/raft/raft_cagra_float.cu +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "raft_cagra_wrapper.h" - -namespace raft::bench::ann { -template class RaftCagra; -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_half.cu b/cpp/bench/ann/src/raft/raft_cagra_half.cu deleted file mode 100644 index a015819ec5..0000000000 --- a/cpp/bench/ann/src/raft/raft_cagra_half.cu +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "raft_cagra_wrapper.h" - -namespace raft::bench::ann { -template class RaftCagra; -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu b/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu deleted file mode 100644 index d9ef1d74a3..0000000000 --- a/cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../common/ann_types.hpp" -#include "raft_ann_bench_param_parser.h" -#include "raft_cagra_hnswlib_wrapper.h" - -#include -#include -#include - -#define JSON_DIAGNOSTICS 1 -#include - -namespace raft::bench::ann { - -template -void parse_search_param(const nlohmann::json& conf, - typename raft::bench::ann::RaftCagraHnswlib::SearchParam& param) -{ - param.ef = conf.at("ef"); - if (conf.contains("numThreads")) { param.num_threads = conf.at("numThreads"); } -} - -template -std::unique_ptr> create_algo(const std::string& algo, - const std::string& distance, - int dim, - const nlohmann::json& conf, - const std::vector& dev_list) -{ - // stop compiler warning; not all algorithms support multi-GPU so it may not be used - (void)dev_list; - - [[maybe_unused]] raft::bench::ann::Metric metric = parse_metric(distance); - std::unique_ptr> ann; - - if constexpr (std::is_same_v or std::is_same_v) { - if (algo == "raft_cagra_hnswlib") { - typename raft::bench::ann::RaftCagraHnswlib::BuildParam param; - parse_build_param(conf, param); - ann = std::make_unique>(metric, dim, param); - } - } - - if (!ann) { throw std::runtime_error("invalid algo: '" + algo + "'"); } - - return ann; -} - -template -std::unique_ptr::AnnSearchParam> create_search_param( - const std::string& algo, const nlohmann::json& conf) -{ - if (algo == "raft_cagra_hnswlib") { - auto param = - std::make_unique::SearchParam>(); - parse_search_param(conf, *param); - return param; - } - - throw std::runtime_error("invalid algo: '" + algo + "'"); -} - -} // namespace raft::bench::ann - -REGISTER_ALGO_INSTANCE(float); -REGISTER_ALGO_INSTANCE(std::int8_t); -REGISTER_ALGO_INSTANCE(std::uint8_t); - -#ifdef ANN_BENCH_BUILD_MAIN -#include "../common/benchmark.hpp" -int main(int argc, char** argv) -{ - rmm::mr::cuda_memory_resource cuda_mr; - // Construct a resource that uses a coalescing best-fit pool allocator - // and is initially sized to half of free device memory. - rmm::mr::pool_memory_resource pool_mr{ - &cuda_mr, rmm::percent_of_free_device_memory(50)}; - // Updates the current device resource pointer to `pool_mr` - auto old_mr = rmm::mr::set_current_device_resource(&pool_mr); - auto ret = raft::bench::ann::run_main(argc, argv); - // Restores the current device resource pointer to its previous value - rmm::mr::set_current_device_resource(old_mr); - return ret; -} -#endif diff --git a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h deleted file mode 100644 index 1d2a1076ab..0000000000 --- a/cpp/bench/ann/src/raft/raft_cagra_hnswlib_wrapper.h +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../hnswlib/hnswlib_wrapper.h" -#include "raft_cagra_wrapper.h" - -#include - -namespace raft::bench::ann { - -template -class RaftCagraHnswlib : public ANN, public AnnGPU { - public: - using typename ANN::AnnSearchParam; - using BuildParam = typename RaftCagra::BuildParam; - using SearchParam = typename HnswLib::SearchParam; - - RaftCagraHnswlib(Metric metric, int dim, const BuildParam& param, int concurrent_searches = 1) - : ANN(metric, dim), - cagra_build_{metric, dim, param, concurrent_searches, true}, - // HnswLib param values don't matter since we don't build with HnswLib - hnswlib_search_{metric, dim, typename HnswLib::BuildParam{50, 100}} - { - } - - void build(const T* dataset, size_t nrow) final; - - void set_search_param(const AnnSearchParam& param) override; - - void search(const T* queries, - int batch_size, - int k, - AnnBase::index_type* neighbors, - float* distances) const override; - - [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override - { - return cagra_build_.get_sync_stream(); - } - - // to enable dataset access from GPU memory - AlgoProperty get_preference() const override - { - AlgoProperty property; - property.dataset_memory_type = MemoryType::HostMmap; - property.query_memory_type = MemoryType::Host; - return property; - } - - void save(const std::string& file) const override; - void load(const std::string&) override; - std::unique_ptr> copy() override - { - return std::make_unique>(*this); - } - - private: - RaftCagra cagra_build_; - HnswLib hnswlib_search_; -}; - -template -void RaftCagraHnswlib::build(const T* dataset, size_t nrow) -{ - cagra_build_.build(dataset, nrow); -} - -template -void RaftCagraHnswlib::set_search_param(const AnnSearchParam& param_) -{ - hnswlib_search_.set_search_param(param_); -} - -template -void RaftCagraHnswlib::save(const std::string& file) const -{ - cagra_build_.save_to_hnswlib(file); -} - -template -void RaftCagraHnswlib::load(const std::string& file) -{ - hnswlib_search_.load(file); - hnswlib_search_.set_base_layer_only(); -} - -template -void RaftCagraHnswlib::search( - const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const -{ - hnswlib_search_.search(queries, batch_size, k, neighbors, distances); -} - -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_int8_t.cu b/cpp/bench/ann/src/raft/raft_cagra_int8_t.cu deleted file mode 100644 index be3b83ee60..0000000000 --- a/cpp/bench/ann/src/raft/raft_cagra_int8_t.cu +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "raft_cagra_wrapper.h" - -namespace raft::bench::ann { -template class RaftCagra; -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_uint8_t.cu b/cpp/bench/ann/src/raft/raft_cagra_uint8_t.cu deleted file mode 100644 index c9679e404d..0000000000 --- a/cpp/bench/ann/src/raft/raft_cagra_uint8_t.cu +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "raft_cagra_wrapper.h" - -namespace raft::bench::ann { -template class RaftCagra; -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h deleted file mode 100644 index b03f875a8e..0000000000 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ /dev/null @@ -1,339 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../common/ann_types.hpp" -#include "../common/cuda_huge_page_resource.hpp" -#include "../common/cuda_pinned_resource.hpp" -#include "raft_ann_bench_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft::bench::ann { - -enum class AllocatorType { HostPinned, HostHugePage, Device }; -template -class RaftCagra : public ANN, public AnnGPU { - public: - using typename ANN::AnnSearchParam; - - struct SearchParam : public AnnSearchParam { - raft::neighbors::experimental::cagra::search_params p; - float refine_ratio; - AllocatorType graph_mem = AllocatorType::Device; - AllocatorType dataset_mem = AllocatorType::Device; - auto needs_dataset() const -> bool override { return true; } - }; - - struct BuildParam { - raft::neighbors::cagra::index_params cagra_params; - std::optional nn_descent_params = - std::nullopt; - std::optional ivf_pq_refine_rate = std::nullopt; - std::optional ivf_pq_build_params = std::nullopt; - std::optional ivf_pq_search_params = std::nullopt; - }; - - RaftCagra(Metric metric, - int dim, - const BuildParam& param, - int concurrent_searches = 1, - bool shall_include_dataset = false) - : ANN(metric, dim), - index_params_(param), - dimension_(dim), - need_dataset_update_(true), - shall_include_dataset_(shall_include_dataset), - dataset_(std::make_shared>( - std::move(make_device_matrix(handle_, 0, 0)))), - graph_(std::make_shared>( - std::move(make_device_matrix(handle_, 0, 0)))), - input_dataset_v_( - std::make_shared>(nullptr, 0, 0)), - graph_mem_(AllocatorType::Device), - dataset_mem_(AllocatorType::Device) - { - index_params_.cagra_params.metric = parse_metric_type(metric); - index_params_.ivf_pq_build_params->metric = parse_metric_type(metric); - } - - void build(const T* dataset, size_t nrow) final; - - void set_search_param(const AnnSearchParam& param) override; - - void set_search_dataset(const T* dataset, size_t nrow) override; - - void search(const T* queries, - int batch_size, - int k, - AnnBase::index_type* neighbors, - float* distances) const override; - void search_base(const T* queries, - int batch_size, - int k, - AnnBase::index_type* neighbors, - float* distances) const; - - [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override - { - return handle_.get_sync_stream(); - } - - // to enable dataset access from GPU memory - AlgoProperty get_preference() const override - { - AlgoProperty property; - property.dataset_memory_type = MemoryType::HostMmap; - property.query_memory_type = MemoryType::Device; - return property; - } - void save(const std::string& file) const override; - void load(const std::string&) override; - void save_to_hnswlib(const std::string& file) const; - std::unique_ptr> copy() override; - - private: - // handle_ must go first to make sure it dies last and all memory allocated in pool - configured_raft_resources handle_{}; - raft::mr::cuda_pinned_resource mr_pinned_; - raft::mr::cuda_huge_page_resource mr_huge_page_; - AllocatorType graph_mem_; - AllocatorType dataset_mem_; - float refine_ratio_; - BuildParam index_params_; - bool need_dataset_update_; - bool shall_include_dataset_; - raft::neighbors::cagra::search_params search_params_; - std::shared_ptr> index_; - int dimension_; - std::shared_ptr> graph_; - std::shared_ptr> dataset_; - std::shared_ptr> input_dataset_v_; - - inline rmm::device_async_resource_ref get_mr(AllocatorType mem_type) - { - switch (mem_type) { - case (AllocatorType::HostPinned): return &mr_pinned_; - case (AllocatorType::HostHugePage): return &mr_huge_page_; - default: return rmm::mr::get_current_device_resource(); - } - } -}; - -template -void RaftCagra::build(const T* dataset, size_t nrow) -{ - auto dataset_view = - raft::make_host_matrix_view(dataset, IdxT(nrow), dimension_); - - auto& params = index_params_.cagra_params; - - // Do include the compressed dataset for the CAGRA-Q - bool include_dataset = params.compression.has_value() || shall_include_dataset_; - - index_ = std::make_shared>( - std::move(raft::neighbors::cagra::detail::build(handle_, - params, - dataset_view, - index_params_.nn_descent_params, - index_params_.ivf_pq_refine_rate, - index_params_.ivf_pq_build_params, - index_params_.ivf_pq_search_params, - include_dataset))); -} - -inline std::string allocator_to_string(AllocatorType mem_type) -{ - if (mem_type == AllocatorType::Device) { - return "device"; - } else if (mem_type == AllocatorType::HostPinned) { - return "host_pinned"; - } else if (mem_type == AllocatorType::HostHugePage) { - return "host_huge_page"; - } - return ""; -} - -template -void RaftCagra::set_search_param(const AnnSearchParam& param) -{ - auto search_param = dynamic_cast(param); - search_params_ = search_param.p; - refine_ratio_ = search_param.refine_ratio; - if (search_param.graph_mem != graph_mem_) { - // Move graph to correct memory space - graph_mem_ = search_param.graph_mem; - RAFT_LOG_DEBUG("moving graph to new memory space: %s", allocator_to_string(graph_mem_).c_str()); - // We create a new graph and copy to it from existing graph - auto mr = get_mr(graph_mem_); - auto new_graph = make_device_mdarray( - handle_, mr, make_extents(index_->graph().extent(0), index_->graph_degree())); - - raft::copy(new_graph.data_handle(), - index_->graph().data_handle(), - index_->graph().size(), - resource::get_cuda_stream(handle_)); - - index_->update_graph(handle_, make_const_mdspan(new_graph.view())); - // update_graph() only stores a view in the index. We need to keep the graph object alive. - *graph_ = std::move(new_graph); - } - - if (search_param.dataset_mem != dataset_mem_ || need_dataset_update_) { - dataset_mem_ = search_param.dataset_mem; - - // First free up existing memory - *dataset_ = make_device_matrix(handle_, 0, 0); - index_->update_dataset(handle_, make_const_mdspan(dataset_->view())); - - // Allocate space using the correct memory resource. - RAFT_LOG_DEBUG("moving dataset to new memory space: %s", - allocator_to_string(dataset_mem_).c_str()); - - auto mr = get_mr(dataset_mem_); - raft::neighbors::cagra::detail::copy_with_padding(handle_, *dataset_, *input_dataset_v_, mr); - - auto dataset_view = raft::make_device_strided_matrix_view( - dataset_->data_handle(), dataset_->extent(0), this->dim_, dataset_->extent(1)); - index_->update_dataset(handle_, dataset_view); - - need_dataset_update_ = false; - } -} - -template -void RaftCagra::set_search_dataset(const T* dataset, size_t nrow) -{ - using ds_idx_type = decltype(index_->data().n_rows()); - bool is_vpq = - dynamic_cast*>(&index_->data()) || - dynamic_cast*>(&index_->data()); - // It can happen that we are re-using a previous algo object which already has - // the dataset set. Check if we need update. - if (static_cast(input_dataset_v_->extent(0)) != nrow || - input_dataset_v_->data_handle() != dataset) { - *input_dataset_v_ = make_device_matrix_view(dataset, nrow, this->dim_); - need_dataset_update_ = !is_vpq; // ignore update if this is a VPQ dataset. - } -} - -template -void RaftCagra::save(const std::string& file) const -{ - raft::neighbors::cagra::serialize(handle_, file, *index_); -} - -template -void RaftCagra::save_to_hnswlib(const std::string& file) const -{ - raft::neighbors::cagra::serialize_to_hnswlib(handle_, file, *index_); -} - -template -void RaftCagra::load(const std::string& file) -{ - index_ = std::make_shared>( - std::move(raft::neighbors::cagra::deserialize(handle_, file))); -} - -template -std::unique_ptr> RaftCagra::copy() -{ - return std::make_unique>(*this); // use copy constructor -} - -template -void RaftCagra::search_base( - const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const -{ - static_assert(std::is_integral_v); - static_assert(std::is_integral_v); - - IdxT* neighbors_IdxT; - std::optional> neighbors_storage{std::nullopt}; - if constexpr (sizeof(IdxT) == sizeof(AnnBase::index_type)) { - neighbors_IdxT = reinterpret_cast(neighbors); - } else { - neighbors_storage.emplace(batch_size * k, resource::get_cuda_stream(handle_)); - neighbors_IdxT = neighbors_storage->data(); - } - - auto queries_view = - raft::make_device_matrix_view(queries, batch_size, dimension_); - auto neighbors_view = raft::make_device_matrix_view(neighbors_IdxT, batch_size, k); - auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); - - raft::neighbors::cagra::search( - handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); - - if constexpr (sizeof(IdxT) != sizeof(AnnBase::index_type)) { - raft::linalg::unaryOp(neighbors, - neighbors_IdxT, - batch_size * k, - raft::cast_op(), - raft::resource::get_cuda_stream(handle_)); - } -} - -template -void RaftCagra::search( - const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const -{ - auto k0 = static_cast(refine_ratio_ * k); - const bool disable_refinement = k0 <= static_cast(k); - const raft::resources& res = handle_; - - if (disable_refinement) { - search_base(queries, batch_size, k, neighbors, distances); - } else { - auto queries_v = - raft::make_device_matrix_view(queries, batch_size, dimension_); - auto candidate_ixs = - raft::make_device_matrix(res, batch_size, k0); - auto candidate_dists = - raft::make_device_matrix(res, batch_size, k0); - search_base( - queries, batch_size, k0, candidate_ixs.data_handle(), candidate_dists.data_handle()); - refine_helper( - res, *input_dataset_v_, queries_v, candidate_ixs, k, neighbors, distances, index_->metric()); - } -} -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat.cu b/cpp/bench/ann/src/raft/raft_ivf_flat.cu deleted file mode 100644 index bcd23723a4..0000000000 --- a/cpp/bench/ann/src/raft/raft_ivf_flat.cu +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "raft_ivf_flat_wrapper.h" - -namespace raft::bench::ann { -template class RaftIvfFlatGpu; -template class RaftIvfFlatGpu; -template class RaftIvfFlatGpu; -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h deleted file mode 100644 index 83a3a63aba..0000000000 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ /dev/null @@ -1,165 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../common/ann_types.hpp" -#include "raft_ann_bench_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -namespace raft::bench::ann { - -template -class RaftIvfFlatGpu : public ANN, public AnnGPU { - public: - using typename ANN::AnnSearchParam; - - struct SearchParam : public AnnSearchParam { - raft::neighbors::ivf_flat::search_params ivf_flat_params; - }; - - using BuildParam = raft::neighbors::ivf_flat::index_params; - - RaftIvfFlatGpu(Metric metric, int dim, const BuildParam& param) - : ANN(metric, dim), index_params_(param), dimension_(dim) - { - index_params_.metric = parse_metric_type(metric); - index_params_.conservative_memory_allocation = true; - RAFT_CUDA_TRY(cudaGetDevice(&device_)); - } - - void build(const T* dataset, size_t nrow) final; - - void set_search_param(const AnnSearchParam& param) override; - - void search(const T* queries, - int batch_size, - int k, - AnnBase::index_type* neighbors, - float* distances) const override; - - [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override - { - return handle_.get_sync_stream(); - } - - // to enable dataset access from GPU memory - AlgoProperty get_preference() const override - { - AlgoProperty property; - property.dataset_memory_type = MemoryType::HostMmap; - property.query_memory_type = MemoryType::Device; - return property; - } - void save(const std::string& file) const override; - void load(const std::string&) override; - std::unique_ptr> copy() override; - - private: - // handle_ must go first to make sure it dies last and all memory allocated in pool - configured_raft_resources handle_{}; - BuildParam index_params_; - raft::neighbors::ivf_flat::search_params search_params_; - std::shared_ptr> index_; - int device_; - int dimension_; -}; - -template -void RaftIvfFlatGpu::build(const T* dataset, size_t nrow) -{ - index_ = std::make_shared>(std::move( - raft::neighbors::ivf_flat::build(handle_, index_params_, dataset, IdxT(nrow), dimension_))); -} - -template -void RaftIvfFlatGpu::set_search_param(const AnnSearchParam& param) -{ - auto search_param = dynamic_cast(param); - search_params_ = search_param.ivf_flat_params; - assert(search_params_.n_probes <= index_params_.n_lists); -} - -template -void RaftIvfFlatGpu::save(const std::string& file) const -{ - raft::neighbors::ivf_flat::serialize(handle_, file, *index_); - return; -} - -template -void RaftIvfFlatGpu::load(const std::string& file) -{ - index_ = std::make_shared>( - std::move(raft::neighbors::ivf_flat::deserialize(handle_, file))); - return; -} - -template -std::unique_ptr> RaftIvfFlatGpu::copy() -{ - return std::make_unique>(*this); // use copy constructor -} - -template -void RaftIvfFlatGpu::search( - const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const -{ - static_assert(std::is_integral_v); - static_assert(std::is_integral_v); - - IdxT* neighbors_IdxT; - std::optional> neighbors_storage{std::nullopt}; - if constexpr (sizeof(IdxT) == sizeof(AnnBase::index_type)) { - neighbors_IdxT = reinterpret_cast(neighbors); - } else { - neighbors_storage.emplace(batch_size * k, resource::get_cuda_stream(handle_)); - neighbors_IdxT = neighbors_storage->data(); - } - raft::neighbors::ivf_flat::search(handle_, - search_params_, - *index_, - queries, - batch_size, - k, - neighbors_IdxT, - distances, - resource::get_workspace_resource(handle_)); - if constexpr (sizeof(IdxT) != sizeof(AnnBase::index_type)) { - raft::linalg::unaryOp(neighbors, - neighbors_IdxT, - batch_size * k, - raft::cast_op(), - raft::resource::get_cuda_stream(handle_)); - } -} -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq.cu b/cpp/bench/ann/src/raft/raft_ivf_pq.cu deleted file mode 100644 index d4f68c1c7d..0000000000 --- a/cpp/bench/ann/src/raft/raft_ivf_pq.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "raft_ivf_pq_wrapper.h" - -namespace raft::bench::ann { -template class RaftIvfPQ; -template class RaftIvfPQ; -template class RaftIvfPQ; -template class RaftIvfPQ; -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h deleted file mode 100644 index 7201467969..0000000000 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ /dev/null @@ -1,203 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../common/ann_types.hpp" -#include "raft_ann_bench_utils.h" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::bench::ann { - -template -class RaftIvfPQ : public ANN, public AnnGPU { - public: - using typename ANN::AnnSearchParam; - using ANN::dim_; - - struct SearchParam : public AnnSearchParam { - raft::neighbors::ivf_pq::search_params pq_param; - float refine_ratio = 1.0f; - auto needs_dataset() const -> bool override { return refine_ratio > 1.0f; } - }; - - using BuildParam = raft::neighbors::ivf_pq::index_params; - - RaftIvfPQ(Metric metric, int dim, const BuildParam& param) - : ANN(metric, dim), index_params_(param), dimension_(dim) - { - index_params_.metric = parse_metric_type(metric); - } - - void build(const T* dataset, size_t nrow) final; - - void set_search_param(const AnnSearchParam& param) override; - void set_search_dataset(const T* dataset, size_t nrow) override; - - void search(const T* queries, - int batch_size, - int k, - AnnBase::index_type* neighbors, - float* distances) const override; - void search_base(const T* queries, - int batch_size, - int k, - AnnBase::index_type* neighbors, - float* distances) const; - - [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override - { - return handle_.get_sync_stream(); - } - - // to enable dataset access from GPU memory - AlgoProperty get_preference() const override - { - AlgoProperty property; - property.dataset_memory_type = MemoryType::Host; - property.query_memory_type = MemoryType::Device; - return property; - } - void save(const std::string& file) const override; - void load(const std::string&) override; - std::unique_ptr> copy() override; - - private: - // handle_ must go first to make sure it dies last and all memory allocated in pool - configured_raft_resources handle_{}; - BuildParam index_params_; - raft::neighbors::ivf_pq::search_params search_params_; - std::shared_ptr> index_; - int dimension_; - float refine_ratio_ = 1.0; - raft::device_matrix_view dataset_; -}; - -template -void RaftIvfPQ::save(const std::string& file) const -{ - raft::neighbors::ivf_pq::serialize(handle_, file, *index_); -} - -template -void RaftIvfPQ::load(const std::string& file) -{ - index_ = std::make_shared>( - std::move(raft::neighbors::ivf_pq::deserialize(handle_, file))); -} - -template -void RaftIvfPQ::build(const T* dataset, size_t nrow) -{ - auto dataset_v = raft::make_device_matrix_view(dataset, IdxT(nrow), dim_); - std::make_shared>( - std::move(raft::neighbors::ivf_pq::build(handle_, index_params_, dataset_v))) - .swap(index_); -} - -template -std::unique_ptr> RaftIvfPQ::copy() -{ - return std::make_unique>(*this); // use copy constructor -} - -template -void RaftIvfPQ::set_search_param(const AnnSearchParam& param) -{ - auto search_param = dynamic_cast(param); - search_params_ = search_param.pq_param; - refine_ratio_ = search_param.refine_ratio; - assert(search_params_.n_probes <= index_params_.n_lists); -} - -template -void RaftIvfPQ::set_search_dataset(const T* dataset, size_t nrow) -{ - dataset_ = raft::make_device_matrix_view(dataset, nrow, index_->dim()); -} - -template -void RaftIvfPQ::search_base( - const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const -{ - static_assert(std::is_integral_v); - static_assert(std::is_integral_v); - - IdxT* neighbors_IdxT; - std::optional> neighbors_storage{std::nullopt}; - if constexpr (sizeof(IdxT) == sizeof(AnnBase::index_type)) { - neighbors_IdxT = reinterpret_cast(neighbors); - } else { - neighbors_storage.emplace(batch_size * k, resource::get_cuda_stream(handle_)); - neighbors_IdxT = neighbors_storage->data(); - } - - auto queries_view = - raft::make_device_matrix_view(queries, batch_size, dimension_); - auto neighbors_view = - raft::make_device_matrix_view(neighbors_IdxT, batch_size, k); - auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); - - raft::neighbors::ivf_pq::search( - handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); - - if constexpr (sizeof(IdxT) != sizeof(AnnBase::index_type)) { - raft::linalg::unaryOp(neighbors, - neighbors_IdxT, - batch_size * k, - raft::cast_op(), - raft::resource::get_cuda_stream(handle_)); - } -} - -template -void RaftIvfPQ::search( - const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const -{ - auto k0 = static_cast(refine_ratio_ * k); - const bool disable_refinement = k0 <= static_cast(k); - const raft::resources& res = handle_; - - if (disable_refinement) { - search_base(queries, batch_size, k, neighbors, distances); - } else { - auto queries_v = - raft::make_device_matrix_view(queries, batch_size, dimension_); - auto candidate_ixs = - raft::make_device_matrix(res, batch_size, k0); - auto candidate_dists = - raft::make_device_matrix(res, batch_size, k0); - search_base( - queries, batch_size, k0, candidate_ixs.data_handle(), candidate_dists.data_handle()); - refine_helper( - res, dataset_, queries_v, candidate_ixs, k, neighbors, distances, index_->metric()); - } -} -} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_wrapper.h b/cpp/bench/ann/src/raft/raft_wrapper.h deleted file mode 100644 index 2c996058b2..0000000000 --- a/cpp/bench/ann/src/raft/raft_wrapper.h +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../common/ann_types.hpp" -#include "raft_ann_bench_utils.h" - -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace raft_temp { - -inline raft::distance::DistanceType parse_metric_type(raft::bench::ann::Metric metric) -{ - switch (metric) { - case raft::bench::ann::Metric::kInnerProduct: return raft::distance::DistanceType::InnerProduct; - case raft::bench::ann::Metric::kEuclidean: return raft::distance::DistanceType::L2Expanded; - default: throw std::runtime_error("raft supports only metric type of inner product and L2"); - } -} -} // namespace raft_temp - -namespace raft::bench::ann { - -// brute force KNN - RAFT -template -class RaftGpu : public ANN, public AnnGPU { - public: - using typename ANN::AnnSearchParam; - - RaftGpu(Metric metric, int dim); - - void build(const T*, size_t) final; - - void set_search_param(const AnnSearchParam& param) override; - - void search(const T* queries, - int batch_size, - int k, - AnnBase::index_type* neighbors, - float* distances) const final; - - // to enable dataset access from GPU memory - AlgoProperty get_preference() const override - { - AlgoProperty property; - property.dataset_memory_type = MemoryType::Device; - property.query_memory_type = MemoryType::Device; - return property; - } - [[nodiscard]] auto get_sync_stream() const noexcept -> cudaStream_t override - { - return handle_.get_sync_stream(); - } - void set_search_dataset(const T* dataset, size_t nrow) override; - void save(const std::string& file) const override; - void load(const std::string&) override; - std::unique_ptr> copy() override; - - protected: - // handle_ must go first to make sure it dies last and all memory allocated in pool - configured_raft_resources handle_{}; - std::shared_ptr> index_; - raft::distance::DistanceType metric_type_; - int device_; - const T* dataset_; - size_t nrow_; -}; - -template -RaftGpu::RaftGpu(Metric metric, int dim) - : ANN(metric, dim), metric_type_(raft_temp::parse_metric_type(metric)) -{ - static_assert(std::is_same_v || std::is_same_v, - "raft bfknn only supports float/double"); - RAFT_CUDA_TRY(cudaGetDevice(&device_)); -} - -template -void RaftGpu::build(const T* dataset, size_t nrow) -{ - auto dataset_view = raft::make_host_matrix_view(dataset, nrow, this->dim_); - index_ = std::make_shared>( - std::move(raft::neighbors::brute_force::build(handle_, dataset_view))); -} - -template -void RaftGpu::set_search_param(const AnnSearchParam&) -{ - // Nothing to set here as it is brute force implementation -} - -template -void RaftGpu::set_search_dataset(const T* dataset, size_t nrow) -{ - dataset_ = dataset; - nrow_ = nrow; -} - -template -void RaftGpu::save(const std::string& file) const -{ - raft::neighbors::brute_force::serialize(handle_, file, *index_); -} - -template -void RaftGpu::load(const std::string& file) -{ - index_ = std::make_shared>( - std::move(raft::neighbors::brute_force::deserialize(handle_, file))); -} - -template -void RaftGpu::search( - const T* queries, int batch_size, int k, AnnBase::index_type* neighbors, float* distances) const -{ - auto queries_view = - raft::make_device_matrix_view(queries, batch_size, this->dim_); - - auto neighbors_view = - raft::make_device_matrix_view(neighbors, batch_size, k); - auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); - - raft::neighbors::brute_force::search( - handle_, *index_, queries_view, neighbors_view, distances_view); -} - -template -std::unique_ptr> RaftGpu::copy() -{ - return std::make_unique>(*this); // use copy constructor -} - -} // namespace raft::bench::ann diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 52c63ad73b..6bc8c802b4 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -74,49 +74,9 @@ function(ConfigureBench) endfunction() if(BUILD_PRIMS_BENCH) - ConfigureBench( - NAME - CORE_BENCH - PATH - core/bitset.cu - core/copy.cu - main.cpp - ) - - ConfigureBench( - NAME - UTIL_BENCH - PATH - util/popc.cu - main.cpp - ) - - ConfigureBench( - NAME CLUSTER_BENCH PATH cluster/kmeans_balanced.cu cluster/kmeans.cu - main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY - ) - - ConfigureBench( - NAME TUNE_DISTANCE PATH distance/tune_pairwise/kernel.cu - distance/tune_pairwise/bench.cu main.cpp - ) + ConfigureBench(NAME CORE_BENCH PATH core/bitset.cu core/copy.cu main.cpp) - ConfigureBench( - NAME - DISTANCE_BENCH - PATH - distance/distance_cosine.cu - distance/distance_exp_l2.cu - distance/distance_l1.cu - distance/distance_unexp_l2.cu - distance/fused_l2_nn.cu - distance/masked_nn.cu - distance/kernels.cu - main.cpp - OPTIONAL - LIB - EXPLICIT_INSTANTIATE_ONLY - ) + ConfigureBench(NAME UTIL_BENCH PATH util/popc.cu main.cpp) ConfigureBench( NAME @@ -137,54 +97,15 @@ if(BUILD_PRIMS_BENCH) ) ConfigureBench( - NAME MATRIX_BENCH PATH matrix/argmin.cu matrix/gather.cu - matrix/select_k.cu main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY - ) - - ConfigureBench( - NAME RANDOM_BENCH PATH random/make_blobs.cu random/permute.cu - random/rng.cu random/subsample.cu main.cpp + NAME MATRIX_BENCH PATH matrix/argmin.cu matrix/select_k.cu matrix/gather.cu main.cpp OPTIONAL + LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( - NAME - SPARSE_BENCH - PATH - sparse/bitmap_to_csr.cu - sparse/convert_csr.cu - sparse/select_k_csr.cu + NAME RANDOM_BENCH PATH random/make_blobs.cu random/permute.cu random/rng.cu random/subsample.cu main.cpp ) - ConfigureBench( - NAME - NEIGHBORS_BENCH - PATH - neighbors/knn/brute_force_float_int64_t.cu - neighbors/knn/brute_force_float_uint32_t.cu - neighbors/knn/cagra_float_uint32_t.cu - neighbors/knn/ivf_flat_filter_float_int64_t.cu - neighbors/knn/ivf_flat_float_int64_t.cu - neighbors/knn/ivf_flat_int8_t_int64_t.cu - neighbors/knn/ivf_flat_uint8_t_int64_t.cu - neighbors/knn/ivf_pq_float_int64_t.cu - neighbors/knn/ivf_pq_filter_float_int64_t.cu - neighbors/knn/ivf_pq_int8_t_int64_t.cu - neighbors/knn/ivf_pq_uint8_t_int64_t.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_search_filtering_float_int64_t.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_float_bitset64.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false_bitset64.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true_bitset64.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_float_half_bitset64.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false_bitset64.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true_bitset64.cu - ${RAFT_SOURCE_DIR}/src/neighbors/detail/ivf_pq_compute_similarity_half_half_bitset64.cu - neighbors/refine_float_int64_t.cu - neighbors/refine_uint8_t_int64_t.cu - main.cpp - OPTIONAL - LIB - EXPLICIT_INSTANTIATE_ONLY - ) + ConfigureBench(NAME SPARSE_BENCH PATH sparse/bitmap_to_csr.cu sparse/convert_csr.cu main.cpp) endif() diff --git a/cpp/bench/prims/cluster/kmeans.cu b/cpp/bench/prims/cluster/kmeans.cu deleted file mode 100644 index 6387211135..0000000000 --- a/cpp/bench/prims/cluster/kmeans.cu +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include - -namespace raft::bench::cluster { - -struct KMeansBenchParams { - DatasetParams data; - BlobsParams blobs; - raft::cluster::KMeansParams kmeans; -}; - -inline auto operator<<(std::ostream& os, const KMeansBenchParams& p) -> std::ostream& -{ - os << p.data.rows << "#" << p.data.cols << "#" << p.kmeans.n_clusters; - return os; -} - -template -struct KMeans : public BlobsFixture { - KMeans(const KMeansBenchParams& p) - : BlobsFixture(p.data, p.blobs), - params(p), - centroids(this->handle), - labels(this->handle) - { - } - - void run_benchmark(::benchmark::State& state) override - { - std::ostringstream label_stream; - label_stream << params; - state.SetLabel(label_stream.str()); - - raft::device_matrix_view X_view = this->X.view(); - std::optional> opt_weights_view = std::nullopt; - std::optional> centroids_view = - std::make_optional>(centroids.view()); - raft::device_vector_view labels_view = labels.view(); - raft::host_scalar_view inertia_view = raft::make_host_scalar_view(&inertia); - raft::host_scalar_view n_iter_view = raft::make_host_scalar_view(&n_iter); - - this->loop_on_state(state, [&]() { - raft::cluster::kmeans_fit_predict(this->handle, - params.kmeans, - X_view, - opt_weights_view, - centroids_view, - labels_view, - inertia_view, - n_iter_view); - }); - } - - void allocate_temp_buffers(const ::benchmark::State& state) override - { - centroids = - raft::make_device_matrix(this->handle, params.kmeans.n_clusters, params.data.cols); - labels = raft::make_device_vector(this->handle, params.data.rows); - } - - private: - KMeansBenchParams params; - raft::device_matrix centroids; - raft::device_vector labels; - T inertia; - IndexT n_iter; -}; // struct KMeans - -std::vector getKMeansInputs() -{ - std::vector out; - KMeansBenchParams p; - p.data.row_major = true; - p.blobs.cluster_std = 1.0; - p.blobs.shuffle = false; - p.blobs.center_box_min = -10.0; - p.blobs.center_box_max = 10.0; - p.blobs.seed = 12345ULL; - p.kmeans.init = raft::cluster::KMeansParams::KMeansPlusPlus; - p.kmeans.max_iter = 300; - p.kmeans.tol = 1e-4; - p.kmeans.verbosity = RAFT_LEVEL_INFO; - p.kmeans.metric = raft::distance::DistanceType::L2Expanded; - p.kmeans.inertia_check = true; - std::vector> row_cols_k = { - {1000000, 20, 1000}, - {3000000, 50, 20}, - {10000000, 50, 5}, - }; - for (auto& rck : row_cols_k) { - p.data.rows = std::get<0>(rck); - p.data.cols = std::get<1>(rck); - p.blobs.n_clusters = std::get<2>(rck); - p.kmeans.n_clusters = std::get<2>(rck); - out.push_back(p); - } - return out; -} - -// note(lsugy): commenting out int64_t because the templates are not compiled in the distance -// library, resulting in long compilation times. -RAFT_BENCH_REGISTER((KMeans), "", getKMeansInputs()); -RAFT_BENCH_REGISTER((KMeans), "", getKMeansInputs()); -// RAFT_BENCH_REGISTER((KMeans), "", getKMeansInputs()); -// RAFT_BENCH_REGISTER((KMeans), "", getKMeansInputs()); - -} // namespace raft::bench::cluster diff --git a/cpp/bench/prims/cluster/kmeans_balanced.cu b/cpp/bench/prims/cluster/kmeans_balanced.cu deleted file mode 100644 index dc05783989..0000000000 --- a/cpp/bench/prims/cluster/kmeans_balanced.cu +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include - -namespace raft::bench::cluster { - -struct KMeansBalancedBenchParams { - DatasetParams data; - uint32_t n_lists; - raft::cluster::kmeans_balanced_params kb_params; -}; - -template -struct KMeansBalanced : public fixture { - KMeansBalanced(const KMeansBalancedBenchParams& p) : params(p), X(handle), centroids(handle) {} - - void run_benchmark(::benchmark::State& state) override - { - this->loop_on_state(state, [this]() { - raft::device_matrix_view X_view = this->X.view(); - raft::device_matrix_view centroids_view = this->centroids.view(); - raft::cluster::kmeans_balanced::fit( - this->handle, this->params.kb_params, X_view, centroids_view); - }); - } - - void allocate_data(const ::benchmark::State& state) override - { - X = raft::make_device_matrix(handle, params.data.rows, params.data.cols); - - raft::random::RngState rng{1234}; - constexpr T kRangeMax = std::is_integral_v ? std::numeric_limits::max() : T(1); - constexpr T kRangeMin = std::is_integral_v ? std::numeric_limits::min() : T(-1); - if constexpr (std::is_integral_v) { - raft::random::uniformInt( - handle, rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax); - } else { - raft::random::uniform( - handle, rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax); - } - resource::sync_stream(handle, stream); - } - - void allocate_temp_buffers(const ::benchmark::State& state) override - { - centroids = - raft::make_device_matrix(this->handle, params.n_lists, params.data.cols); - } - - private: - KMeansBalancedBenchParams params; - raft::device_matrix X; - raft::device_matrix centroids; -}; // struct KMeansBalanced - -std::vector getKMeansBalancedInputs() -{ - std::vector out; - KMeansBalancedBenchParams p; - p.data.row_major = true; - p.kb_params.n_iters = 20; - p.kb_params.metric = raft::distance::DistanceType::L2Expanded; - std::vector> row_cols = { - {100000, 128}, {1000000, 128}, {10000000, 128}, - // The following dataset sizes are too large for most GPUs. - // {100000000, 128}, - }; - for (auto& rc : row_cols) { - p.data.rows = rc.first; - p.data.cols = rc.second; - for (auto n_lists : std::vector({1000, 10000, 100000})) { - p.n_lists = n_lists; - out.push_back(p); - } - } - return out; -} - -// Note: the datasets sizes are too large for 32-bit index types. -RAFT_BENCH_REGISTER((KMeansBalanced), "", getKMeansBalancedInputs()); - -} // namespace raft::bench::cluster diff --git a/cpp/bench/prims/distance/distance_common.cuh b/cpp/bench/prims/distance/distance_common.cuh deleted file mode 100644 index 8368062168..0000000000 --- a/cpp/bench/prims/distance/distance_common.cuh +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include - -#include - -namespace raft::bench::distance { - -struct distance_params { - int m, n, k; - bool isRowMajor; -}; // struct distance_params - -template -struct distance : public fixture { - distance(const distance_params& p) - : params(p), - x(p.m * p.k, stream), - y(p.n * p.k, stream), - out(p.m * p.n, stream), - workspace(0, stream) - { - RAFT_CUDA_TRY(cudaMemsetAsync(x.data(), 0, x.size() * sizeof(T), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(y.data(), 0, y.size() * sizeof(T), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(out.data(), 0, out.size() * sizeof(T), stream)); - worksize = raft::distance::getWorkspaceSize( - x.data(), y.data(), params.m, params.n, params.k); - workspace.resize(worksize, stream); - } - - void run_benchmark(::benchmark::State& state) override - { - loop_on_state(state, [this]() { - raft::distance::distance(handle, - x.data(), - y.data(), - out.data(), - params.m, - params.n, - params.k, - (void*)workspace.data(), - worksize, - params.isRowMajor); - }); - } - - private: - distance_params params; - rmm::device_uvector x, y, out; - rmm::device_uvector workspace; - size_t worksize; -}; // struct Distance - -const std::vector dist_input_vecs{ - {32, 16384, 16384, true}, {64, 16384, 16384, true}, {128, 16384, 16384, true}, - {256, 16384, 16384, true}, {512, 16384, 16384, true}, {1024, 16384, 16384, true}, - {16384, 32, 16384, true}, {16384, 64, 16384, true}, {16384, 128, 16384, true}, - {16384, 256, 16384, true}, {16384, 512, 16384, true}, {16384, 1024, 16384, true}, - {16384, 16384, 32, true}, {16384, 16384, 64, true}, {16384, 16384, 128, true}, - {16384, 16384, 256, true}, {16384, 16384, 512, true}, {16384, 16384, 1024, true}, - {16384, 16384, 16384, true}, {32, 16384, 16384, false}, {64, 16384, 16384, false}, - {128, 16384, 16384, false}, {256, 16384, 16384, false}, {512, 16384, 16384, false}, - {1024, 16384, 16384, false}, {16384, 32, 16384, false}, {16384, 64, 16384, false}, - {16384, 128, 16384, false}, {16384, 256, 16384, false}, {16384, 512, 16384, false}, - {16384, 1024, 16384, false}, {16384, 16384, 32, false}, {16384, 16384, 64, false}, - {16384, 16384, 128, false}, {16384, 16384, 256, false}, {16384, 16384, 512, false}, - {16384, 16384, 1024, false}, {16384, 16384, 16384, false} - -}; - -#define DIST_BENCH_REGISTER(Name, Metric) \ - using Name##F = distance; \ - RAFT_BENCH_REGISTER(Name##F, "", dist_input_vecs); \ - using Name##D = distance; \ - RAFT_BENCH_REGISTER(Name##D, "", dist_input_vecs); - -} // namespace raft::bench::distance diff --git a/cpp/bench/prims/distance/distance_cosine.cu b/cpp/bench/prims/distance/distance_cosine.cu deleted file mode 100644 index c8ac8067c8..0000000000 --- a/cpp/bench/prims/distance/distance_cosine.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "distance_common.cuh" - -namespace raft::bench::distance { - -DIST_BENCH_REGISTER(DistanceCosine, raft::distance::DistanceType::CosineExpanded); - -} // namespace raft::bench::distance diff --git a/cpp/bench/prims/distance/distance_exp_l2.cu b/cpp/bench/prims/distance/distance_exp_l2.cu deleted file mode 100644 index 52b7fff05c..0000000000 --- a/cpp/bench/prims/distance/distance_exp_l2.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "distance_common.cuh" - -namespace raft::bench::distance { - -DIST_BENCH_REGISTER(DistanceL2Sq, raft::distance::DistanceType::L2Expanded); -DIST_BENCH_REGISTER(DistanceL2Sqrt, raft::distance::DistanceType::L2SqrtExpanded); - -} // namespace raft::bench::distance diff --git a/cpp/bench/prims/distance/distance_l1.cu b/cpp/bench/prims/distance/distance_l1.cu deleted file mode 100644 index e80db48ef0..0000000000 --- a/cpp/bench/prims/distance/distance_l1.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "distance_common.cuh" - -namespace raft::bench::distance { - -DIST_BENCH_REGISTER(DistanceL1, raft::distance::DistanceType::L1); - -} // namespace raft::bench::distance diff --git a/cpp/bench/prims/distance/distance_unexp_l2.cu b/cpp/bench/prims/distance/distance_unexp_l2.cu deleted file mode 100644 index 7ac1a8a4b5..0000000000 --- a/cpp/bench/prims/distance/distance_unexp_l2.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "distance_common.cuh" - -namespace raft::bench::distance { - -DIST_BENCH_REGISTER(DistanceUnexpL2Sq, raft::distance::DistanceType::L2Unexpanded); -DIST_BENCH_REGISTER(DistanceUnexpL2Sqrt, raft::distance::DistanceType::L2SqrtUnexpanded); - -} // namespace raft::bench::distance diff --git a/cpp/bench/prims/distance/fused_l2_nn.cu b/cpp/bench/prims/distance/fused_l2_nn.cu deleted file mode 100644 index a263bef6ba..0000000000 --- a/cpp/bench/prims/distance/fused_l2_nn.cu +++ /dev/null @@ -1,163 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include - -#include - -namespace raft::bench::distance { - -struct fusedl2nn_inputs { - int64_t m, n, k; -}; // struct fusedl2nn_inputs - -inline auto operator<<(std::ostream& os, const fusedl2nn_inputs& p) -> std::ostream& -{ - os << p.m << "#" << p.n << "#" << p.k; - return os; -} - -template -struct fusedl2nn : public fixture { - fusedl2nn(const fusedl2nn_inputs& p) - : params(p), - workspace(this->handle), - x(this->handle), - y(this->handle), - x_norm(this->handle), - y_norm(this->handle), - out(this->handle) - { - } - - void allocate_data(const ::benchmark::State& state) override - { - x = raft::make_device_matrix(handle, params.m, params.k); - y = raft::make_device_matrix(handle, params.n, params.k); - x_norm = raft::make_device_vector(handle, params.m); - y_norm = raft::make_device_vector(handle, params.n); - out = raft::make_device_vector(handle, params.m); - - raft::random::RngState rng{1234}; - raft::random::uniform( - handle, rng, x.data_handle(), params.m * params.k, (DataT)-1.0, (DataT)1.0); - raft::random::uniform( - handle, rng, y.data_handle(), params.n * params.k, (DataT)-1.0, (DataT)1.0); - - // Pre-compute norms - raft::linalg::rowNorm(x_norm.data_handle(), - x.data_handle(), - params.k, - params.m, - raft::linalg::L2Norm, - true, - stream); - raft::linalg::rowNorm(y_norm.data_handle(), - y.data_handle(), - params.k, - params.n, - raft::linalg::L2Norm, - true, - stream); - resource::sync_stream(handle, stream); - } - - void allocate_temp_buffers(const ::benchmark::State& state) override - { - workspace = raft::make_device_vector(handle, params.m * sizeof(IdxT)); - } - - void run_benchmark(::benchmark::State& state) override - { - std::ostringstream label_stream; - label_stream << params; - state.SetLabel(label_stream.str()); - - loop_on_state(state, [this]() { - raft::distance::fusedL2NNMinReduce(out.data_handle(), - x.data_handle(), - y.data_handle(), - x_norm.data_handle(), - y_norm.data_handle(), - static_cast(params.m), - static_cast(params.n), - static_cast(params.k), - (void*)workspace.data_handle(), - false, - true, - stream); - }); - - int64_t num_flops = 2 * params.m * params.n * params.k; - - int64_t read_elts = params.n * params.k + params.m * params.k; - int64_t write_elts = params.m; - - state.counters["FLOP/s"] = benchmark::Counter( - num_flops, benchmark::Counter::kIsIterationInvariantRate, benchmark::Counter::OneK::kIs1000); - - state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(OutT), - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(DataT), - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - } - - private: - fusedl2nn_inputs params; - raft::device_matrix x, y; - raft::device_vector x_norm, y_norm; - raft::device_vector out; - raft::device_vector workspace; -}; // struct fusedl2nn - -template -std::vector getFusedL2NNInputs() -{ - std::vector inputs; - std::vector m_list = {100000, 1000000}; - if constexpr (sizeof(IdxT) == 8) { m_list.push_back(10000000); } - std::vector n_list = {100, 1000, 10000}; - std::vector k_list = {64, 128, 256}; - for (auto m : m_list) { - for (auto n : n_list) { - for (auto k : k_list) { - inputs.push_back({m, n, k}); - } - } - } - return inputs; -} - -#define FUSEDL2NN_BENCH(DataT, IdxT, OutT) \ - RAFT_BENCH_REGISTER((fusedl2nn), "", getFusedL2NNInputs()) - -FUSEDL2NN_BENCH(float, int, float); -FUSEDL2NN_BENCH(double, int, double); -FUSEDL2NN_BENCH(float, int, (raft::KeyValuePair)); -FUSEDL2NN_BENCH(double, int, (raft::KeyValuePair)); -FUSEDL2NN_BENCH(float, int64_t, float); -FUSEDL2NN_BENCH(double, int64_t, double); -FUSEDL2NN_BENCH(float, int64_t, (raft::KeyValuePair)); -FUSEDL2NN_BENCH(double, int64_t, (raft::KeyValuePair)); - -} // namespace raft::bench::distance diff --git a/cpp/bench/prims/distance/kernels.cu b/cpp/bench/prims/distance/kernels.cu deleted file mode 100644 index eb86330637..0000000000 --- a/cpp/bench/prims/distance/kernels.cu +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include - -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace raft::bench::distance::kernels { - -using namespace raft::distance::kernels; -struct GramTestParams { - int m; // m parameter of the GEMM - int k; // k parameter of the GEMM - int n; // n parameter of the GEMM - KernelParams kernel_params; - bool is_row_major; -}; // struct GramTestParams - -template -struct GramMatrix : public fixture { - GramMatrix(const GramTestParams& p) - : params(p), handle(stream), A(0, stream), B(0, stream), C(0, stream) - { - kernel = std::unique_ptr>( - KernelFactory::create(p.kernel_params, resource::get_cublas_handle(handle))); - - A.resize(params.m * params.k, stream); - B.resize(params.k * params.n, stream); - C.resize(params.m * params.n, stream); - raft::random::RngState rng(123456ULL); - raft::random::uniform(handle, rng, A.data(), params.m * params.k, T(-1.0), T(1.0)); - raft::random::uniform(handle, rng, B.data(), params.k * params.n, T(-1.0), T(1.0)); - } - - ~GramMatrix() - { - A.release(); - B.release(); - C.release(); - } - - void run_benchmark(::benchmark::State& state) override - { - if (!this->kernel) { state.SkipWithError("Kernel matrix is not initialized"); } - loop_on_state(state, [this]() { - (*this->kernel)(A.data(), - this->params.m, - this->params.k, - B.data(), - this->params.n, - C.data(), - this->params.is_row_major, - this->stream); - }); - } - - private: - const raft::device_resources handle; - std::unique_ptr> kernel; - GramTestParams params; - - rmm::device_uvector A; // input matrix A, size [m * k] - rmm::device_uvector B; // input matrix B, size [n * k] - rmm::device_uvector C; // output matrix C, size [m*n] -}; - -static std::vector getInputs() -{ - std::vector param_vec; - std::vector kernel_params{KernelParams{LINEAR, 3, 1, 0}, - KernelParams{POLYNOMIAL, 2, 1.3, 1}, - KernelParams{TANH, 2, 0.5, 2.4}, - KernelParams{RBF, 2, 0.5, 0}}; - struct TestSize { - int m; - int k; - int n; - }; - std::vector data_size{{4096, 10, 1024}, - {4096, 100, 1024}, - {4096, 1000, 1024}, - {4096, 10000, 1024}, - {100000, 10, 1024}, - {100000, 100, 1024}, - {100000, 1000, 1024}}; - - param_vec.reserve(kernel_params.size() * data_size.size()); - for (TestSize s : data_size) { - for (auto kernel : kernel_params) { - for (bool row_major : {false, true}) { - param_vec.push_back(GramTestParams{s.m, s.k, s.n, kernel, row_major}); - } - } - } - return param_vec; -} - -RAFT_BENCH_REGISTER(GramMatrix, "", getInputs()); -RAFT_BENCH_REGISTER(GramMatrix, "", getInputs()); - -} // namespace raft::bench::distance::kernels diff --git a/cpp/bench/prims/distance/masked_nn.cu b/cpp/bench/prims/distance/masked_nn.cu deleted file mode 100644 index 979d438b67..0000000000 --- a/cpp/bench/prims/distance/masked_nn.cu +++ /dev/null @@ -1,264 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace raft::bench::distance::masked_nn { - -// Introduce various sparsity patterns -enum AdjacencyPattern { - checkerboard = 0, - checkerboard_4 = 1, - checkerboard_64 = 2, - all_true = 3, - all_false = 4 -}; - -struct Params { - int m, n, k, num_groups; - AdjacencyPattern pattern; -}; // struct Params - -RAFT_KERNEL init_adj(AdjacencyPattern pattern, - int n, - raft::device_matrix_view adj, - raft::device_vector_view group_idxs) -{ - int m = adj.extent(0); - int num_groups = adj.extent(1); - - for (int idx_m = blockIdx.y * blockDim.y + threadIdx.y; idx_m < m; - idx_m += blockDim.y * gridDim.y) { - for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; - idx_g += blockDim.x * gridDim.x) { - switch (pattern) { - case checkerboard: adj(idx_m, idx_g) = (idx_m + idx_g) % 2; break; - case checkerboard_4: adj(idx_m, idx_g) = (idx_m / 4 + idx_g) % 2; break; - case checkerboard_64: adj(idx_m, idx_g) = (idx_m / 64 + idx_g) % 2; break; - case all_true: adj(idx_m, idx_g) = true; break; - case all_false: adj(idx_m, idx_g) = false; break; - default: assert(false && "unknown pattern"); - } - } - } - // Each group is of size n / num_groups. - // - // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive - // scan of the group lengths) - // - // - The first group always starts at index zero, so we do not store it. - // - // - The group_idxs[num_groups - 1] should always equal n. - - if (blockIdx.y == 0 && threadIdx.y == 0) { - const int g_stride = blockDim.x * gridDim.x; - for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; idx_g += g_stride) { - group_idxs(idx_g) = (idx_g + 1) * (n / num_groups); - } - group_idxs(num_groups - 1) = n; - } -} - -template -struct masked_l2_nn : public fixture { - using DataT = T; - using IdxT = int; - using OutT = raft::KeyValuePair; - using RedOpT = raft::distance::MinAndDistanceReduceOp; - using PairRedOpT = raft::distance::KVPMinReduce; - using ParamT = raft::distance::masked_l2_nn_params; - - // Parameters - Params params; - // Data - raft::device_vector out; - raft::device_matrix x, y; - raft::device_vector xn, yn; - raft::device_matrix adj; - raft::device_vector group_idxs; - - masked_l2_nn(const Params& p) - : params(p), - out{raft::make_device_vector(handle, p.m)}, - x{raft::make_device_matrix(handle, p.m, p.k)}, - y{raft::make_device_matrix(handle, p.n, p.k)}, - xn{raft::make_device_vector(handle, p.m)}, - yn{raft::make_device_vector(handle, p.n)}, - adj{raft::make_device_matrix(handle, p.m, p.num_groups)}, - group_idxs{raft::make_device_vector(handle, p.num_groups)} - { - raft::random::RngState r(123456ULL); - - uniform(handle, r, x.data_handle(), p.m * p.k, T(-1.0), T(1.0)); - uniform(handle, r, y.data_handle(), p.n * p.k, T(-1.0), T(1.0)); - raft::linalg::rowNorm( - xn.data_handle(), x.data_handle(), p.k, p.m, raft::linalg::L2Norm, true, stream); - raft::linalg::rowNorm( - yn.data_handle(), y.data_handle(), p.k, p.n, raft::linalg::L2Norm, true, stream); - raft::distance::initialize, int>( - handle, out.data_handle(), p.m, std::numeric_limits::max(), RedOpT{}); - - dim3 block(32, 32); - dim3 grid(10, 10); - init_adj<<>>(p.pattern, p.n, adj.view(), group_idxs.view()); - RAFT_CUDA_TRY(cudaGetLastError()); - } - - void run_benchmark(::benchmark::State& state) override - { - bool init_out = true; - bool sqrt = false; - ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; - - loop_on_state(state, [this, masked_l2_params]() { - // It is sufficient to only benchmark the L2-squared metric - raft::distance::masked_l2_nn(handle, - masked_l2_params, - x.view(), - y.view(), - xn.view(), - yn.view(), - adj.view(), - group_idxs.view(), - out.view()); - }); - - // Virtual flop count if no skipping had occurred. - size_t virtual_flops = size_t(2) * size_t(params.m) * size_t(params.n) * size_t(params.k); - - int64_t read_elts = params.n * params.k + params.m * params.k; - int64_t write_elts = params.m; - - // Virtual min flops is the number of flops that would have been executed if - // the algorithm had actually skipped each computation that it could have - // skipped. - size_t virtual_min_flops = 0; - switch (params.pattern) { - case checkerboard: - case checkerboard_4: - case checkerboard_64: virtual_min_flops = virtual_flops / 2; break; - case all_true: virtual_min_flops = virtual_flops; break; - case all_false: virtual_min_flops = 0; break; - default: assert(false && "unknown pattern"); - } - - // VFLOP/s is the "virtual" flop count that would have executed if there was - // no adjacency pattern. This is useful for comparing to fusedL2NN - state.counters["VFLOP/s"] = benchmark::Counter(virtual_flops, - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - // Virtual min flops is the number of flops that would have been executed if - // the algorithm had actually skipped each computation that it could have - // skipped. - state.counters["VminFLOP/s"] = benchmark::Counter(virtual_min_flops, - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - - state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(OutT), - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(DataT), - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - - state.counters["m"] = benchmark::Counter(params.m); - state.counters["n"] = benchmark::Counter(params.n); - state.counters["k"] = benchmark::Counter(params.k); - state.counters["num_groups"] = benchmark::Counter(params.num_groups); - state.counters["group size"] = benchmark::Counter(params.n / params.num_groups); - state.counters["Pat"] = benchmark::Counter(static_cast(params.pattern)); - - state.counters["SM count"] = raft::getMultiProcessorCount(); - } -}; - -const std::vector masked_l2_nn_input_vecs = { - // Very fat matrices... - {32, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {64, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {128, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {256, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {512, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {1024, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 32, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 64, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 128, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 256, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 512, 16384, 32, AdjacencyPattern::checkerboard}, - {16384, 1024, 16384, 32, AdjacencyPattern::checkerboard}, - - // Representative matrices... - {16384, 16384, 32, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 64, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 128, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 256, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 512, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard}, - {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard}, - - {16384, 16384, 32, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 64, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 128, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 256, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 512, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_4}, - {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_4}, - - {16384, 16384, 32, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 64, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 128, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 256, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 512, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_64}, - {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_64}, - - {16384, 16384, 32, 32, AdjacencyPattern::all_true}, - {16384, 16384, 64, 32, AdjacencyPattern::all_true}, - {16384, 16384, 128, 32, AdjacencyPattern::all_true}, - {16384, 16384, 256, 32, AdjacencyPattern::all_true}, - {16384, 16384, 512, 32, AdjacencyPattern::all_true}, - {16384, 16384, 1024, 32, AdjacencyPattern::all_true}, - {16384, 16384, 16384, 32, AdjacencyPattern::all_true}, - - {16384, 16384, 32, 32, AdjacencyPattern::all_false}, - {16384, 16384, 64, 32, AdjacencyPattern::all_false}, - {16384, 16384, 128, 32, AdjacencyPattern::all_false}, - {16384, 16384, 256, 32, AdjacencyPattern::all_false}, - {16384, 16384, 512, 32, AdjacencyPattern::all_false}, - {16384, 16384, 1024, 32, AdjacencyPattern::all_false}, - {16384, 16384, 16384, 32, AdjacencyPattern::all_false}, -}; - -RAFT_BENCH_REGISTER(masked_l2_nn, "", masked_l2_nn_input_vecs); -// We don't benchmark double to keep compile times in check when not using the -// distance library. - -} // namespace raft::bench::distance::masked_nn diff --git a/cpp/bench/prims/distance/tune_pairwise/bench.cu b/cpp/bench/prims/distance/tune_pairwise/bench.cu deleted file mode 100644 index 81105cdefe..0000000000 --- a/cpp/bench/prims/distance/tune_pairwise/bench.cu +++ /dev/null @@ -1,155 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// Tuning benchmarks. -// -// Goals: -// -// 1. Fast compile times to maintain iteration speed. -// 2. Create benchmarks that can inform the design of the kernels. -// -// Non-goals: -// -// 1. Measure every distance operation. Instead measures just one distance -// operation at the same time. -// 2. Be useful for finding performance regressions. This is handled by the -// normal benchmarks. -// -// So far, both goals are partly achieved. -// -// RE (1), COMPILE TIMES: kernel.cu is fast to compile. This file is not. -// When the internals of a pairwise distance kernel is changed, this file is not -// recompiled. -// -// RE 2, benchmarks with intent: this file contains a benchmark to check the -// maximal throughput of a kernel. Measuring other things, like performance on -// skinny or wide matrices is not yet implemented. - -#include "kernel.cuh" // launch_kernel - -#include // RAFT_BENCH_REGISTER - -#include // pairwise_matrix_params - -#include // rmm::device_uvector - -#include // std::min -#include // std::vector - -namespace raft::bench::distance::tune { - -// Max throughput benchmark. -// -// Goal: Measure the maximum distances/sec that can be computed. -// -// To achieve this, we make sure that: -// -// - Input data size is a multiple of the block tile size. -// -// - Perfect distribution of work between SMs, i.e. the number of block tiles is -// a large multiple (num_waves) of the number of blocks (#SMs * occupancy). -// -// - Multiple iterations over Kblk are executed (num_k_iters). -struct throughput_param { - int num_waves; - int occupancy; - int num_k_iters; -}; - -const std::vector throughput_params{ - // 32 waves, requested occupancy of 4, and 32 k iterations typically achieves - // maximum throughput. No need to pick higher values. - {32, 4, 32}, -}; - -struct throughput_bench : public fixture { - const throughput_param p; - - throughput_bench(const throughput_param& p_) : p(p_) {} - - void run_benchmark(::benchmark::State& state) override - { - // Get block size: - int block_m, block_n, block_k; - get_block_size(block_m, block_n, block_k); - - // Determine number of blocks that will be launched. This informs the size - // of the inputs as well as the grid size. - const int num_sms = raft::getMultiProcessorCount(); - const int max_occupancy = get_max_occupancy(); - const int occupancy = std::min(p.occupancy, max_occupancy); - const int num_blocks = occupancy * num_sms; - dim3 grid(num_blocks); - - // Create input sizes that are a multiple of the block tile size. - size_t m = block_m; - size_t n = block_n * p.num_waves * num_blocks; - size_t k = block_k * p.num_k_iters; - - // DataT, OutT, IdxT, etc, are defined in tuned_kernel.cuh - rmm::device_uvector x_vec(m * k, stream); - rmm::device_uvector y_vec(n * k, stream); - rmm::device_uvector x_norm_vec(m, stream); - rmm::device_uvector y_norm_vec(n, stream); - rmm::device_uvector out_vec(m * n, stream); - - auto x = x_vec.data(); - auto y = y_vec.data(); - auto x_norm = x_norm_vec.data(); - auto y_norm = y_norm_vec.data(); - auto out = out_vec.data(); - FinOpT fin_op{}; - - // Create kernel parameter struct. Flip x and y if column major. - IdxT ldx = row_major ? k : m; - IdxT ldy = row_major ? k : n; - IdxT ld_out = row_major ? n : m; - - // Template parameters of pairwise_matrix_params are defined in kernel.cuh - pairwise_matrix_params kparams{ - IdxT(m), IdxT(n), IdxT(k), ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, row_major}; - - // Run benchmark - loop_on_state(state, [&]() { launch_kernel(kparams, grid, stream); }); - - // Report metrics. We don't report flop/s because we do not know for each - // distance operation how many flops it costs. For L2_unexp and l1, we can - // double this number to get the flop/s. For l2 expanded, core_ops/s should - // equal flop/s (modulo the sqrt and subtracting from the norm). - size_t num_core_ops = m * n * k; - size_t read_elts = n * k + m * k; - size_t write_elts = m * n; - - state.counters["m"] = benchmark::Counter(m); - state.counters["n"] = benchmark::Counter(n); - state.counters["k"] = benchmark::Counter(k); - state.counters["occupancy"] = benchmark::Counter(occupancy); - state.counters["# waves"] = benchmark::Counter(p.num_waves); - state.counters["# k iters"] = benchmark::Counter(p.num_k_iters); - - state.counters["core_ops/s"] = benchmark::Counter(num_core_ops, - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - - state.counters["BW"] = benchmark::Counter(write_elts * sizeof(OutT) + read_elts * sizeof(DataT), - benchmark::Counter::kIsIterationInvariantRate, - benchmark::Counter::OneK::kIs1000); - } -}; - -RAFT_BENCH_REGISTER(throughput_bench, "", throughput_params); - -} // namespace raft::bench::distance::tune diff --git a/cpp/bench/prims/distance/tune_pairwise/kernel.cu b/cpp/bench/prims/distance/tune_pairwise/kernel.cu deleted file mode 100644 index 42173c51f5..0000000000 --- a/cpp/bench/prims/distance/tune_pairwise/kernel.cu +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "kernel.cuh" - -#include // pairwise_matrix_sm60_wrapper -#include // raft::linalg::Policy4x4 -#include // raft::util::arch::SM_compute_arch - -namespace raft::bench::distance::tune { - -// Distance op -using OpT = raft::distance::detail::ops::lp_unexp_distance_op; -constexpr float metric_arg = 2.0; -OpT distance_op{metric_arg}; - -// Kernel policy -constexpr int vec_len = 1; -using Policy = typename raft::linalg::Policy4x4::Policy; - -// Architecture -namespace arch = raft::util::arch; -constexpr auto sm_compat_range = arch::SM_range(arch::SM_min(), arch::SM_future()); - -void launch_kernel(pairwise_matrix_params params, dim3 grid, cudaStream_t stream) -{ - dim3 block(Policy::Nthreads); - int smem_size = OpT::shared_mem_size(); - - // Obtain function pointer to kernel - auto kernel = raft::distance::detail::pairwise_matrix_kernel; - - kernel<<>>(distance_op, params); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -void get_block_size(int& m, int& n, int& k) -{ - m = Policy::Mblk; - n = Policy::Nblk; - k = Policy::Kblk; -} - -void* get_kernel_ptr() -{ - auto kernel = raft::distance::detail::pairwise_matrix_kernel; - return reinterpret_cast(kernel); -} - -int get_max_occupancy() -{ - void* kernel_ptr = get_kernel_ptr(); - int max_occupancy; - int smem_size = OpT::shared_mem_size(); - - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &max_occupancy, kernel_ptr, Policy::Nthreads, smem_size)); - - return max_occupancy; -} - -} // namespace raft::bench::distance::tune diff --git a/cpp/bench/prims/distance/tune_pairwise/kernel.cuh b/cpp/bench/prims/distance/tune_pairwise/kernel.cuh deleted file mode 100644 index 5da54a343c..0000000000 --- a/cpp/bench/prims/distance/tune_pairwise/kernel.cuh +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // lp_unexp_distance_op -#include // pairwise_matrix_params - -namespace raft::bench::distance::tune { - -// Launch one specific kernel with the following template parameters -constexpr bool row_major = true; -using DataT = float; -using AccT = float; -using OutT = DataT; -using IdxT = int; - -using FinOpT = raft::identity_op; - -using pairwise_matrix_params = - raft::distance::detail::pairwise_matrix_params; - -// Launches kernel -void launch_kernel(pairwise_matrix_params, dim3, cudaStream_t); - -// Describes the block size that is decided by the policy -void get_block_size(int& m, int& n, int& k); - -int get_max_occupancy(); - -} // namespace raft::bench::distance::tune diff --git a/cpp/bench/prims/neighbors/cagra_bench.cuh b/cpp/bench/prims/neighbors/cagra_bench.cuh deleted file mode 100644 index acbeba375a..0000000000 --- a/cpp/bench/prims/neighbors/cagra_bench.cuh +++ /dev/null @@ -1,208 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include -#include -#include -#include - -#include - -#include - -namespace raft::bench::neighbors { - -struct params { - /** Size of the dataset. */ - size_t n_samples; - /** Number of dimensions in the dataset. */ - int n_dims; - /** The batch size -- number of KNN searches. */ - int n_queries; - /** Number of nearest neighbours to find for every probe. */ - int k; - /** kNN graph degree*/ - int degree; - int itopk_size; - int block_size; - int search_width; - int max_iterations; - /** Ratio of removed indices. */ - double removed_ratio; -}; - -template -struct CagraBench : public fixture { - explicit CagraBench(const params& ps) - : fixture(true), - params_(ps), - queries_(make_device_matrix(handle, ps.n_queries, ps.n_dims)), - dataset_(make_device_matrix(handle, ps.n_samples, ps.n_dims)), - knn_graph_(make_device_matrix(handle, ps.n_samples, ps.degree)), - removed_indices_bitset_(handle, ps.n_samples) - { - // Generate random dataset and queriees - raft::random::RngState state{42}; - constexpr T kRangeMax = std::is_integral_v ? std::numeric_limits::max() : T(1); - constexpr T kRangeMin = std::is_integral_v ? std::numeric_limits::min() : T(-1); - if constexpr (std::is_integral_v) { - raft::random::uniformInt( - handle, state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax); - raft::random::uniformInt( - handle, state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax); - } else { - raft::random::uniform( - handle, state, dataset_.data_handle(), dataset_.size(), kRangeMin, kRangeMax); - raft::random::uniform( - handle, state, queries_.data_handle(), queries_.size(), kRangeMin, kRangeMax); - } - - // Generate random knn graph - - raft::random::uniformInt( - handle, state, knn_graph_.data_handle(), knn_graph_.size(), 0, ps.n_samples - 1); - - auto metric = raft::distance::DistanceType::L2Expanded; - - auto removed_indices = - raft::make_device_vector(handle, ps.removed_ratio * ps.n_samples); - thrust::sequence( - resource::get_thrust_policy(handle), - thrust::device_pointer_cast(removed_indices.data_handle()), - thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); - removed_indices_bitset_.set(handle, removed_indices.view()); - index_.emplace(raft::neighbors::cagra::index( - handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view()))); - } - - void run_benchmark(::benchmark::State& state) override - { - raft::neighbors::cagra::search_params search_params; - search_params.max_queries = 1024; - search_params.itopk_size = params_.itopk_size; - search_params.team_size = 0; - search_params.thread_block_size = params_.block_size; - search_params.search_width = params_.search_width; - - auto indices = make_device_matrix(handle, params_.n_queries, params_.k); - auto distances = make_device_matrix(handle, params_.n_queries, params_.k); - auto ind_v = make_device_matrix_view( - indices.data_handle(), params_.n_queries, params_.k); - auto dist_v = make_device_matrix_view( - distances.data_handle(), params_.n_queries, params_.k); - - auto queries_v = make_const_mdspan(queries_.view()); - if (params_.removed_ratio > 0) { - auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view()); - loop_on_state(state, [&]() { - raft::neighbors::cagra::search_with_filtering( - this->handle, search_params, *this->index_, queries_v, ind_v, dist_v, filter); - }); - } else { - loop_on_state(state, [&]() { - raft::neighbors::cagra::search( - this->handle, search_params, *this->index_, queries_v, ind_v, dist_v); - }); - } - - double data_size = params_.n_samples * params_.n_dims * sizeof(T); - double graph_size = params_.n_samples * params_.degree * sizeof(IdxT); - - int iterations = params_.max_iterations; - if (iterations == 0) { - // see search_plan_impl::adjust_search_params() - double r = params_.itopk_size / static_cast(params_.search_width); - iterations = 1 + std::min(r * 1.1, r + 10); - } - state.counters["dataset (GiB)"] = data_size / (1 << 30); - state.counters["graph (GiB)"] = graph_size / (1 << 30); - state.counters["n_rows"] = params_.n_samples; - state.counters["n_cols"] = params_.n_dims; - state.counters["degree"] = params_.degree; - state.counters["n_queries"] = params_.n_queries; - state.counters["k"] = params_.k; - state.counters["itopk_size"] = params_.itopk_size; - state.counters["block_size"] = params_.block_size; - state.counters["search_width"] = params_.search_width; - state.counters["iterations"] = iterations; - state.counters["removed_ratio"] = params_.removed_ratio; - } - - private: - const params params_; - std::optional> index_; - raft::device_matrix queries_; - raft::device_matrix dataset_; - raft::device_matrix knn_graph_; - raft::core::bitset removed_indices_bitset_; -}; - -inline const std::vector generate_inputs() -{ - std::vector inputs = - raft::util::itertools::product({2000000ull}, // n_samples - {128, 256, 512, 1024}, // dataset dim - {1000}, // n_queries - {32}, // k - {64}, // knn graph degree - {64}, // itopk_size - {0}, // block_size - {1}, // search_width - {0}, // max_iterations - {0.0} // removed_ratio - ); - auto inputs2 = raft::util::itertools::product({2000000ull, 10000000ull}, // n_samples - {128}, // dataset dim - {1000}, // n_queries - {32}, // k - {64}, // knn graph degree - {64}, // itopk_size - {64, 128, 256, 512, 1024}, // block_size - {1}, // search_width - {0}, // max_iterations - {0.0} // removed_ratio - ); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - - inputs2 = raft::util::itertools::product( - {2000000ull, 10000000ull}, // n_samples - {128}, // dataset dim - {1, 10, 10000}, // n_queries - {255}, // k - {64}, // knn graph degree - {300}, // itopk_size - {256}, // block_size - {2}, // search_width - {0}, // max_iterations - {0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio - ); - inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); - return inputs; -} - -const std::vector kCagraInputs = generate_inputs(); - -#define CAGRA_REGISTER(ValT, IdxT, inputs) \ - namespace BENCHMARK_PRIVATE_NAME(knn) { \ - using AnnCagra = CagraBench; \ - RAFT_BENCH_REGISTER(AnnCagra, #ValT "/" #IdxT, inputs); \ - } - -} // namespace raft::bench::neighbors diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh deleted file mode 100644 index 6499078623..0000000000 --- a/cpp/bench/prims/neighbors/knn.cuh +++ /dev/null @@ -1,516 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include - -#include - -namespace raft::bench::spatial { - -struct params { - /** Size of the dataset. */ - size_t n_samples; - /** Number of dimensions in the dataset. */ - size_t n_dims; - /** The batch size -- number of KNN searches. */ - size_t n_queries; - /** Number of nearest neighbours to find for every probe. */ - size_t k; - /** Ratio of removed indices. */ - double removed_ratio; -}; - -inline auto operator<<(std::ostream& os, const params& p) -> std::ostream& -{ - os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k << "#" - << p.removed_ratio; - return os; -} - -enum class TransferStrategy { NO_COPY, COPY_PLAIN, COPY_PINNED, MAP_PINNED, MANAGED }; // NOLINT -enum class Scope { BUILD, SEARCH, BUILD_SEARCH }; // NOLINT - -inline auto operator<<(std::ostream& os, const TransferStrategy& ts) -> std::ostream& -{ - switch (ts) { - case TransferStrategy::NO_COPY: os << "NO_COPY"; break; - case TransferStrategy::COPY_PLAIN: os << "COPY_PLAIN"; break; - case TransferStrategy::COPY_PINNED: os << "COPY_PINNED"; break; - case TransferStrategy::MAP_PINNED: os << "MAP_PINNED"; break; - case TransferStrategy::MANAGED: os << "MANAGED"; break; - default: os << "UNKNOWN"; - } - return os; -} - -inline auto operator<<(std::ostream& os, const Scope& s) -> std::ostream& -{ - switch (s) { - case Scope::BUILD: os << "BUILD"; break; - case Scope::SEARCH: os << "SEARCH"; break; - case Scope::BUILD_SEARCH: os << "BUILD_SEARCH"; break; - default: os << "UNKNOWN"; - } - return os; -} - -struct device_resource { - public: - explicit device_resource(bool managed) : managed_(managed) - { - if (managed_) { - res_ = new rmm::mr::managed_memory_resource(); - } else { - res_ = rmm::mr::get_current_device_resource(); - } - } - - ~device_resource() - { - if (managed_) { delete res_; } - } - - [[nodiscard]] auto get() const -> rmm::device_async_resource_ref { return res_; } - - private: - const bool managed_; - rmm::mr::device_memory_resource* res_; -}; - -template -struct host_uvector { - host_uvector(size_t n, bool pinned) : n_(n) - { - if (pinned) { - res_ = new rmm::mr::pinned_memory_resource(); - } else { - res_ = new rmm::mr::new_delete_resource(); - } - arr_ = static_cast(res_->allocate(n_ * sizeof(T))); - } - - ~host_uvector() noexcept - { - res_->deallocate(arr_, n_ * sizeof(T)); - delete res_; - } - - auto data() -> T* { return arr_; } - [[nodiscard]] auto size() const -> size_t { return n_; } - - private: - rmm::mr::host_memory_resource* res_; - size_t n_; - T* arr_; -}; - -template -struct ivf_flat_knn { - using dist_t = float; - - std::optional> index; - raft::neighbors::ivf_flat::index_params index_params; - raft::neighbors::ivf_flat::search_params search_params; - params ps; - - ivf_flat_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps) - { - index_params.n_lists = 4096; - index_params.metric = raft::distance::DistanceType::L2Expanded; - index.emplace(raft::neighbors::ivf_flat::build( - handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); - } - - void search(const raft::device_resources& handle, - const ValT* search_items, - dist_t* out_dists, - IdxT* out_idxs) - { - search_params.n_probes = 20; - raft::neighbors::ivf_flat::search(handle, - search_params, - *index, - search_items, - ps.n_queries, - ps.k, - out_idxs, - out_dists, - resource::get_workspace_resource(handle)); - } -}; - -template -struct ivf_pq_knn { - using dist_t = float; - - std::optional> index; - raft::neighbors::ivf_pq::index_params index_params; - raft::neighbors::ivf_pq::search_params search_params; - params ps; - - ivf_pq_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps) - { - index_params.n_lists = 4096; - index_params.metric = raft::distance::DistanceType::L2Expanded; - auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); - index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); - } - - void search(const raft::device_resources& handle, - const ValT* search_items, - dist_t* out_dists, - IdxT* out_idxs) - { - search_params.n_probes = 20; - auto queries_view = - raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); - auto idxs_view = raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); - auto dists_view = - raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); - raft::neighbors::ivf_pq::search( - handle, search_params, *index, queries_view, idxs_view, dists_view); - } -}; - -template -struct brute_force_knn { - using dist_t = ValT; - - ValT* index; - params ps; - - brute_force_knn(const raft::device_resources& handle, const params& ps, const ValT* data) - : index(const_cast(data)), ps(ps) - { - } - - void search(const raft::device_resources& handle, - const ValT* search_items, - dist_t* out_dists, - IdxT* out_idxs) - { - std::vector input{index}; - std::vector sizes{ps.n_samples}; - raft::spatial::knn::brute_force_knn(handle, - input, - sizes, - ps.n_dims, - const_cast(search_items), - ps.n_queries, - out_idxs, - out_dists, - ps.k); - } -}; - -template -struct ivf_flat_filter_knn { - using dist_t = float; - - std::optional> index; - raft::neighbors::ivf_flat::index_params index_params; - raft::neighbors::ivf_flat::search_params search_params; - raft::core::bitset removed_indices_bitset_; - params ps; - - ivf_flat_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data) - : ps(ps), removed_indices_bitset_(handle, ps.n_samples) - { - index_params.n_lists = 4096; - index_params.metric = raft::distance::DistanceType::L2Expanded; - index.emplace(raft::neighbors::ivf_flat::build( - handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); - auto removed_indices = - raft::make_device_vector(handle, ps.removed_ratio * ps.n_samples); - thrust::sequence( - resource::get_thrust_policy(handle), - thrust::device_pointer_cast(removed_indices.data_handle()), - thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); - removed_indices_bitset_.set(handle, removed_indices.view()); - } - - void search(const raft::device_resources& handle, - const ValT* search_items, - dist_t* out_dists, - IdxT* out_idxs) - { - search_params.n_probes = 20; - auto queries_view = - raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); - auto neighbors_view = raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); - auto distance_view = raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); - auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view()); - - if (ps.removed_ratio > 0) { - raft::neighbors::ivf_flat::search_with_filtering( - handle, search_params, *index, queries_view, neighbors_view, distance_view, filter); - } else { - raft::neighbors::ivf_flat::search( - handle, search_params, *index, queries_view, neighbors_view, distance_view); - } - } -}; - -template -struct ivf_pq_filter_knn { - using dist_t = float; - - std::optional> index; - raft::neighbors::ivf_pq::index_params index_params; - raft::neighbors::ivf_pq::search_params search_params; - raft::core::bitset removed_indices_bitset_; - params ps; - - ivf_pq_filter_knn(const raft::device_resources& handle, const params& ps, const ValT* data) - : ps(ps), removed_indices_bitset_(handle, ps.n_samples) - { - index_params.n_lists = 4096; - index_params.metric = raft::distance::DistanceType::L2Expanded; - auto data_view = raft::make_device_matrix_view(data, ps.n_samples, ps.n_dims); - index.emplace(raft::neighbors::ivf_pq::build(handle, index_params, data_view)); - auto removed_indices = - raft::make_device_vector(handle, ps.removed_ratio * ps.n_samples); - thrust::sequence( - resource::get_thrust_policy(handle), - thrust::device_pointer_cast(removed_indices.data_handle()), - thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0))); - removed_indices_bitset_.set(handle, removed_indices.view()); - } - - void search(const raft::device_resources& handle, - const ValT* search_items, - dist_t* out_dists, - IdxT* out_idxs) - { - search_params.n_probes = 20; - auto queries_view = - raft::make_device_matrix_view(search_items, ps.n_queries, ps.n_dims); - auto neighbors_view = - raft::make_device_matrix_view(out_idxs, ps.n_queries, ps.k); - auto distance_view = - raft::make_device_matrix_view(out_dists, ps.n_queries, ps.k); - auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view()); - - if (ps.removed_ratio > 0) { - raft::neighbors::ivf_pq::search_with_filtering( - handle, search_params, *index, queries_view, neighbors_view, distance_view, filter); - } else { - raft::neighbors::ivf_pq::search( - handle, search_params, *index, queries_view, neighbors_view, distance_view); - } - } -}; - -template -struct knn : public fixture { - explicit knn(const params& p, const TransferStrategy& strategy, const Scope& scope) - : fixture(true), - params_(p), - strategy_(strategy), - scope_(scope), - dev_mem_res_(strategy == TransferStrategy::MANAGED), - data_host_(0), - search_items_(p.n_queries * p.n_dims, stream), - out_dists_(p.n_queries * p.k, stream), - out_idxs_(p.n_queries * p.k, stream) - { - raft::random::RngState state{42}; - gen_data(state, search_items_, search_items_.size(), stream); - try { - size_t total_size = p.n_samples * p.n_dims; - data_host_.resize(total_size); - constexpr size_t kGenMinibatchSize = 1024 * 1024 * 1024; - rmm::device_uvector d(std::min(kGenMinibatchSize, total_size), stream); - for (size_t offset = 0; offset < total_size; offset += kGenMinibatchSize) { - size_t actual_size = std::min(total_size - offset, kGenMinibatchSize); - gen_data(state, d, actual_size, stream); - copy(data_host_.data() + offset, d.data(), actual_size, stream); - } - } catch (std::bad_alloc& e) { - data_does_not_fit_ = true; - } - } - - template - void gen_data(raft::random::RngState& state, // NOLINT - rmm::device_uvector& vec, - size_t n, - rmm::cuda_stream_view stream) - { - constexpr T kRangeMax = std::is_integral_v ? std::numeric_limits::max() : T(1); - constexpr T kRangeMin = std::is_integral_v ? std::numeric_limits::min() : T(-1); - if constexpr (std::is_integral_v) { - raft::random::uniformInt(handle, state, vec.data(), n, kRangeMin, kRangeMax); - } else { - raft::random::uniform(handle, state, vec.data(), n, kRangeMin, kRangeMax); - } - } - - void run_benchmark(::benchmark::State& state) override - { - if (data_does_not_fit_) { - state.SkipWithError("The data size is too big to fit into the host memory."); - } - if (scope_ == Scope::SEARCH && strategy_ != TransferStrategy::NO_COPY) { - state.SkipWithError( - "When benchmarking without index building (Scope::SEARCH), the data must be already on the " - "device (TransferStrategy::NO_COPY)"); - } - - try { - std::ostringstream label_stream; - label_stream << params_ << "#" << strategy_ << "#" << scope_; - state.SetLabel(label_stream.str()); - raft::device_resources handle(stream); - std::optional index; - - if (scope_ == Scope::SEARCH) { // also implies TransferStrategy::NO_COPY - rmm::device_uvector data(data_host_.size(), stream); - copy(data.data(), data_host_.data(), data_host_.size(), stream); - index.emplace(handle, params_, data.data()); - stream.synchronize(); - } - - // benchmark loop - for (auto _ : state) { - // managed or plain device memory initialized anew every time - rmm::device_uvector data(data_host_.size(), stream, dev_mem_res_.get()); - ValT* data_ptr = data.data(); - size_t allocation_size = data_host_.size() * sizeof(ValT); - - // Non-benchmarked part: using different methods to copy the data if necessary - switch (strategy_) { - case TransferStrategy::NO_COPY: // copy data to GPU before starting the timer. - copy(data_ptr, data_host_.data(), data_host_.size(), stream); - break; - case TransferStrategy::COPY_PINNED: - RAFT_CUDA_TRY( - cudaHostRegister(data_host_.data(), allocation_size, cudaHostRegisterDefault)); - break; - case TransferStrategy::MAP_PINNED: - RAFT_CUDA_TRY( - cudaHostRegister(data_host_.data(), allocation_size, cudaHostRegisterMapped)); - RAFT_CUDA_TRY(cudaHostGetDevicePointer(&data_ptr, data_host_.data(), 0)); - break; - case TransferStrategy::MANAGED: // sic! using std::memcpy rather than cuda copy - RAFT_CUDA_TRY(cudaMemAdvise(data_ptr, - allocation_size, - cudaMemAdviseSetPreferredLocation, - resource::get_device_id(handle))); - RAFT_CUDA_TRY(cudaMemAdvise(data_ptr, - allocation_size, - cudaMemAdviseSetAccessedBy, - resource::get_device_id(handle))); - RAFT_CUDA_TRY(cudaMemAdvise(data_ptr, - allocation_size, - cudaMemAdviseSetReadMostly, - resource::get_device_id(handle))); - std::memcpy(data_ptr, data_host_.data(), allocation_size); - break; - default: break; - } - - flush_L2_cache(); - { - // Timer synchronizes the stream, so all prior gpu work should be done before it sets off. - cuda_event_timer timer(state, stream); - switch (strategy_) { - case TransferStrategy::COPY_PLAIN: - case TransferStrategy::COPY_PINNED: - copy(data_ptr, data_host_.data(), data_host_.size(), stream); - default: break; - } - - if (scope_ != Scope::SEARCH) { index.emplace(handle, params_, data_ptr); } - if (scope_ != Scope::BUILD) { - index->search(handle, search_items_.data(), out_dists_.data(), out_idxs_.data()); - } - } - - if (scope_ != Scope::SEARCH) { index.reset(); } - - switch (strategy_) { - case TransferStrategy::COPY_PINNED: - case TransferStrategy::MAP_PINNED: - RAFT_CUDA_TRY(cudaHostUnregister(data_host_.data())); - break; - default: break; - } - } - } catch (raft::exception& e) { - state.SkipWithError(e.what()); - } catch (std::bad_alloc& e) { - state.SkipWithError(e.what()); - } - } - - private: - const params params_; - const TransferStrategy strategy_; - const Scope scope_; - device_resource dev_mem_res_; - bool data_does_not_fit_ = false; - - std::vector data_host_; - rmm::device_uvector search_items_; - rmm::device_uvector out_dists_; - rmm::device_uvector out_idxs_; -}; - -inline const std::vector kInputs{ - {2000000, 128, 1000, 32, 0}, {10000000, 128, 1000, 32, 0}, {10000, 8192, 1000, 32, 0}}; - -const std::vector kInputsFilter = - raft::util::itertools::product({size_t(10000000)}, // n_samples - {size_t(128)}, // n_dim - {size_t(1000)}, // n_queries - {size_t(255)}, // k - {0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio - ); -inline const std::vector kAllStrategies{ - TransferStrategy::NO_COPY, TransferStrategy::MAP_PINNED, TransferStrategy::MANAGED}; -inline const std::vector kNoCopyOnly{TransferStrategy::NO_COPY}; - -inline const std::vector kScopeFull{Scope::BUILD_SEARCH}; -inline const std::vector kAllScopes{Scope::BUILD_SEARCH, Scope::SEARCH, Scope::BUILD}; - -#define KNN_REGISTER(ValT, IdxT, ImplT, inputs, strats, scope) \ - namespace BENCHMARK_PRIVATE_NAME(knn) { \ - using KNN = knn>; \ - RAFT_BENCH_REGISTER(KNN, #ValT "/" #IdxT "/" #ImplT, inputs, strats, scope); \ - } - -} // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/knn/brute_force_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/brute_force_float_int64_t.cu deleted file mode 100644 index 7df0599670..0000000000 --- a/cpp/bench/prims/neighbors/knn/brute_force_float_int64_t.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../knn.cuh" - -namespace raft::bench::spatial { - -KNN_REGISTER(float, int64_t, brute_force_knn, kInputs, kAllStrategies, kScopeFull); - -} // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/knn/brute_force_float_uint32_t.cu b/cpp/bench/prims/neighbors/knn/brute_force_float_uint32_t.cu deleted file mode 100644 index 9704d39e76..0000000000 --- a/cpp/bench/prims/neighbors/knn/brute_force_float_uint32_t.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../knn.cuh" - -namespace raft::bench::spatial { - -KNN_REGISTER(float, uint32_t, brute_force_knn, kInputs, kAllStrategies, kScopeFull); - -} // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/knn/cagra_float_uint32_t.cu b/cpp/bench/prims/neighbors/knn/cagra_float_uint32_t.cu deleted file mode 100644 index 5d762f6e85..0000000000 --- a/cpp/bench/prims/neighbors/knn/cagra_float_uint32_t.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../cagra_bench.cuh" - -namespace raft::bench::neighbors { - -CAGRA_REGISTER(float, uint32_t, kCagraInputs); - -} // namespace raft::bench::neighbors diff --git a/cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu deleted file mode 100644 index bf5118ceae..0000000000 --- a/cpp/bench/prims/neighbors/knn/ivf_flat_filter_float_int64_t.cu +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY // Enable instantiation of search with filter -#include "../knn.cuh" - -namespace raft::bench::spatial { - -KNN_REGISTER(float, int64_t, ivf_flat_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull); - -} // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/knn/ivf_flat_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_flat_float_int64_t.cu deleted file mode 100644 index fbbb4f9acc..0000000000 --- a/cpp/bench/prims/neighbors/knn/ivf_flat_float_int64_t.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../knn.cuh" - -namespace raft::bench::spatial { - -KNN_REGISTER(float, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes); - -} // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/knn/ivf_flat_int8_t_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_flat_int8_t_int64_t.cu deleted file mode 100644 index 7067dbe1b6..0000000000 --- a/cpp/bench/prims/neighbors/knn/ivf_flat_int8_t_int64_t.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../knn.cuh" - -namespace raft::bench::spatial { - -KNN_REGISTER(int8_t, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes); - -} // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/knn/ivf_flat_uint8_t_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_flat_uint8_t_int64_t.cu deleted file mode 100644 index 91fada3c28..0000000000 --- a/cpp/bench/prims/neighbors/knn/ivf_flat_uint8_t_int64_t.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../knn.cuh" - -namespace raft::bench::spatial { - -KNN_REGISTER(uint8_t, int64_t, ivf_flat_knn, kInputs, kNoCopyOnly, kAllScopes); - -} // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu deleted file mode 100644 index 1840eca99d..0000000000 --- a/cpp/bench/prims/neighbors/knn/ivf_pq_filter_float_int64_t.cu +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../knn.cuh" - -#include -#include -namespace raft::bench::spatial { - -KNN_REGISTER(float, int64_t, ivf_pq_filter_knn, kInputsFilter, kNoCopyOnly, kScopeFull); - -} // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/knn/ivf_pq_float_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_pq_float_int64_t.cu deleted file mode 100644 index 83c4973c3a..0000000000 --- a/cpp/bench/prims/neighbors/knn/ivf_pq_float_int64_t.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../knn.cuh" - -namespace raft::bench::spatial { - -KNN_REGISTER(float, int64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes); - -} // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/knn/ivf_pq_int8_t_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_pq_int8_t_int64_t.cu deleted file mode 100644 index 4ea281b11a..0000000000 --- a/cpp/bench/prims/neighbors/knn/ivf_pq_int8_t_int64_t.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../knn.cuh" - -namespace raft::bench::spatial { - -KNN_REGISTER(int8_t, int64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes); - -} // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/knn/ivf_pq_uint8_t_int64_t.cu b/cpp/bench/prims/neighbors/knn/ivf_pq_uint8_t_int64_t.cu deleted file mode 100644 index 3313a49ba2..0000000000 --- a/cpp/bench/prims/neighbors/knn/ivf_pq_uint8_t_int64_t.cu +++ /dev/null @@ -1,23 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../knn.cuh" - -namespace raft::bench::spatial { - -KNN_REGISTER(uint8_t, int64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes); - -} // namespace raft::bench::spatial diff --git a/cpp/bench/prims/neighbors/refine.cuh b/cpp/bench/prims/neighbors/refine.cuh deleted file mode 100644 index 0360babd82..0000000000 --- a/cpp/bench/prims/neighbors/refine.cuh +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include - -#include -#include - -using namespace raft::neighbors; - -namespace raft::bench::neighbors { - -template -inline auto operator<<(std::ostream& os, const RefineInputs& p) -> std::ostream& -{ - os << p.n_rows << "#" << p.dim << "#" << p.n_queries << "#" << p.k0 << "#" << p.k << "#" - << (p.host_data ? "host" : "device"); - return os; -} - -template -class RefineAnn : public fixture { - public: - RefineAnn(RefineInputs p) : data(handle_, p) {} - - void run_benchmark(::benchmark::State& state) override - { - std::ostringstream label_stream; - label_stream << data.p; - state.SetLabel(label_stream.str()); - - auto old_mr = rmm::mr::get_current_device_resource(); - rmm::mr::pool_memory_resource pool_mr( - old_mr, rmm::percent_of_free_device_memory(50)); - rmm::mr::set_current_device_resource(&pool_mr); - - if (data.p.host_data) { - loop_on_state(state, [this]() { - raft::neighbors::refine(handle_, - data.dataset_host.view(), - data.queries_host.view(), - data.candidates_host.view(), - data.refined_indices_host.view(), - data.refined_distances_host.view(), - data.p.metric); - }); - } else { - loop_on_state(state, [&]() { - raft::neighbors::refine(handle_, - data.dataset.view(), - data.queries.view(), - data.candidates.view(), - data.refined_indices.view(), - data.refined_distances.view(), - data.p.metric); - }); - } - rmm::mr::set_current_device_resource(old_mr); - } - - private: - raft::device_resources handle_; - RefineHelper data; -}; - -template -std::vector> getInputs() -{ - std::vector> out; - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; - for (bool host_data : {true, false}) { - for (T n_queries : {1000, 10000}) { - for (T dim : {128, 512}) { - out.push_back(RefineInputs{n_queries, 2000000, dim, 32, 128, metric, host_data}); - out.push_back(RefineInputs{n_queries, 2000000, dim, 10, 40, metric, host_data}); - } - } - } - return out; -} - -} // namespace raft::bench::neighbors diff --git a/cpp/bench/prims/neighbors/refine_float_int64_t.cu b/cpp/bench/prims/neighbors/refine_float_int64_t.cu deleted file mode 100644 index d69a157eca..0000000000 --- a/cpp/bench/prims/neighbors/refine_float_int64_t.cu +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "refine.cuh" - -#include - -using namespace raft::neighbors; - -namespace raft::bench::neighbors { -using refine_float_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs()); -} // namespace raft::bench::neighbors diff --git a/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu b/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu deleted file mode 100644 index 9da536b6c7..0000000000 --- a/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "refine.cuh" - -#include - -using namespace raft::neighbors; - -namespace raft::bench::neighbors { -using refine_uint8_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs()); -} // namespace raft::bench::neighbors diff --git a/cpp/include/raft/cluster/detail/agglomerative.cuh b/cpp/include/raft/cluster/detail/agglomerative.cuh deleted file mode 100644 index f2c83abdd3..0000000000 --- a/cpp/include/raft/cluster/detail/agglomerative.cuh +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::cluster::detail { -template -class UnionFind { - public: - value_idx next_label; - std::vector parent; - std::vector size; - - value_idx n_indices; - - UnionFind(value_idx N_) - : n_indices(2 * N_ - 1), parent(2 * N_ - 1, -1), size(2 * N_ - 1, 1), next_label(N_) - { - memset(size.data() + N_, 0, (size.size() - N_) * sizeof(value_idx)); - } - - value_idx find(value_idx n) - { - value_idx p; - p = n; - - while (parent[n] != -1) - n = parent[n]; - - // path compression - while (parent[p] != n) { - p = parent[p == -1 ? n_indices - 1 : p]; - parent[p == -1 ? n_indices - 1 : p] = n; - } - return n; - } - - void perform_union(value_idx m, value_idx n) - { - size[next_label] = size[m] + size[n]; - parent[m] = next_label; - parent[n] = next_label; - - next_label += 1; - } -}; - -/** - * Agglomerative labeling on host. This has not been found to be a bottleneck - * in the algorithm. A parallel version of this can be done using a parallel - * variant of Kruskal's MST algorithm - * (ref http://cucis.ece.northwestern.edu/publications/pdf/HenPat12.pdf), - * which breaks apart the sorted MST results into overlapping subsets and - * independently runs Kruskal's algorithm on each subset, merging them back - * together into a single hierarchy when complete. Unfortunately, - * this is nontrivial and the speedup wouldn't be useful until this - * becomes a bottleneck. - * - * @tparam value_idx - * @tparam value_t - * @param[in] handle the raft handle - * @param[in] rows src edges of the sorted MST - * @param[in] cols dst edges of the sorted MST - * @param[in] nnz the number of edges in the sorted MST - * @param[out] out_src parents of output - * @param[out] out_dst children of output - * @param[out] out_delta distances of output - * @param[out] out_size cluster sizes of output - */ -template -void build_dendrogram_host(raft::resources const& handle, - const value_idx* rows, - const value_idx* cols, - const value_t* data, - size_t nnz, - value_idx* children, - value_t* out_delta, - value_idx* out_size) -{ - auto stream = resource::get_cuda_stream(handle); - - value_idx n_edges = nnz; - - std::vector mst_src_h(n_edges); - std::vector mst_dst_h(n_edges); - std::vector mst_weights_h(n_edges); - - update_host(mst_src_h.data(), rows, n_edges, stream); - update_host(mst_dst_h.data(), cols, n_edges, stream); - update_host(mst_weights_h.data(), data, n_edges, stream); - - resource::sync_stream(handle, stream); - - std::vector children_h(n_edges * 2); - std::vector out_size_h(n_edges); - std::vector out_delta_h(n_edges); - - UnionFind U(nnz + 1); - - for (std::size_t i = 0; i < nnz; i++) { - value_idx a = mst_src_h[i]; - value_idx b = mst_dst_h[i]; - value_t delta = mst_weights_h[i]; - - value_idx aa = U.find(a); - value_idx bb = U.find(b); - - value_idx children_idx = i * 2; - - children_h[children_idx] = aa; - children_h[children_idx + 1] = bb; - out_delta_h[i] = delta; - out_size_h[i] = U.size[aa] + U.size[bb]; - - U.perform_union(aa, bb); - } - - raft::update_device(children, children_h.data(), n_edges * 2, stream); - raft::update_device(out_size, out_size_h.data(), n_edges, stream); - raft::update_device(out_delta, out_delta_h.data(), n_edges, stream); -} - -template -RAFT_KERNEL write_levels_kernel(const value_idx* children, value_idx* parents, value_idx n_vertices) -{ - value_idx tid = blockDim.x * blockIdx.x + threadIdx.x; - if (tid < n_vertices) { - value_idx level = tid / 2; - value_idx child = children[tid]; - parents[child] = level; - } -} - -/** - * Instead of propagating a label from roots to children, - * the children each iterate up the tree until they find - * the label of their parent. This increases the potential - * parallelism. - * @tparam value_idx - * @param children - * @param parents - * @param n_leaves - * @param labels - */ -template -RAFT_KERNEL inherit_labels(const value_idx* children, - const value_idx* levels, - std::size_t n_leaves, - value_idx* labels, - int cut_level, - value_idx n_vertices) -{ - value_idx tid = blockDim.x * blockIdx.x + threadIdx.x; - - if (tid < n_vertices) { - value_idx node = children[tid]; - value_idx cur_level = tid / 2; - - /** - * Any roots above the cut level should be ignored. - * Any leaves at the cut level should already be labeled - */ - if (cur_level > cut_level) return; - - value_idx cur_parent = node; - value_idx label = labels[cur_parent]; - - while (label == -1) { - cur_parent = cur_level + n_leaves; - cur_level = levels[cur_parent]; - label = labels[cur_parent]; - } - - labels[node] = label; - } -} - -template -struct init_label_roots { - init_label_roots(value_idx* labels_) : labels(labels_) {} - - template - __host__ __device__ void operator()(Tuple t) - { - labels[thrust::get<1>(t)] = thrust::get<0>(t); - } - - private: - value_idx* labels; -}; - -/** - * Cuts the dendrogram at a particular level where the number of nodes - * is equal to n_clusters, then propagates the resulting labels - * to all the children. - * - * @tparam value_idx - * @param handle - * @param labels - * @param children - * @param n_clusters - * @param n_leaves - */ -template -void extract_flattened_clusters(raft::resources const& handle, - value_idx* labels, - const value_idx* children, - size_t n_clusters, - size_t n_leaves) -{ - auto stream = resource::get_cuda_stream(handle); - auto thrust_policy = resource::get_thrust_policy(handle); - - // Handle special case where n_clusters == 1 - if (n_clusters == 1) { - thrust::fill(thrust_policy, labels, labels + n_leaves, 0); - } else { - /** - * Compute levels for each node - * - * 1. Initialize "levels" array of size n_leaves * 2 - * - * 2. For each entry in children, write parent - * out for each of the children - */ - - auto n_edges = (n_leaves - 1) * 2; - - thrust::device_ptr d_ptr = thrust::device_pointer_cast(children); - value_idx n_vertices = *(thrust::max_element(thrust_policy, d_ptr, d_ptr + n_edges)) + 1; - - // Prevent potential infinite loop from labeling disconnected - // connectivities graph. - RAFT_EXPECTS(n_leaves > 0, "n_leaves must be positive"); - RAFT_EXPECTS( - static_cast(n_vertices) == static_cast((n_leaves - 1) * 2), - "Multiple components found in MST or MST is invalid. " - "Cannot find single-linkage solution."); - - rmm::device_uvector levels(n_vertices, stream); - - value_idx n_blocks = ceildiv(n_vertices, (value_idx)tpb); - write_levels_kernel<<>>(children, levels.data(), n_vertices); - /** - * Step 1: Find label roots: - * - * 1. Copying children[children.size()-(n_clusters-1):] entries to - * separate arrayo - * 2. sort array - * 3. take first n_clusters entries - */ - - value_idx child_size = (n_clusters - 1) * 2; - rmm::device_uvector label_roots(child_size, stream); - - value_idx children_cpy_start = n_edges - child_size; - raft::copy_async(label_roots.data(), children + children_cpy_start, child_size, stream); - - thrust::sort(thrust_policy, - label_roots.data(), - label_roots.data() + (child_size), - thrust::greater()); - - rmm::device_uvector tmp_labels(n_vertices, stream); - - // Init labels to -1 - thrust::fill(thrust_policy, tmp_labels.data(), tmp_labels.data() + n_vertices, -1); - - // Write labels for cluster roots to "labels" - thrust::counting_iterator first(0); - - auto z_iter = thrust::make_zip_iterator( - thrust::make_tuple(first, label_roots.data() + (label_roots.size() - n_clusters))); - - thrust::for_each( - thrust_policy, z_iter, z_iter + n_clusters, init_label_roots(tmp_labels.data())); - - /** - * Step 2: Propagate labels by having children iterate through their parents - * 1. Initialize labels to -1 - * 2. For each element in levels array, propagate until parent's - * label is !=-1 - */ - value_idx cut_level = (n_edges / 2) - (n_clusters - 1); - - inherit_labels<<>>( - children, levels.data(), n_leaves, tmp_labels.data(), cut_level, n_vertices); - - // copy tmp labels to actual labels - raft::copy_async(labels, tmp_labels.data(), n_leaves, stream); - } -} - -}; // namespace raft::cluster::detail diff --git a/cpp/include/raft/cluster/detail/connectivities.cuh b/cpp/include/raft/cluster/detail/connectivities.cuh deleted file mode 100644 index c527b754c3..0000000000 --- a/cpp/include/raft/cluster/detail/connectivities.cuh +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -#include - -namespace raft::cluster::detail { - -template -struct distance_graph_impl { - void run(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - rmm::device_uvector& indptr, - rmm::device_uvector& indices, - rmm::device_uvector& data, - int c); -}; - -/** - * Connectivities specialization to build a knn graph - * @tparam value_idx - * @tparam value_t - */ -template -struct distance_graph_impl { - void run(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - rmm::device_uvector& indptr, - rmm::device_uvector& indices, - rmm::device_uvector& data, - int c) - { - auto stream = resource::get_cuda_stream(handle); - auto thrust_policy = resource::get_thrust_policy(handle); - - // Need to symmetrize knn into undirected graph - raft::sparse::COO knn_graph_coo(stream); - - raft::sparse::neighbors::knn_graph(handle, X, m, n, metric, knn_graph_coo, c); - - indices.resize(knn_graph_coo.nnz, stream); - data.resize(knn_graph_coo.nnz, stream); - - // self-loops get max distance - auto transform_in = thrust::make_zip_iterator( - thrust::make_tuple(knn_graph_coo.rows(), knn_graph_coo.cols(), knn_graph_coo.vals())); - - thrust::transform(thrust_policy, - transform_in, - transform_in + knn_graph_coo.nnz, - knn_graph_coo.vals(), - [=] __device__(const thrust::tuple& tup) { - bool self_loop = thrust::get<0>(tup) == thrust::get<1>(tup); - return (self_loop * std::numeric_limits::max()) + - (!self_loop * thrust::get<2>(tup)); - }); - - raft::sparse::convert::sorted_coo_to_csr( - knn_graph_coo.rows(), knn_graph_coo.nnz, indptr.data(), m + 1, stream); - - // TODO: Wouldn't need to copy here if we could compute knn - // graph directly on the device uvectors - // ref: https://github.com/rapidsai/raft/issues/227 - raft::copy_async(indices.data(), knn_graph_coo.cols(), knn_graph_coo.nnz, stream); - raft::copy_async(data.data(), knn_graph_coo.vals(), knn_graph_coo.nnz, stream); - } -}; - -template -RAFT_KERNEL fill_indices2(value_idx* indices, size_t m, size_t nnz) -{ - value_idx tid = (blockIdx.x * blockDim.x) + threadIdx.x; - if (tid >= nnz) return; - value_idx v = tid % m; - indices[tid] = v; -} - -/** - * Compute connected CSR of pairwise distances - * @tparam value_idx - * @tparam value_t - * @param handle - * @param X - * @param m - * @param n - * @param metric - * @param[out] indptr - * @param[out] indices - * @param[out] data - */ -template -void pairwise_distances(const raft::resources& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - value_idx* indptr, - value_idx* indices, - value_t* data) -{ - auto stream = resource::get_cuda_stream(handle); - auto exec_policy = resource::get_thrust_policy(handle); - - value_idx nnz = m * m; - - value_idx blocks = raft::ceildiv(nnz, (value_idx)256); - fill_indices2<<>>(indices, m, nnz); - - thrust::sequence(exec_policy, indptr, indptr + m, 0, (int)m); - - raft::update_device(indptr + m, &nnz, 1, stream); - - // TODO: It would ultimately be nice if the MST could accept - // dense inputs directly so we don't need to double the memory - // usage to hand it a sparse array here. - distance::pairwise_distance(handle, X, X, data, m, m, n, metric); - // self-loops get max distance - auto transform_in = - thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), data)); - - thrust::transform(exec_policy, - transform_in, - transform_in + nnz, - data, - [=] __device__(const thrust::tuple& tup) { - value_idx idx = thrust::get<0>(tup); - bool self_loop = idx % m == idx / m; - return (self_loop * std::numeric_limits::max()) + - (!self_loop * thrust::get<1>(tup)); - }); -} - -/** - * Connectivities specialization for pairwise distances - * @tparam value_idx - * @tparam value_t - */ -template -struct distance_graph_impl { - void run(const raft::resources& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - rmm::device_uvector& indptr, - rmm::device_uvector& indices, - rmm::device_uvector& data, - int c) - { - auto stream = resource::get_cuda_stream(handle); - - size_t nnz = m * m; - - indices.resize(nnz, stream); - data.resize(nnz, stream); - - pairwise_distances(handle, X, m, n, metric, indptr.data(), indices.data(), data.data()); - } -}; - -/** - * Returns a CSR connectivities graph based on the given linkage distance. - * @tparam value_idx - * @tparam value_t - * @tparam dist_type - * @param[in] handle raft handle - * @param[in] X dense data for which to construct connectivites - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[in] metric distance metric to use - * @param[out] indptr indptr array of connectivities graph - * @param[out] indices column indices array of connectivities graph - * @param[out] data distances array of connectivities graph - * @param[out] c constant 'c' used for nearest neighbors-based distances - * which will guarantee k <= log(n) + c - */ -template -void get_distance_graph(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - rmm::device_uvector& indptr, - rmm::device_uvector& indices, - rmm::device_uvector& data, - int c) -{ - auto stream = resource::get_cuda_stream(handle); - - indptr.resize(m + 1, stream); - - distance_graph_impl dist_graph; - dist_graph.run(handle, X, m, n, metric, indptr, indices, data, c); -} - -}; // namespace raft::cluster::detail diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh deleted file mode 100644 index 4efeedcbaa..0000000000 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ /dev/null @@ -1,1255 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace cluster { -namespace detail { - -// ========================================================= -// Init functions -// ========================================================= - -// Selects 'n_clusters' samples randomly from X -template -void initRandom(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids) -{ - common::nvtx::range fun_scope("initRandom"); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_clusters = params.n_clusters; - detail::shuffleAndGather(handle, X, centroids, n_clusters, params.rng_state.seed); -} - -/* - * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. - - * @note This is the algorithm described in - * "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. - * ACM-SIAM symposium on Discrete algorithms. - * - * Scalable kmeans++ pseudocode - * 1: C = sample a point uniformly at random from X - * 2: while |C| < k - * 3: Sample x in X with probability p_x = d^2(x, C) / phi_X (C) - * 4: C = C U {x} - * 5: end for - */ -template -void kmeansPlusPlus(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroidsRawData, - rmm::device_uvector& workspace) -{ - common::nvtx::range fun_scope("kmeansPlusPlus"); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - // number of seeding trials for each center (except the first) - auto n_trials = 2 + static_cast(std::ceil(log(n_clusters))); - - RAFT_LOG_DEBUG( - "Run sequential k-means++ to select %d centroids from %d input samples " - "(%d seeding trials per iterations)", - n_clusters, - n_samples, - n_trials); - - auto dataBatchSize = getDataBatchSize(params.batch_samples, n_samples); - - // temporary buffers - auto indices = raft::make_device_vector(handle, n_trials); - auto centroidCandidates = raft::make_device_matrix(handle, n_trials, n_features); - auto costPerCandidate = raft::make_device_vector(handle, n_trials); - auto minClusterDistance = raft::make_device_vector(handle, n_samples); - auto distBuffer = raft::make_device_matrix(handle, n_trials, n_samples); - - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - rmm::device_scalar clusterCost(stream); - rmm::device_scalar> minClusterIndexAndDistance(stream); - - // Device and matrix views - raft::device_vector_view indices_view(indices.data_handle(), n_trials); - auto const_weights_view = - raft::make_device_vector_view(minClusterDistance.data_handle(), n_samples); - auto const_indices_view = - raft::make_device_vector_view(indices.data_handle(), n_trials); - auto const_X_view = - raft::make_device_matrix_view(X.data_handle(), n_samples, n_features); - raft::device_matrix_view candidates_view( - centroidCandidates.data_handle(), n_trials, n_features); - - // L2 norm of X: ||c||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - raft::random::RngState rng(params.rng_state.seed, params.rng_state.type); - std::mt19937 gen(params.rng_state.seed); - std::uniform_int_distribution<> dis(0, n_samples - 1); - - // <<< Step-1 >>>: C <-- sample a point uniformly at random from X - auto initialCentroid = raft::make_device_matrix_view( - X.data_handle() + dis(gen) * n_features, 1, n_features); - int n_clusters_picked = 1; - - // store the chosen centroid in the buffer - raft::copy( - centroidsRawData.data_handle(), initialCentroid.data_handle(), initialCentroid.size(), stream); - - // C = initial set of centroids - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), initialCentroid.extent(0), initialCentroid.extent(1)); - // <<< End of Step-1 >>> - - // Calculate cluster distance, d^2(x, C), for all the points x in X to the nearest centroid - detail::minClusterDistanceCompute(handle, - X, - centroids, - minClusterDistance.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - RAFT_LOG_DEBUG(" k-means++ - Sampled %d/%d centroids", n_clusters_picked, n_clusters); - - // <<<< Step-2 >>> : while |C| < k - while (n_clusters_picked < n_clusters) { - // <<< Step-3 >>> : Sample x in X with probability p_x = d^2(x, C) / phi_X (C) - // Choose 'n_trials' centroid candidates from X with probability proportional to the squared - // distance to the nearest existing cluster - - raft::random::discrete(handle, rng, indices_view, const_weights_view); - raft::matrix::gather(handle, const_X_view, const_indices_view, candidates_view); - - // Calculate pairwise distance between X and the centroid candidates - // Output - pwd [n_trials x n_samples] - auto pwd = distBuffer.view(); - detail::pairwise_distance_kmeans( - handle, centroidCandidates.view(), X, pwd, workspace, metric); - - // Update nearest cluster distance for each centroid candidate - // Note pwd and minDistBuf points to same buffer which currently holds pairwise distance values. - // Outputs minDistanceBuf[n_trials x n_samples] where minDistance[i, :] contains updated - // minClusterDistance that includes candidate-i - auto minDistBuf = distBuffer.view(); - raft::linalg::matrixVectorOp(minDistBuf.data_handle(), - pwd.data_handle(), - minClusterDistance.data_handle(), - pwd.extent(1), - pwd.extent(0), - true, - true, - raft::min_op{}, - stream); - - // Calculate costPerCandidate[n_trials] where costPerCandidate[i] is the cluster cost when using - // centroid candidate-i - raft::linalg::reduce(costPerCandidate.data_handle(), - minDistBuf.data_handle(), - minDistBuf.extent(1), - minDistBuf.extent(0), - static_cast(0), - true, - true, - stream); - - // Greedy Choice - Choose the candidate that has minimum cluster cost - // ArgMin operation below identifies the index of minimum cost in costPerCandidate - { - // Determine temporary device storage requirements - size_t temp_storage_bytes = 0; - cub::DeviceReduce::ArgMin(nullptr, - temp_storage_bytes, - costPerCandidate.data_handle(), - minClusterIndexAndDistance.data(), - costPerCandidate.extent(0), - stream); - - // Allocate temporary storage - workspace.resize(temp_storage_bytes, stream); - - // Run argmin-reduction - cub::DeviceReduce::ArgMin(workspace.data(), - temp_storage_bytes, - costPerCandidate.data_handle(), - minClusterIndexAndDistance.data(), - costPerCandidate.extent(0), - stream); - - int bestCandidateIdx = -1; - raft::copy(&bestCandidateIdx, &minClusterIndexAndDistance.data()->key, 1, stream); - resource::sync_stream(handle); - /// <<< End of Step-3 >>> - - /// <<< Step-4 >>>: C = C U {x} - // Update minimum cluster distance corresponding to the chosen centroid candidate - raft::copy(minClusterDistance.data_handle(), - minDistBuf.data_handle() + bestCandidateIdx * n_samples, - n_samples, - stream); - - raft::copy(centroidsRawData.data_handle() + n_clusters_picked * n_features, - centroidCandidates.data_handle() + bestCandidateIdx * n_features, - n_features, - stream); - - ++n_clusters_picked; - /// <<< End of Step-4 >>> - } - - RAFT_LOG_DEBUG(" k-means++ - Sampled %d/%d centroids", n_clusters_picked, n_clusters); - } /// <<<< Step-5 >>> -} - -/** - * - * @tparam DataT - * @tparam IndexT - * @param handle - * @param[in] X input matrix (size n_samples, n_features) - * @param[in] weight number of samples currently assigned to each centroid - * @param[in] cur_centroids matrix of current centroids (size n_clusters, n_features) - * @param[in] l2norm_x - * @param[out] min_cluster_and_dist - * @param[out] new_centroids - * @param[out] new_weight - * @param[inout] workspace - */ -template -void update_centroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroids, - - // TODO: Figure out how to best wrap iterator types in mdspan - LabelsIterator cluster_labels, - raft::device_vector_view weight_per_cluster, - raft::device_matrix_view new_centroids, - rmm::device_uvector& workspace) -{ - auto n_clusters = centroids.extent(0); - auto n_samples = X.extent(0); - - workspace.resize(n_samples, resource::get_cuda_stream(handle)); - - // Calculates weighted sum of all the samples assigned to cluster-i and stores the - // result in new_centroids[i] - raft::linalg::reduce_rows_by_key((DataT*)X.data_handle(), - X.extent(1), - cluster_labels, - sample_weights.data_handle(), - workspace.data(), - X.extent(0), - X.extent(1), - n_clusters, - new_centroids.data_handle(), - resource::get_cuda_stream(handle)); - - // Reduce weights by key to compute weight in each cluster - raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), - cluster_labels, - weight_per_cluster.data_handle(), - (IndexT)1, - (IndexT)sample_weights.extent(0), - (IndexT)n_clusters, - resource::get_cuda_stream(handle)); - - // Computes new_centroids[i] = new_centroids[i]/weight_per_cluster[i] where - // new_centroids[n_clusters x n_features] - 2D array, new_centroids[i] has sum of all the - // samples assigned to cluster-i - // weight_per_cluster[n_clusters] - 1D array, weight_per_cluster[i] contains sum of weights in - // cluster-i. - // Note - when weight_per_cluster[i] is 0, new_centroids[i] is reset to 0 - raft::linalg::matrixVectorOp(new_centroids.data_handle(), - new_centroids.data_handle(), - weight_per_cluster.data_handle(), - new_centroids.extent(1), - new_centroids.extent(0), - true, - false, - raft::div_checkzero_op{}, - resource::get_cuda_stream(handle)); - - // copy centroids[i] to new_centroids[i] when weight_per_cluster[i] is 0 - cub::ArgIndexInputIterator itr_wt(weight_per_cluster.data_handle()); - raft::matrix::gather_if( - const_cast(centroids.data_handle()), - static_cast(centroids.extent(1)), - static_cast(centroids.extent(0)), - itr_wt, - itr_wt, - static_cast(weight_per_cluster.size()), - new_centroids.data_handle(), - [=] __device__(raft::KeyValuePair map) { // predicate - // copy when the sum of weights in the cluster is 0 - return map.value == 0; - }, - raft::key_op{}, - resource::get_cuda_stream(handle)); -} - -// TODO: Resizing is needed to use mdarray instead of rmm::device_uvector -template -void kmeans_fit_main(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view weight, - raft::device_matrix_view centroidsRawData, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) -{ - common::nvtx::range fun_scope("kmeans_fit_main"); - logger::get(RAFT_NAME).set_level(params.verbosity); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - - // temporary buffer to store L2 norm of centroids or distance matrix, - // destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // temporary buffer to store intermediate centroids, destructor releases the - // resource - auto newCentroids = raft::make_device_matrix(handle, n_clusters, n_features); - - // temporary buffer to store weights per cluster, destructor releases the - // resource - auto wtInCluster = raft::make_device_vector(handle, n_clusters); - - rmm::device_scalar clusterCostD(stream); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - auto l2normx_view = - raft::make_device_vector_view(L2NormX.data_handle(), n_samples); - - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - RAFT_LOG_DEBUG( - "Calling KMeans.fit with %d samples of input data and the initialized " - "cluster centers", - n_samples); - - DataT priorClusteringCost = 0; - for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { - RAFT_LOG_DEBUG( - "KMeans.fit: Iteration-%d: fitting the model using the initialized " - "cluster centers", - n_iter[0]); - - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features); - - // computes minClusterAndDistance[0:n_samples) where - // minClusterAndDistance[i] is a pair where - // 'key' is index to a sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - detail::minClusterAndDistanceCompute(handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // Using TransformInputIteratorT to dereference an array of - // raft::KeyValuePair and converting them to just return the Key to be used - // in reduce_rows_by_key prims - detail::KeyValueIndexOp conversion_op; - cub::TransformInputIterator, - raft::KeyValuePair*> - itr(minClusterAndDistance.data_handle(), conversion_op); - - update_centroids(handle, - X, - weight, - raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features), - itr, - wtInCluster.view(), - newCentroids.view(), - workspace); - - // compute the squared norm between the newCentroids and the original - // centroids, destructor releases the resource - auto sqrdNorm = raft::make_device_scalar(handle, DataT(0)); - raft::linalg::mapThenSumReduce(sqrdNorm.data_handle(), - newCentroids.size(), - raft::sqdiff_op{}, - stream, - centroids.data_handle(), - newCentroids.data_handle()); - - DataT sqrdNormError = 0; - raft::copy(&sqrdNormError, sqrdNorm.data_handle(), sqrdNorm.size(), stream); - - raft::copy( - centroidsRawData.data_handle(), newCentroids.data_handle(), newCentroids.size(), stream); - - bool done = false; - if (params.inertia_check) { - // calculate cluster cost phi_x(C) - detail::computeClusterCost(handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - DataT curClusteringCost = clusterCostD.value(stream); - - ASSERT(curClusteringCost != (DataT)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers"); - - if (n_iter[0] > 1) { - DataT delta = curClusteringCost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } - priorClusteringCost = curClusteringCost; - } - - resource::sync_stream(handle, stream); - if (sqrdNormError < params.tol) done = true; - - if (done) { - RAFT_LOG_DEBUG("Threshold triggered after %d iterations. Terminating early.", n_iter[0]); - break; - } - } - - auto centroids = raft::make_device_matrix_view( - centroidsRawData.data_handle(), n_clusters, n_features); - - detail::minClusterAndDistanceCompute(handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // TODO: add different templates for InType of binaryOp to avoid thrust transform - thrust::transform(resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - weight.data_handle(), - minClusterAndDistance.data_handle(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }); - - // calculate cluster cost phi_x(C) - detail::computeClusterCost(handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - inertia[0] = clusterCostD.value(stream); - - RAFT_LOG_DEBUG("KMeans.fit: completed after %d iterations with %f inertia[0] ", - n_iter[0] > params.max_iter ? n_iter[0] - 1 : n_iter[0], - inertia[0]); -} - -/* - * @brief Selects 'n_clusters' samples from X using scalable kmeans++ algorithm. - - * @note This is the algorithm described in - * "Scalable K-Means++", 2012, Bahman Bahmani, Benjamin Moseley, - * Andrea Vattani, Ravi Kumar, Sergei Vassilvitskii, - * https://arxiv.org/abs/1203.6402 - - * Scalable kmeans++ pseudocode - * 1: C = sample a point uniformly at random from X - * 2: psi = phi_X (C) - * 3: for O( log(psi) ) times do - * 4: C' = sample each point x in X independently with probability - * p_x = l * (d^2(x, C) / phi_X (C) ) - * 5: C = C U C' - * 6: end for - * 7: For x in C, set w_x to be the number of points in X closer to x than any - * other point in C - * 8: Recluster the weighted points in C into k clusters - - * TODO: Resizing is needed to use mdarray instead of rmm::device_uvector - - */ -template -void initScalableKMeansPlusPlus(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroidsRawData, - rmm::device_uvector& workspace) -{ - common::nvtx::range fun_scope("initScalableKMeansPlusPlus"); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - raft::random::RngState rng(params.rng_state.seed, params.rng_state.type); - - // <<<< Step-1 >>> : C <- sample a point uniformly at random from X - std::mt19937 gen(params.rng_state.seed); - std::uniform_int_distribution<> dis(0, n_samples - 1); - - auto cIdx = dis(gen); - auto initialCentroid = raft::make_device_matrix_view( - X.data_handle() + cIdx * n_features, 1, n_features); - - // flag the sample that is chosen as initial centroid - std::vector h_isSampleCentroid(n_samples); - std::fill(h_isSampleCentroid.begin(), h_isSampleCentroid.end(), 0); - h_isSampleCentroid[cIdx] = 1; - - // device buffer to flag the sample that is chosen as initial centroid - auto isSampleCentroid = raft::make_device_vector(handle, n_samples); - - raft::copy( - isSampleCentroid.data_handle(), h_isSampleCentroid.data(), isSampleCentroid.size(), stream); - - rmm::device_uvector centroidsBuf(initialCentroid.size(), stream); - - // reset buffer to store the chosen centroid - raft::copy(centroidsBuf.data(), initialCentroid.data_handle(), initialCentroid.size(), stream); - - auto potentialCentroids = raft::make_device_matrix_view( - centroidsBuf.data(), initialCentroid.extent(0), initialCentroid.extent(1)); - // <<< End of Step-1 >>> - - // temporary buffer to store L2 norm of centroids or distance matrix, - // destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - auto minClusterDistanceVec = raft::make_device_vector(handle, n_samples); - auto uniformRands = raft::make_device_vector(handle, n_samples); - rmm::device_scalar clusterCost(stream); - - // <<< Step-2 >>>: psi <- phi_X (C) - detail::minClusterDistanceCompute(handle, - X, - potentialCentroids, - minClusterDistanceVec.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // compute partial cluster cost from the samples in rank - detail::computeClusterCost(handle, - minClusterDistanceVec.view(), - workspace, - raft::make_device_scalar_view(clusterCost.data()), - raft::identity_op{}, - raft::add_op{}); - - auto psi = clusterCost.value(stream); - - // <<< End of Step-2 >>> - - // Scalable kmeans++ paper claims 8 rounds is sufficient - resource::sync_stream(handle, stream); - int niter = std::min(8, (int)ceil(log(psi))); - RAFT_LOG_DEBUG("KMeans||: psi = %g, log(psi) = %g, niter = %d ", psi, log(psi), niter); - - // <<<< Step-3 >>> : for O( log(psi) ) times do - for (int iter = 0; iter < niter; ++iter) { - RAFT_LOG_DEBUG("KMeans|| - Iteration %d: # potential centroids sampled - %d", - iter, - potentialCentroids.extent(0)); - - detail::minClusterDistanceCompute(handle, - X, - potentialCentroids, - minClusterDistanceVec.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - detail::computeClusterCost(handle, - minClusterDistanceVec.view(), - workspace, - raft::make_device_scalar_view(clusterCost.data()), - raft::identity_op{}, - raft::add_op{}); - - psi = clusterCost.value(stream); - - // <<<< Step-4 >>> : Sample each point x in X independently and identify new - // potentialCentroids - raft::random::uniform( - handle, rng, uniformRands.data_handle(), uniformRands.extent(0), (DataT)0, (DataT)1); - - detail::SamplingOp select_op(psi, - params.oversampling_factor, - n_clusters, - uniformRands.data_handle(), - isSampleCentroid.data_handle()); - - rmm::device_uvector CpRaw(0, stream); - detail::sampleCentroids(handle, - X, - minClusterDistanceVec.view(), - isSampleCentroid.view(), - select_op, - CpRaw, - workspace); - auto Cp = raft::make_device_matrix_view( - CpRaw.data(), CpRaw.size() / n_features, n_features); - /// <<<< End of Step-4 >>>> - - /// <<<< Step-5 >>> : C = C U C' - // append the data in Cp to the buffer holding the potentialCentroids - centroidsBuf.resize(centroidsBuf.size() + Cp.size(), stream); - raft::copy( - centroidsBuf.data() + centroidsBuf.size() - Cp.size(), Cp.data_handle(), Cp.size(), stream); - - IndexT tot_centroids = potentialCentroids.extent(0) + Cp.extent(0); - potentialCentroids = - raft::make_device_matrix_view(centroidsBuf.data(), tot_centroids, n_features); - /// <<<< End of Step-5 >>> - } /// <<<< Step-6 >>> - - RAFT_LOG_DEBUG("KMeans||: total # potential centroids sampled - %d", - potentialCentroids.extent(0)); - - if ((int)potentialCentroids.extent(0) > n_clusters) { - // <<< Step-7 >>>: For x in C, set w_x to be the number of pts closest to X - // temporary buffer to store the sample count per cluster, destructor - // releases the resource - auto weight = raft::make_device_vector(handle, potentialCentroids.extent(0)); - - detail::countSamplesInCluster( - handle, params, X, L2NormX.view(), potentialCentroids, workspace, weight.view()); - - // <<< end of Step-7 >>> - - // Step-8: Recluster the weighted points in C into k clusters - detail::kmeansPlusPlus( - handle, params, potentialCentroids, centroidsRawData, workspace); - - auto inertia = make_host_scalar(0); - auto n_iter = make_host_scalar(0); - KMeansParams default_params; - default_params.n_clusters = params.n_clusters; - - detail::kmeans_fit_main(handle, - default_params, - potentialCentroids, - weight.view(), - centroidsRawData, - inertia.view(), - n_iter.view(), - workspace); - - } else if ((int)potentialCentroids.extent(0) < n_clusters) { - // supplement with random - auto n_random_clusters = n_clusters - potentialCentroids.extent(0); - - RAFT_LOG_DEBUG( - "[Warning!] KMeans||: found fewer than %d centroids during " - "initialization (found %d centroids, remaining %d centroids will be " - "chosen randomly from input samples)", - n_clusters, - potentialCentroids.extent(0), - n_random_clusters); - - // generate `n_random_clusters` centroids - KMeansParams rand_params; - rand_params.init = KMeansParams::InitMethod::Random; - rand_params.n_clusters = n_random_clusters; - initRandom(handle, rand_params, X, centroidsRawData); - - // copy centroids generated during kmeans|| iteration to the buffer - raft::copy(centroidsRawData.data_handle() + n_random_clusters * n_features, - potentialCentroids.data_handle(), - potentialCentroids.size(), - stream); - } else { - // found the required n_clusters - raft::copy(centroidsRawData.data_handle(), - potentialCentroids.data_handle(), - potentialCentroids.size(), - stream); - } -} - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. It must be noted - * that the data must be in row-major format and stored in device accessible - * location. - * @param[in] n_samples Number of samples in the input X. - * @param[in] n_features Number of features or the dimensions of each - * sample. - * @param[in] sample_weight Optional weights for each observation in X. - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] Otherwise, generated centroids from the - * kmeans algorithm is stored at the address pointed by 'centroids'. - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -void kmeans_fit(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - common::nvtx::range fun_scope("kmeans_fit"); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - cudaStream_t stream = resource::get_cuda_stream(handle); - // Check that parameters are valid - if (sample_weight.has_value()) - RAFT_EXPECTS(sample_weight.value().extent(0) == n_samples, - "invalid parameter (sample_weight!=n_samples)"); - RAFT_EXPECTS(n_clusters > 0, "invalid parameter (n_clusters<=0)"); - RAFT_EXPECTS(params.tol > 0, "invalid parameter (tol<=0)"); - RAFT_EXPECTS(params.oversampling_factor >= 0, "invalid parameter (oversampling_factor<0)"); - RAFT_EXPECTS((int)centroids.extent(0) == params.n_clusters, - "invalid parameter (centroids.extent(0) != n_clusters)"); - RAFT_EXPECTS(centroids.extent(1) == n_features, - "invalid parameter (centroids.extent(1) != n_features)"); - - // Display a message if the batch size is smaller than n_samples but will be ignored - if (params.batch_samples < (int)n_samples && - (params.metric == raft::distance::DistanceType::L2Expanded || - params.metric == raft::distance::DistanceType::L2SqrtExpanded)) { - RAFT_LOG_DEBUG( - "batch_samples=%d was passed, but batch_samples=%d will be used (reason: " - "batch_samples has no impact on the memory footprint when FusedL2NN can be used)", - params.batch_samples, - (int)n_samples); - } - // Display a message if batch_centroids is set and a fusedL2NN-compatible metric is used - if (params.batch_centroids != 0 && params.batch_centroids != params.n_clusters && - (params.metric == raft::distance::DistanceType::L2Expanded || - params.metric == raft::distance::DistanceType::L2SqrtExpanded)) { - RAFT_LOG_DEBUG( - "batch_centroids=%d was passed, but batch_centroids=%d will be used (reason: " - "batch_centroids has no impact on the memory footprint when FusedL2NN can be used)", - params.batch_centroids, - params.n_clusters); - } - - logger::get(RAFT_NAME).set_level(params.verbosity); - - // Allocate memory - rmm::device_uvector workspace(0, stream); - auto weight = raft::make_device_vector(handle, n_samples); - if (sample_weight.has_value()) - raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream); - else - thrust::fill(resource::get_thrust_policy(handle), - weight.data_handle(), - weight.data_handle() + weight.size(), - 1); - - // check if weights sum up to n_samples - checkWeight(handle, weight.view(), workspace); - - auto centroidsRawData = raft::make_device_matrix(handle, n_clusters, n_features); - - auto n_init = params.n_init; - if (params.init == KMeansParams::InitMethod::Array && n_init != 1) { - RAFT_LOG_DEBUG( - "Explicit initial center position passed: performing only one init in " - "k-means instead of n_init=%d", - n_init); - n_init = 1; - } - - std::mt19937 gen(params.rng_state.seed); - inertia[0] = std::numeric_limits::max(); - - for (auto seed_iter = 0; seed_iter < n_init; ++seed_iter) { - KMeansParams iter_params = params; - iter_params.rng_state.seed = gen(); - - DataT iter_inertia = std::numeric_limits::max(); - IndexT n_current_iter = 0; - if (iter_params.init == KMeansParams::InitMethod::Random) { - // initializing with random samples from input dataset - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers by " - "randomly choosing from the " - "input data.", - seed_iter + 1, - n_init); - initRandom(handle, iter_params, X, centroidsRawData.view()); - } else if (iter_params.init == KMeansParams::InitMethod::KMeansPlusPlus) { - // default method to initialize is kmeans++ - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers using " - "k-means++ algorithm.", - seed_iter + 1, - n_init); - if (iter_params.oversampling_factor == 0) - detail::kmeansPlusPlus( - handle, iter_params, X, centroidsRawData.view(), workspace); - else - detail::initScalableKMeansPlusPlus( - handle, iter_params, X, centroidsRawData.view(), workspace); - } else if (iter_params.init == KMeansParams::InitMethod::Array) { - RAFT_LOG_DEBUG( - "KMeans.fit (Iteration-%d/%d): initialize cluster centers from " - "the ndarray array input " - "passed to init argument.", - seed_iter + 1, - n_init); - raft::copy( - centroidsRawData.data_handle(), centroids.data_handle(), n_clusters * n_features, stream); - } else { - THROW("unknown initialization method to select initial centers"); - } - - detail::kmeans_fit_main(handle, - iter_params, - X, - weight.view(), - centroidsRawData.view(), - raft::make_host_scalar_view(&iter_inertia), - raft::make_host_scalar_view(&n_current_iter), - workspace); - if (iter_inertia < inertia[0]) { - inertia[0] = iter_inertia; - n_iter[0] = n_current_iter; - raft::copy( - centroids.data_handle(), centroidsRawData.data_handle(), n_clusters * n_features, stream); - } - RAFT_LOG_DEBUG("KMeans.fit after iteration-%d/%d: inertia - %f, n_iter[0] - %d", - seed_iter + 1, - n_init, - inertia[0], - n_iter[0]); - } - RAFT_LOG_DEBUG("KMeans.fit: async call returned (fit could still be running on the device)"); -} - -template -void kmeans_fit(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT& inertia, - IndexT& n_iter) -{ - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); - auto centroidsView = - raft::make_device_matrix_view(centroids, params.n_clusters, n_features); - std::optional> sample_weightView = std::nullopt; - if (sample_weight) - sample_weightView = - raft::make_device_vector_view(sample_weight, n_samples); - auto inertiaView = raft::make_host_scalar_view(&inertia); - auto n_iterView = raft::make_host_scalar_view(&n_iter); - - detail::kmeans_fit( - handle, params, XView, sample_weightView, centroidsView, inertiaView, n_iterView); -} - -template -void kmeans_predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) -{ - common::nvtx::range fun_scope("kmeans_predict"); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - cudaStream_t stream = resource::get_cuda_stream(handle); - // Check that parameters are valid - if (sample_weight.has_value()) - RAFT_EXPECTS(sample_weight.value().extent(0) == n_samples, - "invalid parameter (sample_weight!=n_samples)"); - RAFT_EXPECTS(params.n_clusters > 0, "invalid parameter (n_clusters<=0)"); - RAFT_EXPECTS(params.tol > 0, "invalid parameter (tol<=0)"); - RAFT_EXPECTS(params.oversampling_factor >= 0, "invalid parameter (oversampling_factor<0)"); - RAFT_EXPECTS((int)centroids.extent(0) == params.n_clusters, - "invalid parameter (centroids.extent(0) != n_clusters)"); - RAFT_EXPECTS(centroids.extent(1) == n_features, - "invalid parameter (centroids.extent(1) != n_features)"); - - logger::get(RAFT_NAME).set_level(params.verbosity); - auto metric = params.metric; - - // Allocate memory - // Device-accessible allocation of expandable storage used as temporary buffers - rmm::device_uvector workspace(0, stream); - auto weight = raft::make_device_vector(handle, n_samples); - if (sample_weight.has_value()) - raft::copy(weight.data_handle(), sample_weight.value().data_handle(), n_samples, stream); - else - thrust::fill(resource::get_thrust_policy(handle), - weight.data_handle(), - weight.data_handle() + weight.size(), - 1); - - // check if weights sum up to n_samples - if (normalize_weight) checkWeight(handle, weight.view(), workspace); - - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // L2 norm of X: ||x||^2 - auto L2NormX = raft::make_device_vector(handle, n_samples); - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - raft::linalg::rowNorm(L2NormX.data_handle(), - X.data_handle(), - X.extent(1), - X.extent(0), - raft::linalg::L2Norm, - true, - stream); - } - - // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] - // is a pair where - // 'key' is index to a sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - auto l2normx_view = - raft::make_device_vector_view(L2NormX.data_handle(), n_samples); - detail::minClusterAndDistanceCompute(handle, - X, - centroids, - minClusterAndDistance.view(), - l2normx_view, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // calculate cluster cost phi_x(C) - rmm::device_scalar clusterCostD(stream); - // TODO: add different templates for InType of binaryOp to avoid thrust transform - thrust::transform(resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - weight.data_handle(), - minClusterAndDistance.data_handle(), - [=] __device__(const raft::KeyValuePair kvp, DataT wt) { - raft::KeyValuePair res; - res.value = kvp.value * wt; - res.key = kvp.key; - return res; - }); - - detail::computeClusterCost(handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - raft::value_op{}, - raft::add_op{}); - - thrust::transform(resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - labels.data_handle(), - raft::key_op{}); - - inertia[0] = clusterCostD.value(stream); -} - -template -void kmeans_predict(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - bool normalize_weight, - DataT& inertia) -{ - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); - auto centroidsView = - raft::make_device_matrix_view(centroids, params.n_clusters, n_features); - std::optional> sample_weightView{std::nullopt}; - if (sample_weight) - sample_weightView.emplace( - raft::make_device_vector_view(sample_weight, n_samples)); - auto labelsView = raft::make_device_vector_view(labels, n_samples); - auto inertiaView = raft::make_host_scalar_view(&inertia); - - detail::kmeans_predict(handle, - params, - XView, - sample_weightView, - centroidsView, - labelsView, - normalize_weight, - inertiaView); -} - -template -void kmeans_fit_predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - common::nvtx::range fun_scope("kmeans_fit_predict"); - if (!centroids.has_value()) { - auto n_features = X.extent(1); - auto centroids_matrix = - raft::make_device_matrix(handle, params.n_clusters, n_features); - detail::kmeans_fit( - handle, params, X, sample_weight, centroids_matrix.view(), inertia, n_iter); - detail::kmeans_predict( - handle, params, X, sample_weight, centroids_matrix.view(), labels, true, inertia); - } else { - detail::kmeans_fit( - handle, params, X, sample_weight, centroids.value(), inertia, n_iter); - detail::kmeans_predict( - handle, params, X, sample_weight, centroids.value(), labels, true, inertia); - } -} - -template -void kmeans_fit_predict(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - DataT& inertia, - IndexT& n_iter) -{ - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); - std::optional> sample_weightView{std::nullopt}; - if (sample_weight) - sample_weightView.emplace( - raft::make_device_vector_view(sample_weight, n_samples)); - std::optional> centroidsView{std::nullopt}; - if (centroids) - centroidsView.emplace( - raft::make_device_matrix_view(centroids, params.n_clusters, n_features)); - auto labelsView = raft::make_device_vector_view(labels, n_samples); - auto inertiaView = raft::make_host_scalar_view(&inertia); - auto n_iterView = raft::make_host_scalar_view(&n_iter); - - detail::kmeans_fit_predict( - handle, params, XView, sample_weightView, centroidsView, labelsView, inertiaView, n_iterView); -} - -/** - * @brief Transform X to a cluster-distance space. - * - * @param[in] handle The handle to the cuML library context that - * manages the CUDA resources. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format - * @param[in] centroids Cluster centroids. The data must be in row-major format. - * @param[out] X_new X transformed in the new space.. - */ -template -void kmeans_transform(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_matrix_view X_new) -{ - common::nvtx::range fun_scope("kmeans_transform"); - logger::get(RAFT_NAME).set_level(params.verbosity); - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = params.n_clusters; - auto metric = params.metric; - - // Device-accessible allocation of expandable storage used as temporary buffers - rmm::device_uvector workspace(0, stream); - auto dataBatchSize = getDataBatchSize(params.batch_samples, n_samples); - - // tile over the input data and calculate distance matrix [n_samples x - // n_clusters] - for (IndexT dIdx = 0; dIdx < (IndexT)n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min(static_cast(dataBatchSize), static_cast(n_samples - dIdx)); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = raft::make_device_matrix_view( - X.data_handle() + n_features * dIdx, ns, n_features); - - // pairwiseDistanceView [ns x n_clusters] - auto pairwiseDistanceView = raft::make_device_matrix_view( - X_new.data_handle() + n_clusters * dIdx, ns, n_clusters); - - // calculate pairwise distance between cluster centroids and current batch - // of input dataset - pairwise_distance_kmeans( - handle, datasetView, centroids, pairwiseDistanceView, workspace, metric); - } -} - -template -void kmeans_transform(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT* X_new) -{ - auto XView = raft::make_device_matrix_view(X, n_samples, n_features); - auto centroidsView = - raft::make_device_matrix_view(centroids, params.n_clusters, n_features); - auto X_newView = raft::make_device_matrix_view(X_new, n_samples, n_features); - - detail::kmeans_transform(handle, params, XView, centroidsView, X_newView); -} -} // namespace detail -} // namespace cluster -} // namespace raft diff --git a/cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh b/cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh deleted file mode 100644 index 97755351c4..0000000000 --- a/cpp/include/raft/cluster/detail/kmeans_auto_find_k.cuh +++ /dev/null @@ -1,230 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::cluster::detail { - -template -void compute_dispersion(raft::resources const& handle, - raft::device_matrix_view X, - KMeansParams& params, - raft::device_matrix_view centroids_view, - raft::device_vector_view labels, - raft::device_vector_view clusterSizes, - rmm::device_uvector& workspace, - raft::host_vector_view clusterDispertionView, - raft::host_vector_view resultsView, - raft::host_scalar_view residual, - raft::host_scalar_view n_iter, - int val, - idx_t n, - idx_t d) -{ - auto centroids_const_view = - raft::make_device_matrix_view(centroids_view.data_handle(), val, d); - - idx_t* clusterSizes_ptr = clusterSizes.data_handle(); - auto cluster_sizes_view = - raft::make_device_vector_view(clusterSizes_ptr, val); - - params.n_clusters = val; - - raft::cluster::detail::kmeans_fit_predict( - handle, params, X, std::nullopt, std::make_optional(centroids_view), labels, residual, n_iter); - - detail::countLabels(handle, labels.data_handle(), clusterSizes.data_handle(), n, val, workspace); - - resultsView[val] = residual[0]; - clusterDispertionView[val] = raft::stats::cluster_dispersion( - handle, centroids_const_view, cluster_sizes_view, std::nullopt, n); -} - -template -void find_k(raft::resources const& handle, - raft::device_matrix_view X, - raft::host_scalar_view best_k, - raft::host_scalar_view residual, - raft::host_scalar_view n_iter, - idx_t kmax, - idx_t kmin = 1, - idx_t maxiter = 100, - value_t tol = 1e-2) -{ - idx_t n = X.extent(0); - idx_t d = X.extent(1); - - RAFT_EXPECTS(n >= 1, "n must be >= 1"); - RAFT_EXPECTS(d >= 1, "d must be >= 1"); - RAFT_EXPECTS(kmin >= 1, "kmin must be >= 1"); - RAFT_EXPECTS(kmax <= n, "kmax must be <= number of data samples in X"); - RAFT_EXPECTS(tol >= 0, "tolerance must be >= 0"); - RAFT_EXPECTS(maxiter >= 0, "maxiter must be >= 0"); - // Allocate memory - // Device memory - - auto centroids = raft::make_device_matrix(handle, kmax, X.extent(1)); - auto clusterSizes = raft::make_device_vector(handle, kmax); - auto labels = raft::make_device_vector(handle, n); - - rmm::device_uvector workspace(0, resource::get_cuda_stream(handle)); - - idx_t* clusterSizes_ptr = clusterSizes.data_handle(); - - // Host memory - auto results = raft::make_host_vector(kmax + 1); - auto clusterDispersion = raft::make_host_vector(kmax + 1); - - auto clusterDispertionView = clusterDispersion.view(); - auto resultsView = results.view(); - - // Loop to find *best* k - // Perform k-means in binary search - int left = kmin; // must be at least 2 - int right = kmax; // int(floor(len(data)/2)) #assumption of clusters of size 2 at least - int mid = ((unsigned int)left + (unsigned int)right) >> 1; - int oldmid = mid; - int tests = 0; - double objective[3]; // 0= left of mid, 1= right of mid - if (left == 1) left = 2; // at least do 2 clusters - - KMeansParams params; - params.max_iter = maxiter; - params.tol = tol; - - auto centroids_view = - raft::make_device_matrix_view(centroids.data_handle(), left, d); - compute_dispersion(handle, - X, - params, - centroids_view, - labels.view(), - clusterSizes.view(), - workspace, - clusterDispertionView, - resultsView, - residual, - n_iter, - left, - n, - d); - - // eval right edge0 - resultsView[right] = 1e20; - while (resultsView[right] > resultsView[left] && tests < 3) { - centroids_view = - raft::make_device_matrix_view(centroids.data_handle(), right, d); - compute_dispersion(handle, - X, - params, - centroids_view, - labels.view(), - clusterSizes.view(), - workspace, - clusterDispertionView, - resultsView, - residual, - n_iter, - right, - n, - d); - - tests += 1; - } - - objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left]; - objective[1] = (n - right) / (right - 1) * clusterDispertionView[right] / resultsView[right]; - while (left < right - 1) { - resultsView[mid] = 1e20; - tests = 0; - while (resultsView[mid] > resultsView[left] && tests < 3) { - centroids_view = - raft::make_device_matrix_view(centroids.data_handle(), mid, d); - compute_dispersion(handle, - X, - params, - centroids_view, - labels.view(), - clusterSizes.view(), - workspace, - clusterDispertionView, - resultsView, - residual, - n_iter, - mid, - n, - d); - - if (resultsView[mid] > resultsView[left] && (mid + 1) < right) { - mid += 1; - resultsView[mid] = 1e20; - } else if (resultsView[mid] > resultsView[left] && (mid - 1) > left) { - mid -= 1; - resultsView[mid] = 1e20; - } - tests += 1; - } - - // maximize Calinski-Harabasz Index, minimize resid/ cluster - objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left]; - objective[1] = (n - right) / (right - 1) * clusterDispertionView[right] / resultsView[right]; - objective[2] = (n - mid) / (mid - 1) * clusterDispertionView[mid] / resultsView[mid]; - objective[0] = (objective[2] - objective[0]) / (mid - left); - objective[1] = (objective[1] - objective[2]) / (right - mid); - - if (objective[0] > 0 && objective[1] < 0) { - // our point is in the left-of-mid side - right = mid; - } else { - left = mid; - } - oldmid = mid; - mid = ((unsigned int)right + (unsigned int)left) >> 1; - } - - best_k[0] = right; - objective[0] = (n - left) / (left - 1) * clusterDispertionView[left] / resultsView[left]; - objective[1] = (n - oldmid) / (oldmid - 1) * clusterDispertionView[oldmid] / resultsView[oldmid]; - if (objective[1] < objective[0]) { best_k[0] = left; } - - // if best_k isn't what we just ran, re-run to get correct centroids and dist data on return-> - // this saves memory - if (best_k[0] != oldmid) { - auto centroids_view = - raft::make_device_matrix_view(centroids.data_handle(), best_k[0], d); - - params.n_clusters = best_k[0]; - raft::cluster::detail::kmeans_fit_predict(handle, - params, - X, - std::nullopt, - std::make_optional(centroids_view), - labels.view(), - residual, - n_iter); - } -} -} // namespace raft::cluster::detail \ No newline at end of file diff --git a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh deleted file mode 100644 index 0a5a3ba5aa..0000000000 --- a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh +++ /dev/null @@ -1,1089 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include -#include - -#include -#include -#include -#include - -namespace raft::cluster::detail { - -constexpr static inline float kAdjustCentersWeight = 7.0f; - -/** - * @brief Predict labels for the dataset; floating-point types only. - * - * NB: no minibatch splitting is done here, it may require large amount of temporary memory (n_rows - * * n_cluster * sizeof(MathT)). - * - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * - * @param[in] handle The raft handle. - * @param[in] params Structure containing the hyper-parameters - * @param[in] centers Pointer to the row-major matrix of cluster centers [n_clusters, dim] - * @param[in] n_clusters Number of clusters/centers - * @param[in] dim Dimensionality of the data - * @param[in] dataset Pointer to the data [n_rows, dim] - * @param[in] dataset_norm Pointer to the precomputed norm (for L2 metrics only) [n_rows] - * @param[in] n_rows Number samples in the `dataset` - * @param[out] labels Output predictions [n_rows] - * @param[inout] mr (optional) Memory resource to use for temporary allocations - */ -template -inline std::enable_if_t> predict_core( - const raft::resources& handle, - const kmeans_balanced_params& params, - const MathT* centers, - IdxT n_clusters, - IdxT dim, - const MathT* dataset, - const MathT* dataset_norm, - IdxT n_rows, - LabelT* labels, - rmm::device_async_resource_ref mr) -{ - auto stream = resource::get_cuda_stream(handle); - switch (params.metric) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2SqrtExpanded: { - auto workspace = raft::make_device_mdarray( - handle, mr, make_extents((sizeof(int)) * n_rows)); - - auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( - handle, mr, make_extents(n_rows)); - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - thrust::fill(resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - initial_value); - - auto centroidsNorm = - raft::make_device_mdarray(handle, mr, make_extents(n_clusters)); - raft::linalg::rowNorm( - centroidsNorm.data_handle(), centers, dim, n_clusters, raft::linalg::L2Norm, true, stream); - - raft::distance::fusedL2NNMinReduce, IdxT>( - minClusterAndDistance.data_handle(), - dataset, - centers, - dataset_norm, - centroidsNorm.data_handle(), - n_rows, - n_clusters, - dim, - (void*)workspace.data_handle(), - (params.metric == raft::distance::DistanceType::L2Expanded) ? false : true, - false, - stream); - - // todo(lsugy): use KVP + iterator in caller. - // Copy keys to output labels - thrust::transform(resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + n_rows, - labels, - raft::compose_op, raft::key_op>()); - break; - } - case raft::distance::DistanceType::InnerProduct: { - // TODO: pass buffer - rmm::device_uvector distances(n_rows * n_clusters, stream, mr); - - MathT alpha = -1.0; - MathT beta = 0.0; - - linalg::gemm(handle, - true, - false, - n_clusters, - n_rows, - dim, - &alpha, - centers, - dim, - dataset, - dim, - &beta, - distances.data(), - n_clusters, - stream); - - auto distances_const_view = raft::make_device_matrix_view( - distances.data(), n_rows, n_clusters); - auto labels_view = raft::make_device_vector_view(labels, n_rows); - raft::matrix::argmin(handle, distances_const_view, labels_view); - break; - } - default: { - RAFT_FAIL("The chosen distance metric is not supported (%d)", int(params.metric)); - } - } -} - -/** - * @brief Suggest a minibatch size for kmeans prediction. - * - * This function is used as a heuristic to split the work over a large dataset - * to reduce the size of temporary memory allocations. - * - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * - * @param[in] n_clusters number of clusters in kmeans clustering - * @param[in] n_rows Number of samples in the dataset - * @param[in] dim Number of features in the dataset - * @param[in] metric Distance metric - * @param[in] needs_conversion Whether the data needs to be converted to MathT - * @return A suggested minibatch size and the expected memory cost per-row (in bytes) - */ -template -constexpr auto calc_minibatch_size(IdxT n_clusters, - IdxT n_rows, - IdxT dim, - raft::distance::DistanceType metric, - bool needs_conversion) -> std::tuple -{ - n_clusters = std::max(1, n_clusters); - - // Estimate memory needs per row (i.e element of the batch). - size_t mem_per_row = 0; - switch (metric) { - // fusedL2NN needs a mutex and a key-value pair for each row. - case distance::DistanceType::L2Expanded: - case distance::DistanceType::L2SqrtExpanded: { - mem_per_row += sizeof(int); - mem_per_row += sizeof(raft::KeyValuePair); - } break; - // Other metrics require storing a distance matrix. - default: { - mem_per_row += sizeof(MathT) * n_clusters; - } - } - - // If we need to convert to MathT, space required for the converted batch. - if (!needs_conversion) { mem_per_row += sizeof(MathT) * dim; } - - // Heuristic: calculate the minibatch size in order to use at most 1GB of memory. - IdxT minibatch_size = (1 << 30) / mem_per_row; - minibatch_size = 64 * div_rounding_up_safe(minibatch_size, IdxT{64}); - minibatch_size = std::min(minibatch_size, n_rows); - return std::make_tuple(minibatch_size, mem_per_row); -} - -/** - * @brief Given the data and labels, calculate cluster centers and sizes in one sweep. - * - * @note all pointers must be accessible on the device. - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam CounterT counter type supported by CUDA's native atomicAdd - * @tparam MappingOpT type of the mapping operation - * - * @param[in] handle The raft handle. - * @param[inout] centers Pointer to the output [n_clusters, dim] - * @param[inout] cluster_sizes Number of rows in each cluster [n_clusters] - * @param[in] n_clusters Number of clusters/centers - * @param[in] dim Dimensionality of the data - * @param[in] dataset Pointer to the data [n_rows, dim] - * @param[in] n_rows Number of samples in the `dataset` - * @param[in] labels Output predictions [n_rows] - * @param[in] reset_counters Whether to clear the output arrays before calculating. - * When set to `false`, this function may be used to update existing centers and sizes using - * the weighted average principle. - * @param[in] mapping_op Mapping operation from T to MathT - * @param[inout] mr (optional) Memory resource to use for temporary allocations on the device - */ -template -void calc_centers_and_sizes(const raft::resources& handle, - MathT* centers, - CounterT* cluster_sizes, - IdxT n_clusters, - IdxT dim, - const T* dataset, - IdxT n_rows, - const LabelT* labels, - bool reset_counters, - MappingOpT mapping_op, - rmm::device_async_resource_ref mr) -{ - auto stream = resource::get_cuda_stream(handle); - - if (!reset_counters) { - raft::linalg::matrixVectorOp( - centers, centers, cluster_sizes, dim, n_clusters, true, false, raft::mul_op(), stream); - } - - rmm::device_uvector workspace(0, stream, mr); - - // If we reset the counters, we can compute directly the new sizes in cluster_sizes. - // If we don't reset, we compute in a temporary buffer and add in a separate step. - rmm::device_uvector temp_cluster_sizes(0, stream, mr); - CounterT* temp_sizes = cluster_sizes; - if (!reset_counters) { - temp_cluster_sizes.resize(n_clusters, stream); - temp_sizes = temp_cluster_sizes.data(); - } - - // Apply mapping only when the data and math types are different. - if constexpr (std::is_same_v) { - raft::linalg::reduce_rows_by_key( - dataset, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); - } else { - // todo(lsugy): use iterator from KV output of fusedL2NN - cub::TransformInputIterator mapping_itr(dataset, mapping_op); - raft::linalg::reduce_rows_by_key( - mapping_itr, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); - } - - // Compute weight of each cluster - raft::cluster::detail::countLabels(handle, labels, temp_sizes, n_rows, n_clusters, workspace); - - // Add previous sizes if necessary - if (!reset_counters) { - raft::linalg::add(cluster_sizes, cluster_sizes, temp_sizes, n_clusters, stream); - } - - raft::linalg::matrixVectorOp(centers, - centers, - cluster_sizes, - dim, - n_clusters, - true, - false, - raft::div_checkzero_op(), - stream); -} - -/** Computes the L2 norm of the dataset, converting to MathT if necessary */ -template -void compute_norm(const raft::resources& handle, - MathT* dataset_norm, - const T* dataset, - IdxT dim, - IdxT n_rows, - MappingOpT mapping_op, - std::optional mr = std::nullopt) -{ - common::nvtx::range fun_scope("compute_norm"); - auto stream = resource::get_cuda_stream(handle); - rmm::device_uvector mapped_dataset( - 0, stream, mr.value_or(resource::get_workspace_resource(handle))); - - const MathT* dataset_ptr = nullptr; - - if (std::is_same_v) { - dataset_ptr = reinterpret_cast(dataset); - } else { - mapped_dataset.resize(n_rows * dim, stream); - - linalg::unaryOp(mapped_dataset.data(), dataset, n_rows * dim, mapping_op, stream); - - dataset_ptr = static_cast(mapped_dataset.data()); - } - - raft::linalg::rowNorm( - dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream); -} - -/** - * @brief Predict labels for the dataset. - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam MappingOpT type of the mapping operation - * - * @param[in] handle The raft handle - * @param[in] params Structure containing the hyper-parameters - * @param[in] centers Pointer to the row-major matrix of cluster centers [n_clusters, dim] - * @param[in] n_clusters Number of clusters/centers - * @param[in] dim Dimensionality of the data - * @param[in] dataset Pointer to the data [n_rows, dim] - * @param[in] n_rows Number samples in the `dataset` - * @param[out] labels Output predictions [n_rows] - * @param[in] mapping_op Mapping operation from T to MathT - * @param[inout] mr (optional) memory resource to use for temporary allocations - * @param[in] dataset_norm (optional) Pre-computed norms of each row in the dataset [n_rows] - */ -template -void predict(const raft::resources& handle, - const kmeans_balanced_params& params, - const MathT* centers, - IdxT n_clusters, - IdxT dim, - const T* dataset, - IdxT n_rows, - LabelT* labels, - MappingOpT mapping_op, - std::optional mr = std::nullopt, - const MathT* dataset_norm = nullptr) -{ - auto stream = resource::get_cuda_stream(handle); - common::nvtx::range fun_scope( - "predict(%zu, %u)", static_cast(n_rows), n_clusters); - auto mem_res = mr.value_or(resource::get_workspace_resource(handle)); - auto [max_minibatch_size, _mem_per_row] = - calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); - rmm::device_uvector cur_dataset( - std::is_same_v ? 0 : max_minibatch_size * dim, stream, mem_res); - bool need_compute_norm = - dataset_norm == nullptr && (params.metric == raft::distance::DistanceType::L2Expanded || - params.metric == raft::distance::DistanceType::L2SqrtExpanded); - rmm::device_uvector cur_dataset_norm( - need_compute_norm ? max_minibatch_size : 0, stream, mem_res); - const MathT* dataset_norm_ptr = nullptr; - auto cur_dataset_ptr = cur_dataset.data(); - for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { - IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); - - if constexpr (std::is_same_v) { - cur_dataset_ptr = const_cast(dataset + offset * dim); - } else { - linalg::unaryOp( - cur_dataset_ptr, dataset + offset * dim, minibatch_size * dim, mapping_op, stream); - } - - // Compute the norm now if it hasn't been pre-computed. - if (need_compute_norm) { - compute_norm( - handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mem_res); - dataset_norm_ptr = cur_dataset_norm.data(); - } else if (dataset_norm != nullptr) { - dataset_norm_ptr = dataset_norm + offset; - } - - predict_core(handle, - params, - centers, - n_clusters, - dim, - cur_dataset_ptr, - dataset_norm_ptr, - minibatch_size, - labels + offset, - mem_res); - } -} - -template -__launch_bounds__((WarpSize * BlockDimY)) RAFT_KERNEL - adjust_centers_kernel(MathT* centers, // [n_clusters, dim] - IdxT n_clusters, - IdxT dim, - const T* dataset, // [n_rows, dim] - IdxT n_rows, - const LabelT* labels, // [n_rows] - const CounterT* cluster_sizes, // [n_clusters] - MathT threshold, - IdxT average, - IdxT seed, - IdxT* count, - MappingOpT mapping_op) -{ - IdxT l = threadIdx.y + BlockDimY * static_cast(blockIdx.y); - if (l >= n_clusters) return; - auto csize = static_cast(cluster_sizes[l]); - // skip big clusters - if (csize > static_cast(average * threshold)) return; - - // choose a "random" i that belongs to a rather large cluster - IdxT i; - IdxT j = laneId(); - if (j == 0) { - do { - auto old = atomicAdd(count, IdxT{1}); - i = (seed * (old + 1)) % n_rows; - } while (static_cast(cluster_sizes[labels[i]]) < average); - } - i = raft::shfl(i, 0); - - // Adjust the center of the selected smaller cluster to gravitate towards - // a sample from the selected larger cluster. - const IdxT li = static_cast(labels[i]); - // Weight of the current center for the weighted average. - // We dump it for anomalously small clusters, but keep constant otherwise. - const MathT wc = min(static_cast(csize), static_cast(kAdjustCentersWeight)); - // Weight for the datapoint used to shift the center. - const MathT wd = 1.0; - for (; j < dim; j += WarpSize) { - MathT val = 0; - val += wc * centers[j + dim * li]; - val += wd * mapping_op(dataset[j + dim * i]); - val /= wc + wd; - centers[j + dim * l] = val; - } -} - -/** - * @brief Adjust centers for clusters that have small number of entries. - * - * For each cluster, where the cluster size is not bigger than a threshold, the center is moved - * towards a data point that belongs to a large cluster. - * - * NB: if this function returns `true`, you should update the labels. - * - * NB: all pointers must be on the device side. - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam CounterT counter type supported by CUDA's native atomicAdd - * @tparam MappingOpT type of the mapping operation - * - * @param[inout] centers cluster centers [n_clusters, dim] - * @param[in] n_clusters number of rows in `centers` - * @param[in] dim number of columns in `centers` and `dataset` - * @param[in] dataset a host pointer to the row-major data matrix [n_rows, dim] - * @param[in] n_rows number of rows in `dataset` - * @param[in] labels a host pointer to the cluster indices [n_rows] - * @param[in] cluster_sizes number of rows in each cluster [n_clusters] - * @param[in] threshold defines a criterion for adjusting a cluster - * (cluster_sizes <= average_size * threshold) - * 0 <= threshold < 1 - * @param[in] mapping_op Mapping operation from T to MathT - * @param[in] stream CUDA stream - * @param[inout] device_memory memory resource to use for temporary allocations - * - * @return whether any of the centers has been updated (and thus, `labels` need to be recalculated). - */ -template -auto adjust_centers(MathT* centers, - IdxT n_clusters, - IdxT dim, - const T* dataset, - IdxT n_rows, - const LabelT* labels, - const CounterT* cluster_sizes, - MathT threshold, - MappingOpT mapping_op, - rmm::cuda_stream_view stream, - rmm::device_async_resource_ref device_memory) -> bool -{ - common::nvtx::range fun_scope( - "adjust_centers(%zu, %u)", static_cast(n_rows), n_clusters); - if (n_clusters == 0) { return false; } - constexpr static std::array kPrimes{29, 71, 113, 173, 229, 281, 349, 409, 463, 541, - 601, 659, 733, 809, 863, 941, 1013, 1069, 1151, 1223, - 1291, 1373, 1451, 1511, 1583, 1657, 1733, 1811, 1889, 1987, - 2053, 2129, 2213, 2287, 2357, 2423, 2531, 2617, 2687, 2741}; - static IdxT i = 0; - static IdxT i_primes = 0; - - bool adjusted = false; - IdxT average = n_rows / n_clusters; - IdxT ofst; - do { - i_primes = (i_primes + 1) % kPrimes.size(); - ofst = kPrimes[i_primes]; - } while (n_rows % ofst == 0); - - constexpr uint32_t kBlockDimY = 4; - const dim3 block_dim(WarpSize, kBlockDimY, 1); - const dim3 grid_dim(1, raft::ceildiv(n_clusters, static_cast(kBlockDimY)), 1); - rmm::device_scalar update_count(0, stream, device_memory); - adjust_centers_kernel<<>>(centers, - n_clusters, - dim, - dataset, - n_rows, - labels, - cluster_sizes, - threshold, - average, - ofst, - update_count.data(), - mapping_op); - adjusted = update_count.value(stream) > 0; // NB: rmm scalar performs the sync - - return adjusted; -} - -/** - * @brief Expectation-maximization-balancing combined in an iterative process. - * - * Note, the `cluster_centers` is assumed to be already initialized here. - * Thus, this function can be used for fine-tuning existing clusters; - * to train from scratch, use `build_clusters` function below. - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam CounterT counter type supported by CUDA's native atomicAdd - * @tparam MappingOpT type of the mapping operation - * - * @param[in] handle The raft handle - * @param[in] params Structure containing the hyper-parameters - * @param[in] n_iters Requested number of iterations (can differ from params.n_iter!) - * @param[in] dim Dimensionality of the dataset - * @param[in] dataset Pointer to a managed row-major array [n_rows, dim] - * @param[in] dataset_norm Pointer to the precomputed norm (for L2 metrics only) [n_rows] - * @param[in] n_rows Number of rows in the dataset - * @param[in] n_cluster Requested number of clusters - * @param[inout] cluster_centers Pointer to a managed row-major array [n_clusters, dim] - * @param[out] cluster_labels Pointer to a managed row-major array [n_rows] - * @param[out] cluster_sizes Pointer to a managed row-major array [n_clusters] - * @param[in] balancing_pullback - * if the cluster centers are rebalanced on this number of iterations, - * one extra iteration is performed (this could happen several times) (default should be `2`). - * In other words, the first and then every `ballancing_pullback`-th rebalancing operation adds - * one more iteration to the main cycle. - * @param[in] balancing_threshold - * the rebalancing takes place if any cluster is smaller than `avg_size * balancing_threshold` - * on a given iteration (default should be `~ 0.25`). - * @param[in] mapping_op Mapping operation from T to MathT - * @param[inout] device_memory - * A memory resource for device allocations (makes sense to provide a memory pool here) - */ -template -void balancing_em_iters(const raft::resources& handle, - const kmeans_balanced_params& params, - uint32_t n_iters, - IdxT dim, - const T* dataset, - const MathT* dataset_norm, - IdxT n_rows, - IdxT n_clusters, - MathT* cluster_centers, - LabelT* cluster_labels, - CounterT* cluster_sizes, - uint32_t balancing_pullback, - MathT balancing_threshold, - MappingOpT mapping_op, - rmm::device_async_resource_ref device_memory) -{ - auto stream = resource::get_cuda_stream(handle); - uint32_t balancing_counter = balancing_pullback; - for (uint32_t iter = 0; iter < n_iters; iter++) { - // Balancing step - move the centers around to equalize cluster sizes - // (but not on the first iteration) - if (iter > 0 && adjust_centers(cluster_centers, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - cluster_sizes, - balancing_threshold, - mapping_op, - stream, - device_memory)) { - if (balancing_counter++ >= balancing_pullback) { - balancing_counter -= balancing_pullback; - n_iters++; - } - } - switch (params.metric) { - // For some metrics, cluster calculation and adjustment tends to favor zero center vectors. - // To avoid converging to zero, we normalize the center vectors on every iteration. - case raft::distance::DistanceType::InnerProduct: - case raft::distance::DistanceType::CosineExpanded: - case raft::distance::DistanceType::CorrelationExpanded: { - auto clusters_in_view = raft::make_device_matrix_view( - cluster_centers, n_clusters, dim); - auto clusters_out_view = raft::make_device_matrix_view( - cluster_centers, n_clusters, dim); - raft::linalg::row_normalize( - handle, clusters_in_view, clusters_out_view, raft::linalg::L2Norm); - break; - } - default: break; - } - // E: Expectation step - predict labels - predict(handle, - params, - cluster_centers, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - mapping_op, - device_memory, - dataset_norm); - // M: Maximization step - calculate optimal cluster centers - calc_centers_and_sizes(handle, - cluster_centers, - cluster_sizes, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - true, - mapping_op, - device_memory); - } -} - -/** Randomly initialize cluster centers and then call `balancing_em_iters`. */ -template -void build_clusters(const raft::resources& handle, - const kmeans_balanced_params& params, - IdxT dim, - const T* dataset, - IdxT n_rows, - IdxT n_clusters, - MathT* cluster_centers, - LabelT* cluster_labels, - CounterT* cluster_sizes, - MappingOpT mapping_op, - rmm::device_async_resource_ref device_memory, - const MathT* dataset_norm = nullptr) -{ - auto stream = resource::get_cuda_stream(handle); - - // "randomly" initialize labels - auto labels_view = raft::make_device_vector_view(cluster_labels, n_rows); - linalg::map_offset( - handle, - labels_view, - raft::compose_op(raft::cast_op(), raft::mod_const_op(n_clusters))); - - // update centers to match the initialized labels. - calc_centers_and_sizes(handle, - cluster_centers, - cluster_sizes, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - true, - mapping_op, - device_memory); - - // run EM - balancing_em_iters(handle, - params, - params.n_iters, - dim, - dataset, - dataset_norm, - n_rows, - n_clusters, - cluster_centers, - cluster_labels, - cluster_sizes, - 2, - MathT{0.25}, - mapping_op, - device_memory); -} - -/** Calculate how many fine clusters should belong to each mesocluster. */ -template -inline auto arrange_fine_clusters(IdxT n_clusters, - IdxT n_mesoclusters, - IdxT n_rows, - const CounterT* mesocluster_sizes) -{ - std::vector fine_clusters_nums(n_mesoclusters); - std::vector fine_clusters_csum(n_mesoclusters + 1); - fine_clusters_csum[0] = 0; - - IdxT n_lists_rem = n_clusters; - IdxT n_nonempty_ms_rem = 0; - for (IdxT i = 0; i < n_mesoclusters; i++) { - n_nonempty_ms_rem += mesocluster_sizes[i] > CounterT{0} ? 1 : 0; - } - IdxT n_rows_rem = n_rows; - CounterT mesocluster_size_sum = 0; - CounterT mesocluster_size_max = 0; - IdxT fine_clusters_nums_max = 0; - for (IdxT i = 0; i < n_mesoclusters; i++) { - if (i < n_mesoclusters - 1) { - // Although the algorithm is meant to produce balanced clusters, when something - // goes wrong, we may get empty clusters (e.g. during development/debugging). - // The code below ensures a proportional arrangement of fine cluster numbers - // per mesocluster, even if some clusters are empty. - if (mesocluster_sizes[i] == 0) { - fine_clusters_nums[i] = 0; - } else { - n_nonempty_ms_rem--; - auto s = static_cast( - static_cast(n_lists_rem * mesocluster_sizes[i]) / n_rows_rem + .5); - s = std::min(s, n_lists_rem - n_nonempty_ms_rem); - fine_clusters_nums[i] = std::max(s, IdxT{1}); - } - } else { - fine_clusters_nums[i] = n_lists_rem; - } - n_lists_rem -= fine_clusters_nums[i]; - n_rows_rem -= mesocluster_sizes[i]; - mesocluster_size_max = max(mesocluster_size_max, mesocluster_sizes[i]); - mesocluster_size_sum += mesocluster_sizes[i]; - fine_clusters_nums_max = max(fine_clusters_nums_max, fine_clusters_nums[i]); - fine_clusters_csum[i + 1] = fine_clusters_csum[i] + fine_clusters_nums[i]; - } - - RAFT_EXPECTS(static_cast(mesocluster_size_sum) == n_rows, - "mesocluster sizes do not add up (%zu) to the total trainset size (%zu)", - static_cast(mesocluster_size_sum), - static_cast(n_rows)); - RAFT_EXPECTS(fine_clusters_csum[n_mesoclusters] == n_clusters, - "fine cluster numbers do not add up (%zu) to the total number of clusters (%zu)", - static_cast(fine_clusters_csum[n_mesoclusters]), - static_cast(n_clusters)); - - return std::make_tuple(static_cast(mesocluster_size_max), - fine_clusters_nums_max, - std::move(fine_clusters_nums), - std::move(fine_clusters_csum)); -} - -/** - * Given the (coarse) mesoclusters and the distribution of fine clusters within them, - * build the fine clusters. - * - * Processing one mesocluster at a time: - * 1. Copy mesocluster data into a separate buffer - * 2. Predict fine cluster - * 3. Refince the fine cluster centers - * - * As a result, the fine clusters are what is returned by `build_hierarchical`; - * this function returns the total number of fine clusters, which can be checked to be - * the same as the requested number of clusters. - * - * Note: this function uses at most `fine_clusters_nums_max` points per mesocluster for training; - * if one of the clusters is larger than that (as given by `mesocluster_sizes`), the extra data - * is ignored and a warning is reported. - */ -template -auto build_fine_clusters(const raft::resources& handle, - const kmeans_balanced_params& params, - IdxT dim, - const T* dataset_mptr, - const MathT* dataset_norm_mptr, - const LabelT* labels_mptr, - IdxT n_rows, - const IdxT* fine_clusters_nums, - const IdxT* fine_clusters_csum, - const CounterT* mesocluster_sizes, - IdxT n_mesoclusters, - IdxT mesocluster_size_max, - IdxT fine_clusters_nums_max, - MathT* cluster_centers, - MappingOpT mapping_op, - rmm::device_async_resource_ref managed_memory, - rmm::device_async_resource_ref device_memory) -> IdxT -{ - auto stream = resource::get_cuda_stream(handle); - rmm::device_uvector mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory); - rmm::device_uvector mc_trainset_buf(mesocluster_size_max * dim, stream, device_memory); - rmm::device_uvector mc_trainset_norm_buf(mesocluster_size_max, stream, device_memory); - auto mc_trainset_ids = mc_trainset_ids_buf.data(); - auto mc_trainset = mc_trainset_buf.data(); - auto mc_trainset_norm = mc_trainset_norm_buf.data(); - - // label (cluster ID) of each vector - rmm::device_uvector mc_trainset_labels(mesocluster_size_max, stream, device_memory); - - rmm::device_uvector mc_trainset_ccenters( - fine_clusters_nums_max * dim, stream, device_memory); - // number of vectors in each cluster - rmm::device_uvector mc_trainset_csizes_tmp( - fine_clusters_nums_max, stream, device_memory); - - // Training clusters in each meso-cluster - IdxT n_clusters_done = 0; - for (IdxT i = 0; i < n_mesoclusters; i++) { - IdxT k = 0; - for (IdxT j = 0; j < n_rows && k < mesocluster_size_max; j++) { - if (labels_mptr[j] == LabelT(i)) { mc_trainset_ids[k++] = j; } - } - if (k != static_cast(mesocluster_sizes[i])) - RAFT_LOG_WARN("Incorrect mesocluster size at %d. %zu vs %zu", - static_cast(i), - static_cast(k), - static_cast(mesocluster_sizes[i])); - if (k == 0) { - RAFT_LOG_DEBUG("Empty cluster %d", i); - RAFT_EXPECTS(fine_clusters_nums[i] == 0, - "Number of fine clusters must be zero for the empty mesocluster (got %d)", - static_cast(fine_clusters_nums[i])); - continue; - } else { - RAFT_EXPECTS(fine_clusters_nums[i] > 0, - "Number of fine clusters must be non-zero for a non-empty mesocluster"); - } - - cub::TransformInputIterator mapping_itr(dataset_mptr, mapping_op); - raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream); - if (params.metric == raft::distance::DistanceType::L2Expanded || - params.metric == raft::distance::DistanceType::L2SqrtExpanded) { - thrust::gather(resource::get_thrust_policy(handle), - mc_trainset_ids, - mc_trainset_ids + k, - dataset_norm_mptr, - mc_trainset_norm); - } - - build_clusters(handle, - params, - dim, - mc_trainset, - k, - fine_clusters_nums[i], - mc_trainset_ccenters.data(), - mc_trainset_labels.data(), - mc_trainset_csizes_tmp.data(), - mapping_op, - device_memory, - mc_trainset_norm); - - raft::copy(cluster_centers + (dim * fine_clusters_csum[i]), - mc_trainset_ccenters.data(), - fine_clusters_nums[i] * dim, - stream); - resource::sync_stream(handle, stream); - n_clusters_done += fine_clusters_nums[i]; - } - return n_clusters_done; -} - -/** - * @brief Hierarchical balanced k-means - * - * @tparam T element type - * @tparam MathT type of the centroids and mapped data - * @tparam IdxT index type - * @tparam LabelT label type - * @tparam MappingOpT type of the mapping operation - * - * @param[in] handle The raft handle. - * @param[in] params Structure containing the hyper-parameters - * @param dim number of columns in `centers` and `dataset` - * @param[in] dataset a device pointer to the source dataset [n_rows, dim] - * @param n_rows number of rows in the input - * @param[out] cluster_centers a device pointer to the found cluster centers [n_cluster, dim] - * @param n_cluster - * @param metric the distance type - * @param mapping_op Mapping operation from T to MathT - * @param stream - */ -template -void build_hierarchical(const raft::resources& handle, - const kmeans_balanced_params& params, - IdxT dim, - const T* dataset, - IdxT n_rows, - MathT* cluster_centers, - IdxT n_clusters, - MappingOpT mapping_op) -{ - auto stream = resource::get_cuda_stream(handle); - using LabelT = uint32_t; - - common::nvtx::range fun_scope( - "build_hierarchical(%zu, %u)", static_cast(n_rows), n_clusters); - - IdxT n_mesoclusters = std::min(n_clusters, static_cast(std::sqrt(n_clusters) + 0.5)); - RAFT_LOG_DEBUG("build_hierarchical: n_mesoclusters: %u", n_mesoclusters); - - // TODO: Remove the explicit managed memory- we shouldn't be creating this on the user's behalf. - rmm::mr::managed_memory_resource managed_memory; - rmm::device_async_resource_ref device_memory = resource::get_workspace_resource(handle); - auto [max_minibatch_size, mem_per_row] = - calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); - - // Precompute the L2 norm of the dataset if relevant. - const MathT* dataset_norm = nullptr; - rmm::device_uvector dataset_norm_buf(0, stream, device_memory); - if (params.metric == raft::distance::DistanceType::L2Expanded || - params.metric == raft::distance::DistanceType::L2SqrtExpanded) { - dataset_norm_buf.resize(n_rows, stream); - for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { - IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); - compute_norm(handle, - dataset_norm_buf.data() + offset, - dataset + dim * offset, - dim, - minibatch_size, - mapping_op, - device_memory); - } - dataset_norm = (const MathT*)dataset_norm_buf.data(); - } - - /* Temporary workaround to cub::DeviceHistogram not supporting any type that isn't natively - * supported by atomicAdd: find a supported CounterT based on the IdxT. */ - typedef typename std::conditional_t - CounterT; - - // build coarse clusters (mesoclusters) - rmm::device_uvector mesocluster_labels_buf(n_rows, stream, &managed_memory); - rmm::device_uvector mesocluster_sizes_buf(n_mesoclusters, stream, &managed_memory); - { - rmm::device_uvector mesocluster_centers_buf(n_mesoclusters * dim, stream, device_memory); - build_clusters(handle, - params, - dim, - dataset, - n_rows, - n_mesoclusters, - mesocluster_centers_buf.data(), - mesocluster_labels_buf.data(), - mesocluster_sizes_buf.data(), - mapping_op, - device_memory, - dataset_norm); - } - - auto mesocluster_sizes = mesocluster_sizes_buf.data(); - auto mesocluster_labels = mesocluster_labels_buf.data(); - - resource::sync_stream(handle, stream); - - // build fine clusters - auto [mesocluster_size_max, fine_clusters_nums_max, fine_clusters_nums, fine_clusters_csum] = - arrange_fine_clusters(n_clusters, n_mesoclusters, n_rows, mesocluster_sizes); - - const IdxT mesocluster_size_max_balanced = div_rounding_up_safe( - 2lu * size_t(n_rows), std::max(size_t(n_mesoclusters), 1lu)); - if (mesocluster_size_max > mesocluster_size_max_balanced) { - RAFT_LOG_WARN( - "build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). " - "At most %u points will be used for training within each mesocluster. " - "Consider increasing the number of training iterations `n_iters`.", - mesocluster_size_max, - mesocluster_size_max_balanced, - mesocluster_size_max_balanced); - RAFT_LOG_TRACE_VEC(mesocluster_sizes, n_mesoclusters); - RAFT_LOG_TRACE_VEC(fine_clusters_nums.data(), n_mesoclusters); - mesocluster_size_max = mesocluster_size_max_balanced; - } - - auto n_clusters_done = build_fine_clusters(handle, - params, - dim, - dataset, - dataset_norm, - mesocluster_labels, - n_rows, - fine_clusters_nums.data(), - fine_clusters_csum.data(), - mesocluster_sizes, - n_mesoclusters, - mesocluster_size_max, - fine_clusters_nums_max, - cluster_centers, - mapping_op, - &managed_memory, - device_memory); - RAFT_EXPECTS(n_clusters_done == n_clusters, "Didn't process all clusters."); - - rmm::device_uvector cluster_sizes(n_clusters, stream, device_memory); - rmm::device_uvector labels(n_rows, stream, device_memory); - - // Fine-tuning k-means for all clusters - // - // (*) Since the likely cluster centroids have been calculated hierarchically already, the number - // of iterations for fine-tuning kmeans for whole clusters should be reduced. However, there is a - // possibility that the clusters could be unbalanced here, in which case the actual number of - // iterations would be increased. - // - balancing_em_iters(handle, - params, - std::max(params.n_iters / 10, 2), - dim, - dataset, - dataset_norm, - n_rows, - n_clusters, - cluster_centers, - labels.data(), - cluster_sizes.data(), - 5, - MathT{0.2}, - mapping_op, - device_memory); -} - -} // namespace raft::cluster::detail diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh deleted file mode 100644 index 8263aa4615..0000000000 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ /dev/null @@ -1,663 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace cluster { -namespace detail { - -template -struct SamplingOp { - DataT* rnd; - uint8_t* flag; - DataT cluster_cost; - double oversampling_factor; - IndexT n_clusters; - - CUB_RUNTIME_FUNCTION __forceinline__ - SamplingOp(DataT c, double l, IndexT k, DataT* rand, uint8_t* ptr) - : cluster_cost(c), oversampling_factor(l), n_clusters(k), rnd(rand), flag(ptr) - { - } - - __host__ __device__ __forceinline__ bool operator()( - const raft::KeyValuePair& a) const - { - DataT prob_threshold = (DataT)rnd[a.key]; - - DataT prob_x = ((oversampling_factor * n_clusters * a.value) / cluster_cost); - - return !flag[a.key] && (prob_x > prob_threshold); - } -}; - -template -struct KeyValueIndexOp { - __host__ __device__ __forceinline__ IndexT - operator()(const raft::KeyValuePair& a) const - { - return a.key; - } -}; - -// Computes the intensity histogram from a sequence of labels -template -void countLabels(raft::resources const& handle, - SampleIteratorT labels, - CounterT* count, - IndexT n_samples, - IndexT n_clusters, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - - // CUB::DeviceHistogram requires a signed index type - typedef typename std::make_signed_t CubIndexT; - - CubIndexT num_levels = n_clusters + 1; - CubIndexT lower_level = 0; - CubIndexT upper_level = n_clusters; - - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(nullptr, - temp_storage_bytes, - labels, - count, - num_levels, - lower_level, - upper_level, - static_cast(n_samples), - stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(workspace.data(), - temp_storage_bytes, - labels, - count, - num_levels, - lower_level, - upper_level, - static_cast(n_samples), - stream)); -} - -template -void checkWeight(raft::resources const& handle, - raft::device_vector_view weight, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto wt_aggr = raft::make_device_scalar(handle, 0); - auto n_samples = weight.extent(0); - - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data_handle(), n_samples, stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Sum(workspace.data(), - temp_storage_bytes, - weight.data_handle(), - wt_aggr.data_handle(), - n_samples, - stream)); - DataT wt_sum = 0; - raft::copy(&wt_sum, wt_aggr.data_handle(), 1, stream); - resource::sync_stream(handle, stream); - - if (wt_sum != n_samples) { - RAFT_LOG_DEBUG( - "[Warning!] KMeans: normalizing the user provided sample weight to " - "sum up to %d samples", - n_samples); - - auto scale = static_cast(n_samples) / wt_sum; - raft::linalg::unaryOp(weight.data_handle(), - weight.data_handle(), - n_samples, - raft::mul_const_op{scale}, - stream); - } -} - -template -IndexT getDataBatchSize(int batch_samples, IndexT n_samples) -{ - auto minVal = std::min(static_cast(batch_samples), n_samples); - return (minVal == 0) ? n_samples : minVal; -} - -template -IndexT getCentroidsBatchSize(int batch_centroids, IndexT n_local_clusters) -{ - auto minVal = std::min(static_cast(batch_centroids), n_local_clusters); - return (minVal == 0) ? n_local_clusters : minVal; -} - -template -void computeClusterCost(raft::resources const& handle, - raft::device_vector_view minClusterDistance, - rmm::device_uvector& workspace, - raft::device_scalar_view clusterCost, - MainOpT main_op, - ReductionOpT reduction_op) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - - cub::TransformInputIterator itr(minClusterDistance.data_handle(), - main_op); - - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(nullptr, - temp_storage_bytes, - itr, - clusterCost.data_handle(), - minClusterDistance.size(), - reduction_op, - OutputT(), - stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Reduce(workspace.data(), - temp_storage_bytes, - itr, - clusterCost.data_handle(), - minClusterDistance.size(), - reduction_op, - OutputT(), - stream)); -} - -template -void sampleCentroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view minClusterDistance, - raft::device_vector_view isSampleCentroid, - SamplingOp& select_op, - rmm::device_uvector& inRankCp, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_local_samples = X.extent(0); - auto n_features = X.extent(1); - - auto nSelected = raft::make_device_scalar(handle, 0); - cub::ArgIndexInputIterator ip_itr(minClusterDistance.data_handle()); - auto sampledMinClusterDistance = - raft::make_device_vector, IndexT>(handle, n_local_samples); - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceSelect::If(nullptr, - temp_storage_bytes, - ip_itr, - sampledMinClusterDistance.data_handle(), - nSelected.data_handle(), - n_local_samples, - select_op, - stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceSelect::If(workspace.data(), - temp_storage_bytes, - ip_itr, - sampledMinClusterDistance.data_handle(), - nSelected.data_handle(), - n_local_samples, - select_op, - stream)); - - IndexT nPtsSampledInRank = 0; - raft::copy(&nPtsSampledInRank, nSelected.data_handle(), 1, stream); - resource::sync_stream(handle, stream); - - uint8_t* rawPtr_isSampleCentroid = isSampleCentroid.data_handle(); - thrust::for_each_n(resource::get_thrust_policy(handle), - sampledMinClusterDistance.data_handle(), - nPtsSampledInRank, - [=] __device__(raft::KeyValuePair val) { - rawPtr_isSampleCentroid[val.key] = 1; - }); - - inRankCp.resize(nPtsSampledInRank * n_features, stream); - - raft::matrix::gather((DataT*)X.data_handle(), - X.extent(1), - X.extent(0), - sampledMinClusterDistance.data_handle(), - nPtsSampledInRank, - inRankCp.data(), - raft::key_op{}, - stream); -} - -// calculate pairwise distance between 'dataset[n x d]' and 'centroids[k x d]', -// result will be stored in 'pairwiseDistance[n x k]' -template -void pairwise_distance_kmeans(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_matrix_view pairwiseDistance, - rmm::device_uvector& workspace, - raft::distance::DistanceType metric) -{ - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - - ASSERT(X.extent(1) == centroids.extent(1), - "# features in dataset and centroids are different (must be same)"); - - raft::distance::pairwise_distance(handle, - X.data_handle(), - centroids.data_handle(), - pairwiseDistance.data_handle(), - n_samples, - n_clusters, - n_features, - workspace, - metric); -} - -// shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores -// in 'out' does not modify the input -template -void shuffleAndGather(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - uint32_t n_samples_to_gather, - uint64_t seed) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = in.extent(0); - auto n_features = in.extent(1); - - auto indices = raft::make_device_vector(handle, n_samples); - - // shuffle indices on device - raft::random::permute(indices.data_handle(), - nullptr, - nullptr, - (IndexT)in.extent(1), - (IndexT)in.extent(0), - true, - stream); - - raft::matrix::gather((DataT*)in.data_handle(), - in.extent(1), - in.extent(0), - indices.data_handle(), - static_cast(n_samples_to_gather), - out.data_handle(), - stream); -} - -// Calculates a pair for every sample in input 'X' where key is an -// index to an sample in 'centroids' (index of the nearest centroid) and 'value' -// is the distance between the sample and the 'centroid[key]' -template -void minClusterAndDistanceCompute( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - raft::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - // todo(lsugy): change batch size computation when using fusedL2NN! - bool is_fused = metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded; - auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); - auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); - - if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(), - centroids.data_handle(), - centroids.extent(1), - centroids.extent(0), - raft::linalg::L2Norm, - true, - stream); - } else { - // TODO: Unless pool allocator is used, passing in a workspace for this - // isn't really increasing performance because this needs to do a re-allocation - // anyways. ref https://github.com/rapidsai/raft/issues/930 - L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - } - - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer - auto pairwiseDistance = raft::make_device_matrix_view( - L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - - raft::KeyValuePair initial_value(0, std::numeric_limits::max()); - - thrust::fill(resource::get_thrust_policy(handle), - minClusterAndDistance.data_handle(), - minClusterAndDistance.data_handle() + minClusterAndDistance.size(), - initial_value); - - // tile over the input dataset - for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = raft::make_device_matrix_view( - X.data_handle() + (dIdx * n_features), ns, n_features); - - // minClusterAndDistanceView [ns x n_clusters] - auto minClusterAndDistanceView = - raft::make_device_vector_view, IndexT>( - minClusterAndDistance.data_handle() + dIdx, ns); - - auto L2NormXView = - raft::make_device_vector_view(L2NormX.data_handle() + dIdx, ns); - - if (is_fused) { - workspace.resize((sizeof(int)) * ns, stream); - - // todo(lsugy): remove cIdx - raft::distance::fusedL2NNMinReduce, IndexT>( - minClusterAndDistanceView.data_handle(), - datasetView.data_handle(), - centroids.data_handle(), - L2NormXView.data_handle(), - centroidsNorm.data_handle(), - ns, - n_clusters, - n_features, - (void*)workspace.data(), - metric != raft::distance::DistanceType::L2Expanded, - false, - stream); - } else { - // tile over the centroids - for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch - auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); - - // centroidsView [nc x n_features] - view representing the current batch - // of centroids - auto centroidsView = raft::make_device_matrix_view( - centroids.data_handle() + (cIdx * n_features), nc, n_features); - - // pairwiseDistanceView [ns x nc] - view representing the pairwise - // distance for current batch - auto pairwiseDistanceView = - raft::make_device_matrix_view(pairwiseDistance.data_handle(), ns, nc); - - // calculate pairwise distance between current tile of cluster centroids - // and input dataset - pairwise_distance_kmeans( - handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric); - - // argmin reduction returning pair - // calculates the closest centroid and the distance to the closest - // centroid - raft::linalg::coalescedReduction( - minClusterAndDistanceView.data_handle(), - pairwiseDistanceView.data_handle(), - pairwiseDistanceView.extent(1), - pairwiseDistanceView.extent(0), - initial_value, - stream, - true, - [=] __device__(const DataT val, const IndexT i) { - raft::KeyValuePair pair; - pair.key = cIdx + i; - pair.value = val; - return pair; - }, - raft::argmin_op{}, - raft::identity_op{}); - } - } - } -} - -template -void minClusterDistanceCompute(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view minClusterDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - raft::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - - bool is_fused = metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded; - auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); - auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); - - if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::rowNorm(L2NormBuf_OR_DistBuf.data(), - centroids.data_handle(), - centroids.extent(1), - centroids.extent(0), - raft::linalg::L2Norm, - true, - stream); - } else { - L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); - } - - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer - auto pairwiseDistance = raft::make_device_matrix_view( - L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); - - thrust::fill(resource::get_thrust_policy(handle), - minClusterDistance.data_handle(), - minClusterDistance.data_handle() + minClusterDistance.size(), - std::numeric_limits::max()); - - // tile over the input data and calculate distance matrix [n_samples x - // n_clusters] - for (IndexT dIdx = 0; dIdx < n_samples; dIdx += dataBatchSize) { - // # of samples for the current batch - auto ns = std::min((IndexT)dataBatchSize, n_samples - dIdx); - - // datasetView [ns x n_features] - view representing the current batch of - // input dataset - auto datasetView = raft::make_device_matrix_view( - X.data_handle() + dIdx * n_features, ns, n_features); - - // minClusterDistanceView [ns x n_clusters] - auto minClusterDistanceView = - raft::make_device_vector_view(minClusterDistance.data_handle() + dIdx, ns); - - auto L2NormXView = - raft::make_device_vector_view(L2NormX.data_handle() + dIdx, ns); - - if (is_fused) { - workspace.resize((sizeof(IndexT)) * ns, stream); - - raft::distance::fusedL2NNMinReduce( - minClusterDistanceView.data_handle(), - datasetView.data_handle(), - centroids.data_handle(), - L2NormXView.data_handle(), - centroidsNorm.data_handle(), - ns, - n_clusters, - n_features, - (void*)workspace.data(), - metric != raft::distance::DistanceType::L2Expanded, - false, - stream); - } else { - // tile over the centroids - for (IndexT cIdx = 0; cIdx < n_clusters; cIdx += centroidsBatchSize) { - // # of centroids for the current batch - auto nc = std::min((IndexT)centroidsBatchSize, n_clusters - cIdx); - - // centroidsView [nc x n_features] - view representing the current batch - // of centroids - auto centroidsView = raft::make_device_matrix_view( - centroids.data_handle() + cIdx * n_features, nc, n_features); - - // pairwiseDistanceView [ns x nc] - view representing the pairwise - // distance for current batch - auto pairwiseDistanceView = - raft::make_device_matrix_view(pairwiseDistance.data_handle(), ns, nc); - - // calculate pairwise distance between current tile of cluster centroids - // and input dataset - pairwise_distance_kmeans( - handle, datasetView, centroidsView, pairwiseDistanceView, workspace, metric); - - raft::linalg::coalescedReduction(minClusterDistanceView.data_handle(), - pairwiseDistanceView.data_handle(), - pairwiseDistanceView.extent(1), - pairwiseDistanceView.extent(0), - std::numeric_limits::max(), - stream, - true, - raft::identity_op{}, - raft::min_op{}, - raft::identity_op{}); - } - } - } -} - -template -void countSamplesInCluster(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view L2NormX, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace, - raft::device_vector_view sampleCountInCluster) -{ - cudaStream_t stream = resource::get_cuda_stream(handle); - auto n_samples = X.extent(0); - auto n_features = X.extent(1); - auto n_clusters = centroids.extent(0); - - // stores (key, value) pair corresponding to each sample where - // - key is the index of nearest cluster - // - value is the distance to the nearest cluster - auto minClusterAndDistance = - raft::make_device_vector, IndexT>(handle, n_samples); - - // temporary buffer to store distance matrix, destructor releases the resource - rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); - - // computes minClusterAndDistance[0:n_samples) where minClusterAndDistance[i] - // is a pair where - // 'key' is index to an sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - detail::minClusterAndDistanceCompute(handle, - X, - (raft::device_matrix_view)centroids, - minClusterAndDistance.view(), - L2NormX, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - // Using TransformInputIteratorT to dereference an array of raft::KeyValuePair - // and converting them to just return the Key to be used in reduce_rows_by_key - // prims - detail::KeyValueIndexOp conversion_op; - cub::TransformInputIterator, - raft::KeyValuePair*> - itr(minClusterAndDistance.data_handle(), conversion_op); - - // count # of samples in each cluster - countLabels(handle, - itr, - sampleCountInCluster.data_handle(), - (IndexT)n_samples, - (IndexT)n_clusters, - workspace); -} -} // namespace detail -} // namespace cluster -} // namespace raft diff --git a/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh b/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh deleted file mode 100644 index e89f5480e3..0000000000 --- a/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh +++ /dev/null @@ -1,1001 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/** - * Note: This file is deprecated and will be removed in a future release - * Please use include/raft/cluster/kmeans.cuh instead - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace raft { -namespace cluster { -namespace detail { -// ========================================================= -// Useful grid settings -// ========================================================= - -constexpr unsigned int BLOCK_SIZE = 1024; -constexpr unsigned int WARP_SIZE = 32; -constexpr unsigned int BSIZE_DIV_WSIZE = (BLOCK_SIZE / WARP_SIZE); - -// ========================================================= -// CUDA kernels -// ========================================================= - -/** - * @brief Compute distances between observation vectors and centroids - * Block dimensions should be (warpSize, 1, - * blockSize/warpSize). Ideally, the grid is large enough so there - * are d threads in the x-direction, k threads in the y-direction, - * and n threads in the z-direction. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param obs (Input, d*n entries) Observation matrix. Matrix is - * stored column-major and each column is an observation - * vector. Matrix dimensions are d x n. - * @param centroids (Input, d*k entries) Centroid matrix. Matrix is - * stored column-major and each column is a centroid. Matrix - * dimensions are d x k. - * @param dists (Output, n*k entries) Distance matrix. Matrix is - * stored column-major and the (i,j)-entry is the square of the - * Euclidean distance between the ith observation vector and jth - * centroid. Matrix dimensions are n x k. Entries must be - * initialized to zero. - */ -template -RAFT_KERNEL computeDistances(index_type_t n, - index_type_t d, - index_type_t k, - const value_type_t* __restrict__ obs, - const value_type_t* __restrict__ centroids, - value_type_t* __restrict__ dists) -{ - // Loop index - index_type_t i; - - // Block indices - index_type_t bidx; - // Global indices - index_type_t gidx, gidy, gidz; - - // Private memory - value_type_t centroid_private, dist_private; - - // Global x-index indicates index of vector entry - bidx = blockIdx.x; - while (bidx * blockDim.x < d) { - gidx = threadIdx.x + bidx * blockDim.x; - - // Global y-index indicates centroid - gidy = threadIdx.y + blockIdx.y * blockDim.y; - while (gidy < k) { - // Load centroid coordinate from global memory - centroid_private = (gidx < d) ? centroids[IDX(gidx, gidy, d)] : 0; - - // Global z-index indicates observation vector - gidz = threadIdx.z + blockIdx.z * blockDim.z; - while (gidz < n) { - // Load observation vector coordinate from global memory - dist_private = (gidx < d) ? obs[IDX(gidx, gidz, d)] : 0; - - // Compute contribution of current entry to distance - dist_private = centroid_private - dist_private; - dist_private = dist_private * dist_private; - - // Perform reduction on warp - for (i = WARP_SIZE / 2; i > 0; i /= 2) - dist_private += __shfl_down_sync(warp_full_mask(), dist_private, i, 2 * i); - - // Write result to global memory - if (threadIdx.x == 0) atomicAdd(dists + IDX(gidz, gidy, n), dist_private); - - // Move to another observation vector - gidz += blockDim.z * gridDim.z; - } - - // Move to another centroid - gidy += blockDim.y * gridDim.y; - } - - // Move to another vector entry - bidx += gridDim.x; - } -} - -/** - * @brief Find closest centroid to observation vectors. - * Block and grid dimensions should be 1-dimensional. Ideally the - * grid is large enough so there are n threads. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param n Number of observation vectors. - * @param k Number of clusters. - * @param centroids (Input, d*k entries) Centroid matrix. Matrix is - * stored column-major and each column is a centroid. Matrix - * dimensions are d x k. - * @param dists (Input/output, n*k entries) Distance matrix. Matrix - * is stored column-major and the (i,j)-entry is the square of - * the Euclidean distance between the ith observation vector and - * jth centroid. Matrix dimensions are n x k. On exit, the first - * n entries give the square of the Euclidean distance between - * observation vectors and closest centroids. - * @param codes (Output, n entries) Cluster assignments. - * @param clusterSizes (Output, k entries) Number of points in each - * cluster. Entries must be initialized to zero. - */ -template -RAFT_KERNEL minDistances(index_type_t n, - index_type_t k, - value_type_t* __restrict__ dists, - index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes) -{ - // Loop index - index_type_t i, j; - - // Current matrix entry - value_type_t dist_curr; - - // Smallest entry in row - value_type_t dist_min; - index_type_t code_min; - - // Each row in observation matrix is processed by a thread - i = threadIdx.x + blockIdx.x * blockDim.x; - while (i < n) { - // Find minimum entry in row - code_min = 0; - dist_min = dists[IDX(i, 0, n)]; - for (j = 1; j < k; ++j) { - dist_curr = dists[IDX(i, j, n)]; - code_min = (dist_curr < dist_min) ? j : code_min; - dist_min = (dist_curr < dist_min) ? dist_curr : dist_min; - } - - // Transfer result to global memory - dists[i] = dist_min; - codes[i] = code_min; - - // Increment cluster sizes - atomicAdd(clusterSizes + code_min, 1); - - // Move to another row - i += blockDim.x * gridDim.x; - } -} - -/** - * @brief Check if newly computed distances are smaller than old distances. - * Block and grid dimensions should be 1-dimensional. Ideally the - * grid is large enough so there are n threads. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param n Number of observation vectors. - * @param dists_old (Input/output, n entries) Distances between - * observation vectors and closest centroids. On exit, entries - * are replaced by entries in 'dists_new' if the corresponding - * observation vectors are closest to the new centroid. - * @param dists_new (Input, n entries) Distance between observation - * vectors and new centroid. - * @param codes_old (Input/output, n entries) Cluster - * assignments. On exit, entries are replaced with 'code_new' if - * the corresponding observation vectors are closest to the new - * centroid. - * @param code_new Index associated with new centroid. - */ -template -RAFT_KERNEL minDistances2(index_type_t n, - value_type_t* __restrict__ dists_old, - const value_type_t* __restrict__ dists_new, - index_type_t* __restrict__ codes_old, - index_type_t code_new) -{ - // Loop index - index_type_t i = threadIdx.x + blockIdx.x * blockDim.x; - - // Distances - value_type_t dist_old_private; - value_type_t dist_new_private; - - // Each row is processed by a thread - while (i < n) { - // Get old and new distances - dist_old_private = dists_old[i]; - dist_new_private = dists_new[i]; - - // Update if new distance is smaller than old distance - if (dist_new_private < dist_old_private) { - dists_old[i] = dist_new_private; - codes_old[i] = code_new; - } - - // Move to another row - i += blockDim.x * gridDim.x; - } -} - -/** - * @brief Compute size of k-means clusters. - * Block and grid dimensions should be 1-dimensional. Ideally the - * grid is large enough so there are n threads. - * @tparam index_type_t the type of data used for indexing. - * @param n Number of observation vectors. - * @param k Number of clusters. - * @param codes (Input, n entries) Cluster assignments. - * @param clusterSizes (Output, k entries) Number of points in each - * cluster. Entries must be initialized to zero. - */ -template -RAFT_KERNEL computeClusterSizes(index_type_t n, - const index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes) -{ - index_type_t i = threadIdx.x + blockIdx.x * blockDim.x; - while (i < n) { - atomicAdd(clusterSizes + codes[i], 1); - i += blockDim.x * gridDim.x; - } -} - -/** - * @brief Divide rows of centroid matrix by cluster sizes. - * Divides the ith column of the sum matrix by the size of the ith - * cluster. If the sum matrix has been initialized so that the ith - * row is the sum of all observation vectors in the ith cluster, - * this kernel produces cluster centroids. The grid and block - * dimensions should be 2-dimensional. Ideally the grid is large - * enough so there are d threads in the x-direction and k threads - * in the y-direction. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param clusterSizes (Input, k entries) Number of points in each - * cluster. - * @param centroids (Input/output, d*k entries) Sum matrix. Matrix - * is stored column-major and matrix dimensions are d x k. The - * ith column is the sum of all observation vectors in the ith - * cluster. On exit, the matrix is the centroid matrix (each - * column is the mean position of a cluster). - */ -template -RAFT_KERNEL divideCentroids(index_type_t d, - index_type_t k, - const index_type_t* __restrict__ clusterSizes, - value_type_t* __restrict__ centroids) -{ - // Global indices - index_type_t gidx, gidy; - - // Current cluster size - index_type_t clusterSize_private; - - // Observation vector is determined by global y-index - gidy = threadIdx.y + blockIdx.y * blockDim.y; - while (gidy < k) { - // Get cluster size from global memory - clusterSize_private = clusterSizes[gidy]; - - // Add vector entries to centroid matrix - // vector entris are determined by global x-index - gidx = threadIdx.x + blockIdx.x * blockDim.x; - while (gidx < d) { - centroids[IDX(gidx, gidy, d)] /= clusterSize_private; - gidx += blockDim.x * gridDim.x; - } - - // Move to another centroid - gidy += blockDim.y * gridDim.y; - } -} - -// ========================================================= -// Helper functions -// ========================================================= - -/** - * @brief Randomly choose new centroids. - * Centroid is randomly chosen with k-means++ algorithm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param rand Random number drawn uniformly from [0,1). - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are n x d. - * @param dists (Input, device memory, 2*n entries) Workspace. The - * first n entries should be the distance between observation - * vectors and the closest centroid. - * @param centroid (Output, device memory, d entries) Centroid - * coordinates. - * @return Zero if successful. Otherwise non-zero. - */ -template -static int chooseNewCentroid(raft::resources const& handle, - index_type_t n, - index_type_t d, - value_type_t rand, - const value_type_t* __restrict__ obs, - value_type_t* __restrict__ dists, - value_type_t* __restrict__ centroid) -{ - // Cumulative sum of distances - value_type_t* distsCumSum = dists + n; - // Residual sum of squares - value_type_t distsSum{0}; - // Observation vector that is chosen as new centroid - index_type_t obsIndex; - - auto stream = resource::get_cuda_stream(handle); - auto thrust_exec_policy = resource::get_thrust_policy(handle); - - // Compute cumulative sum of distances - thrust::inclusive_scan(thrust_exec_policy, - thrust::device_pointer_cast(dists), - thrust::device_pointer_cast(dists + n), - thrust::device_pointer_cast(distsCumSum)); - RAFT_CHECK_CUDA(stream); - RAFT_CUDA_TRY(cudaMemcpyAsync( - &distsSum, distsCumSum + n - 1, sizeof(value_type_t), cudaMemcpyDeviceToHost, stream)); - - // Randomly choose observation vector - // Probabilities are proportional to square of distance to closest - // centroid (see k-means++ algorithm) - // - // seg-faults due to Thrust bug - // on binary-search-like algorithms - // when run with stream dependent - // execution policies; fixed on Thrust GitHub - // hence replace w/ linear interpolation, - // until the Thrust issue gets resolved: - // - // obsIndex = (thrust::lower_bound( - // thrust_exec_policy, thrust::device_pointer_cast(distsCumSum), - // thrust::device_pointer_cast(distsCumSum + n), distsSum * rand) - - // thrust::device_pointer_cast(distsCumSum)); - // - // linear interpolation logic: - //{ - value_type_t minSum{0}; - RAFT_CUDA_TRY( - cudaMemcpyAsync(&minSum, distsCumSum, sizeof(value_type_t), cudaMemcpyDeviceToHost, stream)); - RAFT_CHECK_CUDA(stream); - - if (distsSum > minSum) { - value_type_t vIndex = static_cast(n - 1); - obsIndex = static_cast(vIndex * (distsSum * rand - minSum) / (distsSum - minSum)); - } else { - obsIndex = 0; - } - //} - - RAFT_CHECK_CUDA(stream); - obsIndex = std::max(obsIndex, static_cast(0)); - obsIndex = std::min(obsIndex, n - 1); - - // Record new centroid position - RAFT_CUDA_TRY(cudaMemcpyAsync(centroid, - obs + IDX(0, obsIndex, d), - d * sizeof(value_type_t), - cudaMemcpyDeviceToDevice, - stream)); - - return 0; -} - -/** - * @brief Choose initial cluster centroids for k-means algorithm. - * Centroids are randomly chosen with k-means++ algorithm - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param centroids (Output, device memory, d*k entries) Centroid - * matrix. Matrix is stored column-major and each column is a - * centroid. Matrix dimensions are d x k. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param clusterSizes (Output, device memory, k entries) Number of - * points in each cluster. - * @param dists (Output, device memory, 2*n entries) Workspace. On - * exit, the first n entries give the square of the Euclidean - * distance between observation vectors and the closest centroid. - * @return Zero if successful. Otherwise non-zero. - */ -template -static int initializeCentroids(raft::resources const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - const value_type_t* __restrict__ obs, - value_type_t* __restrict__ centroids, - index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes, - value_type_t* __restrict__ dists, - unsigned long long seed) -{ - // ------------------------------------------------------- - // Variable declarations - // ------------------------------------------------------- - - // Loop index - index_type_t i; - - // Random number generator - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution uniformDist(0, 1); - - auto stream = resource::get_cuda_stream(handle); - auto thrust_exec_policy = resource::get_thrust_policy(handle); - - constexpr unsigned grid_lower_bound{65535}; - - // ------------------------------------------------------- - // Implementation - // ------------------------------------------------------- - - // Initialize grid dimensions - dim3 blockDim_warp{WARP_SIZE, 1, BSIZE_DIV_WSIZE}; - - // CUDA grid dimensions - dim3 gridDim_warp{std::min(ceildiv(d, WARP_SIZE), grid_lower_bound), - 1, - std::min(ceildiv(n, BSIZE_DIV_WSIZE), grid_lower_bound)}; - - // CUDA grid dimensions - dim3 gridDim_block{std::min(ceildiv(n, BLOCK_SIZE), grid_lower_bound), 1, 1}; - - // Assign observation vectors to code 0 - RAFT_CUDA_TRY(cudaMemsetAsync(codes, 0, n * sizeof(index_type_t), stream)); - - // Choose first centroid - thrust::fill(thrust_exec_policy, - thrust::device_pointer_cast(dists), - thrust::device_pointer_cast(dists + n), - 1); - RAFT_CHECK_CUDA(stream); - if (chooseNewCentroid(handle, n, d, uniformDist(rng), obs, dists, centroids)) - WARNING("error in k-means++ (could not pick centroid)"); - - // Compute distances from first centroid - RAFT_CUDA_TRY(cudaMemsetAsync(dists, 0, n * sizeof(value_type_t), stream)); - computeDistances<<>>(n, d, 1, obs, centroids, dists); - RAFT_CHECK_CUDA(stream); - - // Choose remaining centroids - for (i = 1; i < k; ++i) { - // Choose ith centroid - if (chooseNewCentroid(handle, n, d, uniformDist(rng), obs, dists, centroids + IDX(0, i, d))) - WARNING("error in k-means++ (could not pick centroid)"); - - // Compute distances from ith centroid - RAFT_CUDA_TRY(cudaMemsetAsync(dists + n, 0, n * sizeof(value_type_t), stream)); - computeDistances<<>>( - n, d, 1, obs, centroids + IDX(0, i, d), dists + n); - RAFT_CHECK_CUDA(stream); - - // Recompute minimum distances - minDistances2<<>>(n, dists, dists + n, codes, i); - RAFT_CHECK_CUDA(stream); - } - - // Compute cluster sizes - RAFT_CUDA_TRY(cudaMemsetAsync(clusterSizes, 0, k * sizeof(index_type_t), stream)); - computeClusterSizes<<>>(n, codes, clusterSizes); - RAFT_CHECK_CUDA(stream); - - return 0; -} - -/** - * @brief Find cluster centroids closest to observation vectors. - * Distance is measured with Euclidean norm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param centroids (Input, device memory, d*k entries) Centroid - * matrix. Matrix is stored column-major and each column is a - * centroid. Matrix dimensions are d x k. - * @param dists (Output, device memory, n*k entries) Workspace. On - * exit, the first n entries give the square of the Euclidean - * distance between observation vectors and the closest centroid. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param clusterSizes (Output, device memory, k entries) Number of - * points in each cluster. - * @param residual_host (Output, host memory, 1 entry) Residual sum - * of squares of assignment. - * @return Zero if successful. Otherwise non-zero. - */ -template -static int assignCentroids(raft::resources const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - const value_type_t* __restrict__ obs, - const value_type_t* __restrict__ centroids, - value_type_t* __restrict__ dists, - index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes, - value_type_t* residual_host) -{ - auto stream = resource::get_cuda_stream(handle); - auto thrust_exec_policy = resource::get_thrust_policy(handle); - - // Compute distance between centroids and observation vectors - RAFT_CUDA_TRY(cudaMemsetAsync(dists, 0, n * k * sizeof(value_type_t), stream)); - - // CUDA grid dimensions - dim3 blockDim{WARP_SIZE, 1, BLOCK_SIZE / WARP_SIZE}; - - dim3 gridDim; - constexpr unsigned grid_lower_bound{65535}; - gridDim.x = std::min(ceildiv(d, WARP_SIZE), grid_lower_bound); - gridDim.y = std::min(static_cast(k), grid_lower_bound); - gridDim.z = std::min(ceildiv(n, BSIZE_DIV_WSIZE), grid_lower_bound); - - computeDistances<<>>(n, d, k, obs, centroids, dists); - RAFT_CHECK_CUDA(stream); - - // Find centroid closest to each observation vector - RAFT_CUDA_TRY(cudaMemsetAsync(clusterSizes, 0, k * sizeof(index_type_t), stream)); - blockDim.x = BLOCK_SIZE; - blockDim.y = 1; - blockDim.z = 1; - gridDim.x = std::min(ceildiv(n, BLOCK_SIZE), grid_lower_bound); - gridDim.y = 1; - gridDim.z = 1; - minDistances<<>>(n, k, dists, codes, clusterSizes); - RAFT_CHECK_CUDA(stream); - - // Compute residual sum of squares - *residual_host = thrust::reduce( - thrust_exec_policy, thrust::device_pointer_cast(dists), thrust::device_pointer_cast(dists + n)); - - return 0; -} - -/** - * @brief Update cluster centroids for k-means algorithm. - * All clusters are assumed to be non-empty. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param codes (Input, device memory, n entries) Cluster - * assignments. - * @param clusterSizes (Input, device memory, k entries) Number of - * points in each cluster. - * @param centroids (Output, device memory, d*k entries) Centroid - * matrix. Matrix is stored column-major and each column is a - * centroid. Matrix dimensions are d x k. - * @param work (Output, device memory, n*d entries) Workspace. - * @param work_int (Output, device memory, 2*d*n entries) - * Workspace. - * @return Zero if successful. Otherwise non-zero. - */ -template -static int updateCentroids(raft::resources const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - const value_type_t* __restrict__ obs, - const index_type_t* __restrict__ codes, - const index_type_t* __restrict__ clusterSizes, - value_type_t* __restrict__ centroids, - value_type_t* __restrict__ work, - index_type_t* __restrict__ work_int) -{ - // ------------------------------------------------------- - // Variable declarations - // ------------------------------------------------------- - - // Useful constants - const value_type_t one = 1; - const value_type_t zero = 0; - - constexpr unsigned grid_lower_bound{65535}; - - auto stream = resource::get_cuda_stream(handle); - auto cublas_h = resource::get_cublas_handle(handle); - auto thrust_exec_policy = resource::get_thrust_policy(handle); - - // Device memory - thrust::device_ptr obs_copy(work); - thrust::device_ptr codes_copy(work_int); - thrust::device_ptr rows(work_int + d * n); - - // Take transpose of observation matrix - // #TODO: Call from public API when ready - RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgeam(cublas_h, - CUBLAS_OP_T, - CUBLAS_OP_N, - n, - d, - &one, - obs, - d, - &zero, - (value_type_t*)NULL, - n, - thrust::raw_pointer_cast(obs_copy), - n, - stream)); - - // Cluster assigned to each observation matrix entry - thrust::sequence(thrust_exec_policy, rows, rows + d * n); - RAFT_CHECK_CUDA(stream); - thrust::transform(thrust_exec_policy, - rows, - rows + d * n, - thrust::make_constant_iterator(n), - rows, - thrust::modulus()); - RAFT_CHECK_CUDA(stream); - thrust::gather( - thrust_exec_policy, rows, rows + d * n, thrust::device_pointer_cast(codes), codes_copy); - RAFT_CHECK_CUDA(stream); - - // Row associated with each observation matrix entry - thrust::sequence(thrust_exec_policy, rows, rows + d * n); - RAFT_CHECK_CUDA(stream); - thrust::transform(thrust_exec_policy, - rows, - rows + d * n, - thrust::make_constant_iterator(n), - rows, - thrust::divides()); - RAFT_CHECK_CUDA(stream); - - // Sort and reduce to add observation vectors in same cluster - thrust::stable_sort_by_key(thrust_exec_policy, - codes_copy, - codes_copy + d * n, - make_zip_iterator(make_tuple(obs_copy, rows))); - RAFT_CHECK_CUDA(stream); - thrust::reduce_by_key(thrust_exec_policy, - rows, - rows + d * n, - obs_copy, - codes_copy, // Output to codes_copy is ignored - thrust::device_pointer_cast(centroids)); - RAFT_CHECK_CUDA(stream); - - // Divide sums by cluster size to get centroid matrix - // - // CUDA grid dimensions - dim3 blockDim{WARP_SIZE, BLOCK_SIZE / WARP_SIZE, 1}; - - // CUDA grid dimensions - dim3 gridDim{std::min(ceildiv(d, WARP_SIZE), grid_lower_bound), - std::min(ceildiv(k, BSIZE_DIV_WSIZE), grid_lower_bound), - 1}; - - divideCentroids<<>>(d, k, clusterSizes, centroids); - RAFT_CHECK_CUDA(stream); - - return 0; -} - -// ========================================================= -// k-means algorithm -// ========================================================= - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param tol Tolerance for convergence. k-means stops when the - * change in residual divided by n is less than tol. - * @param maxiter Maximum number of k-means iterations. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param clusterSizes (Output, device memory, k entries) Number of - * points in each cluster. - * @param centroids (Output, device memory, d*k entries) Centroid - * matrix. Matrix is stored column-major and each column is a - * centroid. Matrix dimensions are d x k. - * @param work (Output, device memory, n*max(k,d) entries) - * Workspace. - * @param work_int (Output, device memory, 2*d*n entries) - * Workspace. - * @param residual_host (Output, host memory, 1 entry) Residual sum - * of squares (sum of squares of distances between observation - * vectors and centroids). - * @param iters_host (Output, host memory, 1 entry) Number of - * k-means iterations. - * @param seed random seed to be used. - * @return error flag. - */ -template -int kmeans(raft::resources const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - value_type_t tol, - index_type_t maxiter, - const value_type_t* __restrict__ obs, - index_type_t* __restrict__ codes, - index_type_t* __restrict__ clusterSizes, - value_type_t* __restrict__ centroids, - value_type_t* __restrict__ work, - index_type_t* __restrict__ work_int, - value_type_t* residual_host, - index_type_t* iters_host, - unsigned long long seed) -{ - // ------------------------------------------------------- - // Variable declarations - // ------------------------------------------------------- - - // Current iteration - index_type_t iter; - - constexpr unsigned grid_lower_bound{65535}; - - // Residual sum of squares at previous iteration - value_type_t residualPrev = 0; - - // Random number generator - thrust::default_random_engine rng(seed); - thrust::uniform_real_distribution uniformDist(0, 1); - - // ------------------------------------------------------- - // Initialization - // ------------------------------------------------------- - - auto stream = resource::get_cuda_stream(handle); - auto cublas_h = resource::get_cublas_handle(handle); - auto thrust_exec_policy = resource::get_thrust_policy(handle); - - // Trivial cases - if (k == 1) { - RAFT_CUDA_TRY(cudaMemsetAsync(codes, 0, n * sizeof(index_type_t), stream)); - RAFT_CUDA_TRY( - cudaMemcpyAsync(clusterSizes, &n, sizeof(index_type_t), cudaMemcpyHostToDevice, stream)); - if (updateCentroids(handle, n, d, k, obs, codes, clusterSizes, centroids, work, work_int)) - WARNING("could not compute k-means centroids"); - - dim3 blockDim{WARP_SIZE, 1, BLOCK_SIZE / WARP_SIZE}; - - dim3 gridDim{std::min(ceildiv(d, WARP_SIZE), grid_lower_bound), - 1, - std::min(ceildiv(n, BLOCK_SIZE / WARP_SIZE), grid_lower_bound)}; - - RAFT_CUDA_TRY(cudaMemsetAsync(work, 0, n * k * sizeof(value_type_t), stream)); - computeDistances<<>>(n, d, 1, obs, centroids, work); - RAFT_CHECK_CUDA(stream); - *residual_host = thrust::reduce( - thrust_exec_policy, thrust::device_pointer_cast(work), thrust::device_pointer_cast(work + n)); - RAFT_CHECK_CUDA(stream); - return 0; - } - if (n <= k) { - thrust::sequence(thrust_exec_policy, - thrust::device_pointer_cast(codes), - thrust::device_pointer_cast(codes + n)); - RAFT_CHECK_CUDA(stream); - thrust::fill_n(thrust_exec_policy, thrust::device_pointer_cast(clusterSizes), n, 1); - RAFT_CHECK_CUDA(stream); - - if (n < k) - RAFT_CUDA_TRY(cudaMemsetAsync(clusterSizes + n, 0, (k - n) * sizeof(index_type_t), stream)); - RAFT_CUDA_TRY(cudaMemcpyAsync( - centroids, obs, d * n * sizeof(value_type_t), cudaMemcpyDeviceToDevice, stream)); - *residual_host = 0; - return 0; - } - - // Initialize cuBLAS - // #TODO: Call from public API when ready - RAFT_CUBLAS_TRY( - raft::linalg::detail::cublassetpointermode(cublas_h, CUBLAS_POINTER_MODE_HOST, stream)); - - // ------------------------------------------------------- - // k-means++ algorithm - // ------------------------------------------------------- - - // Choose initial cluster centroids - if (initializeCentroids(handle, n, d, k, obs, centroids, codes, clusterSizes, work, seed)) - WARNING("could not initialize k-means centroids"); - - // Apply k-means iteration until convergence - for (iter = 0; iter < maxiter; ++iter) { - // Update cluster centroids - if (updateCentroids(handle, n, d, k, obs, codes, clusterSizes, centroids, work, work_int)) - WARNING("could not update k-means centroids"); - - // Determine centroid closest to each observation - residualPrev = *residual_host; - if (assignCentroids(handle, n, d, k, obs, centroids, work, codes, clusterSizes, residual_host)) - WARNING("could not assign observation vectors to k-means clusters"); - - // Reinitialize empty clusters with new centroids - index_type_t emptyCentroid = (thrust::find(thrust_exec_policy, - thrust::device_pointer_cast(clusterSizes), - thrust::device_pointer_cast(clusterSizes + k), - 0) - - thrust::device_pointer_cast(clusterSizes)); - - // FIXME: emptyCentroid never reaches k (infinite loop) under certain - // conditions, such as if obs is corrupt (as seen as a result of a - // DataFrame column of NULL edge vals used to create the Graph) - while (emptyCentroid < k) { - if (chooseNewCentroid( - handle, n, d, uniformDist(rng), obs, work, centroids + IDX(0, emptyCentroid, d))) - WARNING("could not replace empty centroid"); - if (assignCentroids( - handle, n, d, k, obs, centroids, work, codes, clusterSizes, residual_host)) - WARNING("could not assign observation vectors to k-means clusters"); - emptyCentroid = (thrust::find(thrust_exec_policy, - thrust::device_pointer_cast(clusterSizes), - thrust::device_pointer_cast(clusterSizes + k), - 0) - - thrust::device_pointer_cast(clusterSizes)); - RAFT_CHECK_CUDA(stream); - } - - // Check for convergence - if (std::fabs(residualPrev - (*residual_host)) / n < tol) { - ++iter; - break; - } - } - - // Warning if k-means has failed to converge - if (std::fabs(residualPrev - (*residual_host)) / n >= tol) WARNING("k-means failed to converge"); - - *iters_host = iter; - return 0; -} - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param tol Tolerance for convergence. k-means stops when the - * change in residual divided by n is less than tol. - * @param maxiter Maximum number of k-means iterations. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param residual On exit, residual sum of squares (sum of squares - * of distances between observation vectors and centroids). - * @param iters on exit, number of k-means iterations. - * @param seed random seed to be used. - * @return error flag - */ -template -int kmeans(raft::resources const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - value_type_t tol, - index_type_t maxiter, - const value_type_t* __restrict__ obs, - index_type_t* __restrict__ codes, - value_type_t& residual, - index_type_t& iters, - unsigned long long seed = 123456) -{ - // Check that parameters are valid - RAFT_EXPECTS(n > 0, "invalid parameter (n<1)"); - RAFT_EXPECTS(d > 0, "invalid parameter (d<1)"); - RAFT_EXPECTS(k > 0, "invalid parameter (k<1)"); - RAFT_EXPECTS(tol > 0, "invalid parameter (tol<=0)"); - RAFT_EXPECTS(maxiter >= 0, "invalid parameter (maxiter<0)"); - - // Allocate memory - raft::spectral::matrix::vector_t clusterSizes(handle, k); - raft::spectral::matrix::vector_t centroids(handle, d * k); - raft::spectral::matrix::vector_t work(handle, n * std::max(k, d)); - raft::spectral::matrix::vector_t work_int(handle, 2 * d * n); - - // Perform k-means - return kmeans(handle, - n, - d, - k, - tol, - maxiter, - obs, - codes, - clusterSizes.raw(), - centroids.raw(), - work.raw(), - work_int.raw(), - &residual, - &iters, - seed); -} - -} // namespace detail -} // namespace cluster -} // namespace raft diff --git a/cpp/include/raft/cluster/detail/mst.cuh b/cpp/include/raft/cluster/detail/mst.cuh deleted file mode 100644 index 55becc8e15..0000000000 --- a/cpp/include/raft/cluster/detail/mst.cuh +++ /dev/null @@ -1,207 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -namespace raft::cluster::detail { - -template -void merge_msts(sparse::solver::Graph_COO& coo1, - sparse::solver::Graph_COO& coo2, - cudaStream_t stream) -{ - /** Add edges to existing mst **/ - int final_nnz = coo2.n_edges + coo1.n_edges; - - coo1.src.resize(final_nnz, stream); - coo1.dst.resize(final_nnz, stream); - coo1.weights.resize(final_nnz, stream); - - /** - * Construct final edge list - */ - raft::copy_async(coo1.src.data() + coo1.n_edges, coo2.src.data(), coo2.n_edges, stream); - raft::copy_async(coo1.dst.data() + coo1.n_edges, coo2.dst.data(), coo2.n_edges, stream); - raft::copy_async(coo1.weights.data() + coo1.n_edges, coo2.weights.data(), coo2.n_edges, stream); - - coo1.n_edges = final_nnz; -} - -/** - * Connect an unconnected knn graph (one in which mst returns an msf). The - * device buffers underlying the Graph_COO object are modified in-place. - * @tparam value_idx index type - * @tparam value_t floating-point value type - * @param[in] handle raft handle - * @param[in] X original dense data from which knn grpah was constructed - * @param[inout] msf edge list containing the mst result - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[inout] color the color labels array returned from the mst invocation - * @return updated MST edge list - */ -template -void connect_knn_graph( - raft::resources const& handle, - const value_t* X, - sparse::solver::Graph_COO& msf, - size_t m, - size_t n, - value_idx* color, - red_op reduction_op, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded) -{ - auto stream = resource::get_cuda_stream(handle); - - raft::sparse::COO connected_edges(stream); - - // default row and column batch sizes are chosen for computing cross component nearest neighbors. - // Reference: PR #1445 - static constexpr size_t default_row_batch_size = 4096; - static constexpr size_t default_col_batch_size = 16; - - raft::sparse::neighbors::cross_component_nn(handle, - connected_edges, - X, - color, - m, - n, - reduction_op, - min(m, default_row_batch_size), - min(n, default_col_batch_size)); - - rmm::device_uvector indptr2(m + 1, stream); - raft::sparse::convert::sorted_coo_to_csr( - connected_edges.rows(), connected_edges.nnz, indptr2.data(), m + 1, stream); - - // On the second call, we hand the MST the original colors - // and the new set of edges and let it restart the optimization process - auto new_mst = - raft::sparse::solver::mst(handle, - indptr2.data(), - connected_edges.cols(), - connected_edges.vals(), - m, - connected_edges.nnz, - color, - stream, - false, - false); - - merge_msts(msf, new_mst, stream); -} - -/** - * Constructs an MST and sorts the resulting edges in ascending - * order by their weight. - * - * Hierarchical clustering heavily relies upon the ordering - * and vertices returned in the MST. If the result of the - * MST was actually a minimum-spanning forest, the CSR - * being passed into the MST is not connected. In such a - * case, this graph will be connected by performing a - * KNN across the components. - * @tparam value_idx - * @tparam value_t - * @param[in] handle raft handle - * @param[in] indptr CSR indptr of connectivities graph - * @param[in] indices CSR indices array of connectivities graph - * @param[in] pw_dists CSR weights array of connectivities graph - * @param[in] m number of rows in X / src vertices in connectivities graph - * @param[in] n number of columns in X - * @param[out] mst_src output src edges - * @param[out] mst_dst output dst edges - * @param[out] mst_weight output weights (distances) - * @param[in] max_iter maximum iterations to run knn graph connection. This - * argument is really just a safeguard against the potential for infinite loops. - */ -template -void build_sorted_mst( - raft::resources const& handle, - const value_t* X, - const value_idx* indptr, - const value_idx* indices, - const value_t* pw_dists, - size_t m, - size_t n, - value_idx* mst_src, - value_idx* mst_dst, - value_t* mst_weight, - value_idx* color, - size_t nnz, - red_op reduction_op, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded, - int max_iter = 10) -{ - auto stream = resource::get_cuda_stream(handle); - - // We want to have MST initialize colors on first call. - auto mst_coo = raft::sparse::solver::mst( - handle, indptr, indices, pw_dists, (value_idx)m, nnz, color, stream, false, true); - - int iters = 1; - int n_components = raft::sparse::neighbors::get_n_components(color, m, stream); - - while (n_components > 1 && iters < max_iter) { - connect_knn_graph(handle, X, mst_coo, m, n, color, reduction_op); - - iters++; - - n_components = raft::sparse::neighbors::get_n_components(color, m, stream); - } - - /** - * The `max_iter` argument was introduced only to prevent the potential for an infinite loop. - * Ideally the log2(n) guarantees of the MST should be enough to connect KNN graphs with a - * massive number of data samples in very few iterations. If it does not, there are 3 likely - * reasons why (in order of their likelihood): - * 1. There is a bug in this code somewhere - * 2. Either the given KNN graph wasn't generated from X or the same metric is not being used - * to generate the 1-nn (currently only L2SqrtExpanded is supported). - * 3. max_iter was not large enough to connect the graph (less likely). - * - * Note that a KNN graph generated from 50 random isotropic balls (with significant overlap) - * was able to be connected in a single iteration. - */ - RAFT_EXPECTS(n_components == 1, - "KNN graph could not be connected in %d iterations. " - "Please verify that the input knn graph is generated from X " - "(and the same distance metric used)," - " or increase 'max_iter'", - max_iter); - - raft::sparse::op::coo_sort_by_weight( - mst_coo.src.data(), mst_coo.dst.data(), mst_coo.weights.data(), mst_coo.n_edges, stream); - - raft::copy_async(mst_src, mst_coo.src.data(), mst_coo.n_edges, stream); - raft::copy_async(mst_dst, mst_coo.dst.data(), mst_coo.n_edges, stream); - raft::copy_async(mst_weight, mst_coo.weights.data(), mst_coo.n_edges, stream); -} - -}; // namespace raft::cluster::detail \ No newline at end of file diff --git a/cpp/include/raft/cluster/detail/single_linkage.cuh b/cpp/include/raft/cluster/detail/single_linkage.cuh deleted file mode 100644 index ccc6472684..0000000000 --- a/cpp/include/raft/cluster/detail/single_linkage.cuh +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::cluster::detail { - -static const size_t EMPTY = 0; - -/** - * Single-linkage clustering, capable of constructing a KNN graph to - * scale the algorithm beyond the n^2 memory consumption of implementations - * that use the fully-connected graph of pairwise distances by connecting - * a knn graph when k is not large enough to connect it. - - * @tparam value_idx - * @tparam value_t - * @tparam dist_type method to use for constructing connectivities graph - * @param[in] handle raft handle - * @param[in] X dense input matrix in row-major layout - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[in] metric distance metrix to use when constructing connectivities graph - * @param[out] out struct containing output dendrogram and cluster assignments - * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect - control - * of k. The algorithm will set `k = log(n) + c` - * @param[in] n_clusters number of clusters to assign data samples - */ -template -void single_linkage(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - linkage_output* out, - int c, - size_t n_clusters) -{ - ASSERT(n_clusters <= m, "n_clusters must be less than or equal to the number of data points"); - - auto stream = resource::get_cuda_stream(handle); - - rmm::device_uvector indptr(EMPTY, stream); - rmm::device_uvector indices(EMPTY, stream); - rmm::device_uvector pw_dists(EMPTY, stream); - - /** - * 1. Construct distance graph - */ - detail::get_distance_graph( - handle, X, m, n, metric, indptr, indices, pw_dists, c); - - rmm::device_uvector mst_rows(m - 1, stream); - rmm::device_uvector mst_cols(m - 1, stream); - rmm::device_uvector mst_data(m - 1, stream); - - /** - * 2. Construct MST, sorted by weights - */ - rmm::device_uvector color(m, stream); - raft::sparse::neighbors::FixConnectivitiesRedOp op(m); - detail::build_sorted_mst(handle, - X, - indptr.data(), - indices.data(), - pw_dists.data(), - m, - n, - mst_rows.data(), - mst_cols.data(), - mst_data.data(), - color.data(), - indices.size(), - op, - metric); - - pw_dists.release(); - - /** - * Perform hierarchical labeling - */ - size_t n_edges = mst_rows.size(); - - rmm::device_uvector out_delta(n_edges, stream); - rmm::device_uvector out_size(n_edges, stream); - // Create dendrogram - detail::build_dendrogram_host(handle, - mst_rows.data(), - mst_cols.data(), - mst_data.data(), - n_edges, - out->children, - out_delta.data(), - out_size.data()); - detail::extract_flattened_clusters(handle, out->labels, out->children, n_clusters, m); - - out->m = m; - out->n_clusters = n_clusters; - out->n_leaves = m; - out->n_connected_components = 1; -} -}; // namespace raft::cluster::detail \ No newline at end of file diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh deleted file mode 100644 index 38318e8ec8..0000000000 --- a/cpp/include/raft/cluster/kmeans.cuh +++ /dev/null @@ -1,1120 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::cluster::kmeans { - -/** - * Functor used for sampling centroids - */ -template -using SamplingOp = detail::SamplingOp; - -/** - * Functor used to extract the index from a KeyValue pair - * storing both index and a distance. - */ -template -using KeyValueIndexOp = detail::KeyValueIndexOp; - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace raft::cluster; - * ... - * raft::raft::resources handle; - * raft::cluster::KMeansParams params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * - * kmeans::fit(handle, - * params, - * X, - * std::nullopt, - * centroids, - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -[[deprecated("Use cuVS instead")]] void fit( - raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - detail::kmeans_fit(handle, params, X, sample_weight, centroids, inertia, n_iter); -} - -/** - * @brief Predict the closest cluster each sample in X belongs to. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace raft::cluster; - * ... - * raft::raft::resources handle; - * raft::cluster::KMeansParams params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * - * kmeans::fit(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * ... - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * false, - * labels.view(), - * raft::make_scalar_view(&ineratia)); - * @endcode - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X New data to predict. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[in] centroids Cluster centroids. The data must be in - * row-major format. - * [dim = n_clusters x n_features] - * @param[in] normalize_weight True if the weights should be normalized - * @param[out] labels Index of the cluster each sample in X - * belongs to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to - * their closest cluster center. - */ -template -[[deprecated("Use cuVS instead")]] void predict( - raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) -{ - detail::kmeans_predict( - handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); -} - -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * #include - * using namespace raft::cluster; - * ... - * raft::raft::resources handle; - * raft::cluster::KMeansParams params; - * int n_features = 15, inertia, n_iter; - * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, X.extent(0)); - * - * kmeans::fit_predict(handle, - * params, - * X, - * std::nullopt, - * centroids.view(), - * labels.view(), - * raft::make_scalar_view(&inertia), - * raft::make_scalar_view(&n_iter)); - * @endcode - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -[[deprecated("Use cuVS instead")]] void fit_predict( - raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - detail::kmeans_fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} - -/** - * @brief Transform X to a cluster-distance space. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Cluster centroids. The data must be in row-major format. - * [dim = n_clusters x n_features] - * @param[out] X_new X transformed in the new space. - * [dim = n_samples x n_features] - */ -template -void transform(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_matrix_view X_new) -{ - detail::kmeans_transform(handle, params, X, centroids, X_new); -} - -template -[[deprecated("Use cuVS instead")]] void transform(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT* X_new) -{ - detail::kmeans_transform( - handle, params, X, centroids, n_samples, n_features, X_new); -} - -/** - * Automatically find the optimal value of k using a binary search. - * This method maximizes the Calinski-Harabasz Index while minimizing the per-cluster inertia. - * - * @code{.cpp} - * #include - * #include - * #include - * - * #include - * - * using namespace raft::cluster; - * - * raft::handle_t handle; - * int n_samples = 100, n_features = 15, n_clusters = 10; - * auto X = raft::make_device_matrix(handle, n_samples, n_features); - * auto labels = raft::make_device_vector(handle, n_samples); - * - * raft::random::make_blobs(handle, X, labels, n_clusters); - * - * auto best_k = raft::make_host_scalar(0); - * auto n_iter = raft::make_host_scalar(0); - * auto inertia = raft::make_host_scalar(0); - * - * kmeans::find_k(handle, X, best_k.view(), inertia.view(), n_iter.view(), n_clusters+1); - * - * @endcode - * - * @tparam idx_t indexing type (should be integral) - * @tparam value_t value type (should be floating point) - * @param handle raft handle - * @param X input observations (shape n_samples, n_dims) - * @param best_k best k found from binary search - * @param inertia inertia of best k found - * @param n_iter number of iterations used to find best k - * @param kmax maximum k to try in search - * @param kmin minimum k to try in search (should be >= 1) - * @param maxiter maximum number of iterations to run - * @param tol tolerance for early stopping convergence - */ -template -void find_k(raft::resources const& handle, - raft::device_matrix_view X, - raft::host_scalar_view best_k, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - idx_t kmax, - idx_t kmin = 1, - idx_t maxiter = 100, - value_t tol = 1e-3) -{ - detail::find_k(handle, X, best_k, inertia, n_iter, kmax, kmin, maxiter, tol); -} - -/** - * @brief Select centroids according to a sampling operation - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] isSampleCentroid Flag the sample chosen as initial centroid - * [dim = n_samples] - * @param[in] select_op The sampling operation used to select the centroids - * @param[out] inRankCp The sampled centroids - * [dim = n_selected_centroids x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void sample_centroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view minClusterDistance, - raft::device_vector_view isSampleCentroid, - SamplingOp& select_op, - rmm::device_uvector& inRankCp, - rmm::device_uvector& workspace) -{ - detail::sampleCentroids( - handle, X, minClusterDistance, isSampleCentroid, select_op, inRankCp, workspace); -} - -/** - * @brief Compute cluster cost - * - * @tparam DataT the type of data used for weights, distances. - * @tparam ReductionOpT the type of data used for the reduction operation. - * - * @param[in] handle The raft handle - * @param[in] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] workspace Temporary workspace buffer which can get resized - * @param[out] clusterCost Resulting cluster cost - * @param[in] reduction_op The reduction operation used for the cost - * - */ -template -void cluster_cost(raft::resources const& handle, - raft::device_vector_view minClusterDistance, - rmm::device_uvector& workspace, - raft::device_scalar_view clusterCost, - ReductionOpT reduction_op) -{ - detail::computeClusterCost( - handle, minClusterDistance, workspace, clusterCost, raft::identity_op{}, reduction_op); -} - -/** - * @brief Update centroids given current centroids and number of points assigned to each centroid. - * This function also produces a vector of RAFT key/value pairs containing the cluster assignment - * for each point and its distance. - * - * @tparam DataT - * @tparam IndexT - * @param[in] handle: Raft handle to use for managing library resources - * @param[in] X: input matrix (size n_samples, n_features) - * @param[in] sample_weights: number of samples currently assigned to each centroid (size n_samples) - * @param[in] centroids: matrix of current centroids (size n_clusters, n_features) - * @param[in] labels: Iterator of labels (can also be a raw pointer) - * @param[out] weight_per_cluster: sum of sample weights per cluster (size n_clusters) - * @param[out] new_centroids: output matrix of updated centroids (size n_clusters, n_features) - */ -template -void update_centroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroids, - LabelsIterator labels, - raft::device_vector_view weight_per_cluster, - raft::device_matrix_view new_centroids) -{ - // TODO: Passing these into the algorithm doesn't really present much of a benefit - // because they are being resized anyways. - // ref https://github.com/rapidsai/raft/issues/930 - rmm::device_uvector workspace(0, resource::get_cuda_stream(handle)); - - detail::update_centroids( - handle, X, sample_weights, centroids, labels, weight_per_cluster, new_centroids, workspace); -} - -/** - * @brief Compute distance for every sample to it's nearest centroid - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[out] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance - * matrix - * @param[in] metric Distance metric to use - * @param[in] batch_samples batch size for input data samples - * @param[in] batch_centroids batch size for input centroids - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void min_cluster_distance(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view minClusterDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - raft::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) -{ - detail::minClusterDistanceCompute(handle, - X, - centroids, - minClusterDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - metric, - batch_samples, - batch_centroids, - workspace); -} - -/** - * @brief Calculates a pair for every sample in input 'X' where key is an - * index of one of the 'centroids' (index of the nearest centroid) and 'value' - * is the distance between the sample and the 'centroid[key]' - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[out] minClusterAndDistance Distance vector that contains for every sample, the nearest - * centroid and it's distance - * [dim = n_samples] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance - * matrix - * @param[in] metric distance metric - * @param[in] batch_samples batch size of data samples - * @param[in] batch_centroids batch size of centroids - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void min_cluster_and_distance( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - raft::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) -{ - detail::minClusterAndDistanceCompute(handle, - X, - centroids, - minClusterAndDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - metric, - batch_samples, - batch_centroids, - workspace); -} - -/** - * @brief Shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores - * in 'out' does not modify the input - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] in The data to shuffle and gather - * [dim = n_samples x n_features] - * @param[out] out The sampled data - * [dim = n_samples_to_gather x n_features] - * @param[in] n_samples_to_gather Number of sample to gather - * @param[in] seed Seed for the shuffle - * - */ -template -void shuffle_and_gather(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - uint32_t n_samples_to_gather, - uint64_t seed) -{ - detail::shuffleAndGather(handle, in, out, n_samples_to_gather, seed); -} - -/** - * @brief Count the number of samples in each cluster - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - * @param[out] sampleCountInCluster The count for each centroid - * [dim = n_cluster] - * - */ -template -void count_samples_in_cluster(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view L2NormX, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace, - raft::device_vector_view sampleCountInCluster) -{ - detail::countSamplesInCluster( - handle, params, X, L2NormX, centroids, workspace, sampleCountInCluster); -} - -/** - * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. - * - * @see "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. - * ACM-SIAM symposium on Discrete algorithms. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[out] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - */ -template -void init_plus_plus(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace) -{ - detail::kmeansPlusPlus(handle, params, X, centroids, workspace); -} - -/* - * @brief Main function used to fit KMeans (after cluster initialization) - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] Initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - * @param[in] workspace Temporary workspace buffer which can get resized - */ -template -void fit_main(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view sample_weights, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) -{ - detail::kmeans_fit_main( - handle, params, X, sample_weights, centroids, inertia, n_iter, workspace); -} - -}; // end namespace raft::cluster::kmeans - -namespace raft::cluster { - -/** - * Note: All of the functions below in raft::cluster are deprecated and will - * be removed in a future release. Please use raft::cluster::kmeans instead. - */ - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -void kmeans_fit(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - kmeans::fit(handle, params, X, sample_weight, centroids, inertia, n_iter); -} - -template -void kmeans_fit(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT& inertia, - IndexT& n_iter) -{ - kmeans::fit( - handle, params, X, sample_weight, centroids, n_samples, n_features, inertia, n_iter); -} - -/** - * @brief Predict the closest cluster each sample in X belongs to. - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X New data to predict. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[in] centroids Cluster centroids. The data must be in - * row-major format. - * [dim = n_clusters x n_features] - * @param[in] normalize_weight True if the weights should be normalized - * @param[out] labels Index of the cluster each sample in X - * belongs to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to - * their closest cluster center. - */ -template -void kmeans_predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - bool normalize_weight, - raft::host_scalar_view inertia) -{ - kmeans::predict( - handle, params, X, sample_weight, centroids, labels, normalize_weight, inertia); -} - -template -void kmeans_predict(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - bool normalize_weight, - DataT& inertia) -{ - kmeans::predict(handle, - params, - X, - sample_weight, - centroids, - n_samples, - n_features, - labels, - normalize_weight, - inertia); -} - -/** - * @brief Compute k-means clustering and predicts cluster index for each sample - * in the input. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must be - * in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Optional weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids Optional - * [in] When init is InitMethod::Array, use - * centroids as the initial cluster centers - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] labels Index of the cluster each sample in X belongs - * to. - * [len = n_samples] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - */ -template -void kmeans_fit_predict(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - std::optional> sample_weight, - std::optional> centroids, - raft::device_vector_view labels, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter) -{ - kmeans::fit_predict( - handle, params, X, sample_weight, centroids, labels, inertia, n_iter); -} - -template -void kmeans_fit_predict(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* sample_weight, - DataT* centroids, - IndexT n_samples, - IndexT n_features, - IndexT* labels, - DataT& inertia, - IndexT& n_iter) -{ - kmeans::fit_predict( - handle, params, X, sample_weight, centroids, n_samples, n_features, labels, inertia, n_iter); -} - -/** - * @brief Transform X to a cluster-distance space. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Cluster centroids. The data must be in row-major format. - * [dim = n_clusters x n_features] - * @param[out] X_new X transformed in the new space. - * [dim = n_samples x n_features] - */ -template -void kmeans_transform(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_matrix_view X_new) -{ - kmeans::transform(handle, params, X, centroids, X_new); -} - -template -void kmeans_transform(raft::resources const& handle, - const KMeansParams& params, - const DataT* X, - const DataT* centroids, - IndexT n_samples, - IndexT n_features, - DataT* X_new) -{ - kmeans::transform(handle, params, X, centroids, n_samples, n_features, X_new); -} - -template -using SamplingOp = kmeans::SamplingOp; - -template -using KeyValueIndexOp = kmeans::KeyValueIndexOp; - -/** - * @brief Select centroids according to a sampling operation - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] isSampleCentroid Flag the sample chosen as initial centroid - * [dim = n_samples] - * @param[in] select_op The sampling operation used to select the centroids - * @param[out] inRankCp The sampled centroids - * [dim = n_selected_centroids x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void sampleCentroids(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_vector_view minClusterDistance, - raft::device_vector_view isSampleCentroid, - SamplingOp& select_op, - rmm::device_uvector& inRankCp, - rmm::device_uvector& workspace) -{ - kmeans::sample_centroids( - handle, X, minClusterDistance, isSampleCentroid, select_op, inRankCp, workspace); -} - -/** - * @brief Compute cluster cost - * - * @tparam DataT the type of data used for weights, distances. - * @tparam ReductionOpT the type of data used for the reduction operation. - * - * @param[in] handle The raft handle - * @param[in] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] workspace Temporary workspace buffer which can get resized - * @param[out] clusterCost Resulting cluster cost - * @param[in] reduction_op The reduction operation used for the cost - * - */ -template -void computeClusterCost(raft::resources const& handle, - raft::device_vector_view minClusterDistance, - rmm::device_uvector& workspace, - raft::device_scalar_view clusterCost, - ReductionOpT reduction_op) -{ - kmeans::cluster_cost(handle, minClusterDistance, workspace, clusterCost, reduction_op); -} - -/** - * @brief Compute distance for every sample to it's nearest centroid - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[out] minClusterDistance Distance for every sample to it's nearest centroid - * [dim = n_samples] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance - * matrix - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void minClusterDistanceCompute(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view minClusterDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - rmm::device_uvector& workspace) -{ - kmeans::min_cluster_distance(handle, - X, - centroids, - minClusterDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); -} - -/** - * @brief Calculates a pair for every sample in input 'X' where key is an - * index of one of the 'centroids' (index of the nearest centroid) and 'value' - * is the distance between the sample and the 'centroid[key]' - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[out] minClusterAndDistance Distance vector that contains for every sample, the nearest - * centroid and it's distance - * [dim = n_samples] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[out] L2NormBuf_OR_DistBuf Resizable buffer to store L2 norm of centroids or distance - * matrix - * @param[in] workspace Temporary workspace buffer which can get resized - * - */ -template -void minClusterAndDistanceCompute( - raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view, IndexT> minClusterAndDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - rmm::device_uvector& workspace) -{ - kmeans::min_cluster_and_distance(handle, - X, - centroids, - minClusterAndDistance, - L2NormX, - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); -} - -/** - * @brief Shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores - * in 'out' does not modify the input - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] in The data to shuffle and gather - * [dim = n_samples x n_features] - * @param[out] out The sampled data - * [dim = n_samples_to_gather x n_features] - * @param[in] n_samples_to_gather Number of sample to gather - * @param[in] seed Seed for the shuffle - * - */ -template -void shuffleAndGather(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out, - uint32_t n_samples_to_gather, - uint64_t seed) -{ - kmeans::shuffle_and_gather(handle, in, out, n_samples_to_gather, seed); -} - -/** - * @brief Count the number of samples in each cluster - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[in] L2NormX L2 norm of X : ||x||^2 - * [dim = n_samples] - * @param[in] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - * @param[out] sampleCountInCluster The count for each centroid - * [dim = n_cluster] - * - */ -template -void countSamplesInCluster(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view L2NormX, - raft::device_matrix_view centroids, - rmm::device_uvector& workspace, - raft::device_vector_view sampleCountInCluster) -{ - kmeans::count_samples_in_cluster( - handle, params, X, L2NormX, centroids, workspace, sampleCountInCluster); -} - -/* - * @brief Selects 'n_clusters' samples from the input X using kmeans++ algorithm. - - * @note This is the algorithm described in - * "k-means++: the advantages of careful seeding". 2007, Arthur, D. and Vassilvitskii, S. - * ACM-SIAM symposium on Discrete algorithms. - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle - * @param[in] params The parameters for KMeans - * @param[in] X The data in row-major format - * [dim = n_samples x n_features] - * @param[out] centroids Centroids data - * [dim = n_cluster x n_features] - * @param[in] workspace Temporary workspace buffer which can get resized - */ -template -void kmeansPlusPlus(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_matrix_view centroidsRawData, - rmm::device_uvector& workspace) -{ - kmeans::init_plus_plus(handle, params, X, centroidsRawData, workspace); -} - -/* - * @brief Main function used to fit KMeans (after cluster initialization) - * - * @tparam DataT the type of data used for weights, distances. - * @tparam IndexT the type of data used for indexing. - * - * @param[in] handle The raft handle. - * @param[in] params Parameters for KMeans model. - * @param[in] X Training instances to cluster. The data must - * be in row-major format. - * [dim = n_samples x n_features] - * @param[in] sample_weight Weights for each observation in X. - * [len = n_samples] - * @param[inout] centroids [in] Initial cluster centers. - * [out] The generated centroids from the - * kmeans algorithm are stored at the address - * pointed by 'centroids'. - * [dim = n_clusters x n_features] - * @param[out] inertia Sum of squared distances of samples to their - * closest cluster center. - * @param[out] n_iter Number of iterations run. - * @param[in] workspace Temporary workspace buffer which can get resized - */ -template -void kmeans_fit_main(raft::resources const& handle, - const KMeansParams& params, - raft::device_matrix_view X, - raft::device_vector_view weight, - raft::device_matrix_view centroidsRawData, - raft::host_scalar_view inertia, - raft::host_scalar_view n_iter, - rmm::device_uvector& workspace) -{ - kmeans::fit_main( - handle, params, X, weight, centroidsRawData, inertia, n_iter, workspace); -} -}; // namespace raft::cluster diff --git a/cpp/include/raft/cluster/kmeans_balanced.cuh b/cpp/include/raft/cluster/kmeans_balanced.cuh deleted file mode 100644 index 7479047fce..0000000000 --- a/cpp/include/raft/cluster/kmeans_balanced.cuh +++ /dev/null @@ -1,371 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -#include - -namespace raft::cluster::kmeans_balanced { - -/** - * @brief Find clusters of balanced sizes with a hierarchical k-means algorithm. - * - * This variant of the k-means algorithm first clusters the dataset in mesoclusters, then clusters - * the subsets associated to each mesocluster into fine clusters, and finally runs a few k-means - * iterations over the whole dataset and with all the centroids to obtain the final clusters. - * - * Each k-means iteration applies expectation-maximization-balancing: - * - Balancing: adjust centers for clusters that have a small number of entries. If the size of a - * cluster is below a threshold, the center is moved towards a bigger cluster. - * - Expectation: predict the labels (i.e find closest cluster centroid to each point) - * - Maximization: calculate optimal centroids (i.e find the center of gravity of each cluster) - * - * The number of mesoclusters is chosen by rounding the square root of the number of clusters. E.g - * for 512 clusters, we would have 23 mesoclusters. The number of fine clusters per mesocluster is - * chosen proportionally to the number of points in each mesocluster. - * - * This variant of k-means uses random initialization and a fixed number of iterations, though - * iterations can be repeated if the balancing step moved the centroids. - * - * Additionally, this algorithm supports quantized datasets in arbitrary types but the core part of - * the algorithm will work with a floating-point type, hence a conversion function can be provided - * to map the data type to the math type. - * - * @code{.cpp} - * #include - * #include - * #include - * ... - * raft::handle_t handle; - * raft::cluster::kmeans_balanced_params params; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * raft::cluster::kmeans_balanced::fit(handle, params, X, centroids.view()); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] params Structure containing the hyper-parameters - * @param[in] X Training instances to cluster. The data must be in row-major format. - * [dim = n_samples x n_features] - * @param[out] centroids The generated centroids [dim = n_clusters x n_features] - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic - * datatype. If DataT == MathT, this must be the identity. - */ -template -[[deprecated("Use cuVS instead")]] void fit(const raft::resources& handle, - kmeans_balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - MappingOpT mapping_op = raft::identity_op()) -{ - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); - RAFT_EXPECTS(static_cast(X.extent(0)) * static_cast(X.extent(1)) <= - static_cast(std::numeric_limits::max()), - "The chosen index type cannot represent all indices for the given dataset"); - RAFT_EXPECTS(centroids.extent(0) > IndexT{0} && centroids.extent(0) <= X.extent(0), - "The number of centroids must be strictly positive and cannot exceed the number of " - "points in the training dataset."); - - detail::build_hierarchical(handle, - params, - X.extent(1), - X.data_handle(), - X.extent(0), - centroids.data_handle(), - centroids.extent(0), - mapping_op); -} - -/** - * @brief Predict the closest cluster each sample in X belongs to. - * - * @code{.cpp} - * #include - * #include - * #include - * ... - * raft::handle_t handle; - * raft::cluster::kmeans_balanced_params params; - * auto labels = raft::make_device_vector(handle, n_rows); - * raft::cluster::kmeans_balanced::predict(handle, params, X, centroids, labels); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam LabelT Type of the output labels. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] params Structure containing the hyper-parameters - * @param[in] X Dataset for which to infer the closest clusters. - * [dim = n_samples x n_features] - * @param[in] centroids The input centroids [dim = n_clusters x n_features] - * @param[out] labels The output labels [dim = n_samples] - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic - * datatype. If DataT == MathT, this must be the identity. - */ -template -[[deprecated("Use cuVS instead")]] void predict( - const raft::resources& handle, - kmeans_balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - MappingOpT mapping_op = raft::identity_op()) -{ - RAFT_EXPECTS(X.extent(0) == labels.extent(0), - "Number of rows in dataset and labels are different"); - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); - RAFT_EXPECTS(static_cast(X.extent(0)) * static_cast(X.extent(1)) <= - static_cast(std::numeric_limits::max()), - "The chosen index type cannot represent all indices for the given dataset"); - RAFT_EXPECTS(static_cast(centroids.extent(0)) <= - static_cast(std::numeric_limits::max()), - "The chosen label type cannot represent all cluster labels"); - - detail::predict(handle, - params, - centroids.data_handle(), - centroids.extent(0), - X.extent(1), - X.data_handle(), - X.extent(0), - labels.data_handle(), - mapping_op); -} - -/** - * @brief Compute hierarchical balanced k-means clustering and predict cluster index for each sample - * in the input. - * - * @code{.cpp} - * #include - * #include - * #include - * ... - * raft::handle_t handle; - * raft::cluster::kmeans_balanced_params params; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, n_rows); - * raft::cluster::kmeans_balanced::fit_predict( - * handle, params, X, centroids.view(), labels.view()); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam LabelT Type of the output labels. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] params Structure containing the hyper-parameters - * @param[in] X Training instances to cluster. The data must be in row-major format. - * [dim = n_samples x n_features] - * @param[out] centroids The output centroids [dim = n_clusters x n_features] - * @param[out] labels The output labels [dim = n_samples] - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic - * datatype. If DataT and MathT are the same, this must be the identity. - */ -template -[[deprecated("Use cuVS instead")]] void fit_predict( - const raft::resources& handle, - kmeans_balanced_params const& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - MappingOpT mapping_op = raft::identity_op()) -{ - auto centroids_const = raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - raft::cluster::kmeans_balanced::fit(handle, params, X, centroids, mapping_op); - raft::cluster::kmeans_balanced::predict(handle, params, X, centroids_const, labels, mapping_op); -} - -namespace helpers { - -/** - * @brief Randomly initialize centers and apply expectation-maximization-balancing iterations - * - * This is essentially the non-hierarchical balanced k-means algorithm which is used by the - * hierarchical algorithm once to build the mesoclusters and once per mesocluster to build the fine - * clusters. - * - * @code{.cpp} - * #include - * #include - * #include - * ... - * raft::handle_t handle; - * raft::cluster::kmeans_balanced_params params; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto labels = raft::make_device_vector(handle, n_samples); - * auto sizes = raft::make_device_vector(handle, n_clusters); - * raft::cluster::kmeans_balanced::build_clusters( - * handle, params, X, centroids.view(), labels.view(), sizes.view()); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam LabelT Type of the output labels. - * @tparam CounterT Counter type supported by CUDA's native atomicAdd. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] params Structure containing the hyper-parameters - * @param[in] X Training instances to cluster. The data must be in row-major format. - * [dim = n_samples x n_features] - * @param[out] centroids The output centroids [dim = n_clusters x n_features] - * @param[out] labels The output labels [dim = n_samples] - * @param[out] cluster_sizes Size of each cluster [dim = n_clusters] - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the - * arithmetic datatype. If DataT == MathT, this must be the identity. - * @param[in] X_norm (optional) Dataset's row norms [dim = n_samples] - */ -template -[[deprecated("Use cuVS instead")]] void build_clusters( - const raft::resources& handle, - const kmeans_balanced_params& params, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view labels, - raft::device_vector_view cluster_sizes, - MappingOpT mapping_op = raft::identity_op(), - std::optional> X_norm = std::nullopt) -{ - RAFT_EXPECTS(X.extent(0) == labels.extent(0), - "Number of rows in dataset and labels are different"); - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); - RAFT_EXPECTS(centroids.extent(0) == cluster_sizes.extent(0), - "Number of rows in centroids and clusyer_sizes are different"); - - detail::build_clusters(handle, - params, - X.extent(1), - X.data_handle(), - X.extent(0), - centroids.extent(0), - centroids.data_handle(), - labels.data_handle(), - cluster_sizes.data_handle(), - mapping_op, - resource::get_workspace_resource(handle), - X_norm.has_value() ? X_norm.value().data_handle() : nullptr); -} - -/** - * @brief Given the data and labels, calculate cluster centers and sizes in one sweep. - * - * Let `S_i = {x_k | x_k \in X & labels[k] == i}` be the vectors in the dataset with label i. - * - * On exit, - * `centers_i = (\sum_{x \in S_i} x + w_i * center_i) / (|S_i| + w_i)`, - * where `w_i = reset_counters ? 0 : cluster_size[i]`. - * - * In other words, the updated cluster centers are a weighted average of the existing cluster - * center, and the coordinates of the points labeled with i. _This allows calling this function - * multiple times with different datasets with the same effect as if calling this function once - * on the combined dataset_. - * - * @code{.cpp} - * #include - * #include - * ... - * raft::handle_t handle; - * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); - * auto sizes = raft::make_device_vector(handle, n_clusters); - * raft::cluster::kmeans_balanced::calc_centers_and_sizes( - * handle, X, labels, centroids.view(), sizes.view(), true); - * @endcode - * - * @tparam DataT Type of the input data. - * @tparam MathT Type of the centroids and mapped data. - * @tparam IndexT Type used for indexing. - * @tparam LabelT Type of the output labels. - * @tparam CounterT Counter type supported by CUDA's native atomicAdd. - * @tparam MappingOpT Type of the mapping function. - * @param[in] handle The raft resources - * @param[in] X Dataset for which to calculate cluster centers. The data must be in - * row-major format. [dim = n_samples x n_features] - * @param[in] labels The input labels [dim = n_samples] - * @param[out] centroids The output centroids [dim = n_clusters x n_features] - * @param[out] cluster_sizes Size of each cluster [dim = n_clusters] - * @param[in] reset_counters Whether to clear the output arrays before calculating. - * When set to `false`, this function may be used to update existing - * centers and sizes using the weighted average principle. - * @param[in] mapping_op (optional) Functor to convert from the input datatype to the - * arithmetic datatype. If DataT == MathT, this must be the identity. - */ -template -[[deprecated("Use cuVS instead")]] void calc_centers_and_sizes( - const raft::resources& handle, - raft::device_matrix_view X, - raft::device_vector_view labels, - raft::device_matrix_view centroids, - raft::device_vector_view cluster_sizes, - bool reset_counters = true, - MappingOpT mapping_op = raft::identity_op()) -{ - RAFT_EXPECTS(X.extent(0) == labels.extent(0), - "Number of rows in dataset and labels are different"); - RAFT_EXPECTS(X.extent(1) == centroids.extent(1), - "Number of features in dataset and centroids are different"); - RAFT_EXPECTS(centroids.extent(0) == cluster_sizes.extent(0), - "Number of rows in centroids and clusyer_sizes are different"); - - detail::calc_centers_and_sizes(handle, - centroids.data_handle(), - cluster_sizes.data_handle(), - centroids.extent(0), - X.extent(1), - X.data_handle(), - X.extent(0), - labels.data_handle(), - reset_counters, - mapping_op, - resource::get_workspace_resource(handle)); -} - -} // namespace helpers - -} // namespace raft::cluster::kmeans_balanced diff --git a/cpp/include/raft/cluster/kmeans_balanced_types.hpp b/cpp/include/raft/cluster/kmeans_balanced_types.hpp deleted file mode 100644 index 11b77e288a..0000000000 --- a/cpp/include/raft/cluster/kmeans_balanced_types.hpp +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -namespace raft::cluster::kmeans_balanced { - -/** - * Simple object to specify hyper-parameters to the balanced k-means algorithm. - * - * The following metrics are currently supported in k-means balanced: - * - InnerProduct - * - L2Expanded - * - L2SqrtExpanded - */ -struct kmeans_balanced_params : kmeans_base_params { - /** - * Number of training iterations - */ - uint32_t n_iters = 20; -}; - -} // namespace raft::cluster::kmeans_balanced - -namespace raft::cluster { - -using kmeans_balanced::kmeans_balanced_params; - -} // namespace raft::cluster diff --git a/cpp/include/raft/cluster/kmeans_deprecated.cuh b/cpp/include/raft/cluster/kmeans_deprecated.cuh deleted file mode 100644 index 11f964eef5..0000000000 --- a/cpp/include/raft/cluster/kmeans_deprecated.cuh +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -namespace raft { -namespace cluster { - -/** - * @brief Find clusters with k-means algorithm. - * Initial centroids are chosen with k-means++ algorithm. Empty - * clusters are reinitialized by choosing new centroids with - * k-means++ algorithm. - * @tparam index_type_t the type of data used for indexing. - * @tparam value_type_t the type of data used for weights, distances. - * @param handle the raft handle. - * @param n Number of observation vectors. - * @param d Dimension of observation vectors. - * @param k Number of clusters. - * @param tol Tolerance for convergence. k-means stops when the - * change in residual divided by n is less than tol. - * @param maxiter Maximum number of k-means iterations. - * @param obs (Input, device memory, d*n entries) Observation - * matrix. Matrix is stored column-major and each column is an - * observation vector. Matrix dimensions are d x n. - * @param codes (Output, device memory, n entries) Cluster - * assignments. - * @param residual On exit, residual sum of squares (sum of squares - * of distances between observation vectors and centroids). - * @param iters on exit, number of k-means iterations. - * @param seed random seed to be used. - * @return error flag - */ -template -int kmeans(raft::resources const& handle, - index_type_t n, - index_type_t d, - index_type_t k, - value_type_t tol, - index_type_t maxiter, - const value_type_t* __restrict__ obs, - index_type_t* __restrict__ codes, - value_type_t& residual, - index_type_t& iters, - unsigned long long seed = 123456) -{ - return detail::kmeans( - handle, n, d, k, tol, maxiter, obs, codes, residual, iters, seed); -} -} // namespace cluster -} // namespace raft diff --git a/cpp/include/raft/cluster/kmeans_types.hpp b/cpp/include/raft/cluster/kmeans_types.hpp deleted file mode 100644 index 4d956ad7a0..0000000000 --- a/cpp/include/raft/cluster/kmeans_types.hpp +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include - -namespace raft::cluster { - -/** Base structure for parameters that are common to all k-means algorithms */ -struct kmeans_base_params { - /** - * Metric to use for distance computation. The supported metrics can vary per algorithm. - */ - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; -}; - -} // namespace raft::cluster - -namespace raft::cluster::kmeans { - -/** - * Simple object to specify hyper-parameters to the kmeans algorithm. - */ -struct KMeansParams : kmeans_base_params { - enum InitMethod { - - /** - * Sample the centroids using the kmeans++ strategy - */ - KMeansPlusPlus, - - /** - * Sample the centroids uniformly at random - */ - Random, - - /** - * User provides the array of initial centroids - */ - Array - }; - - /** - * The number of clusters to form as well as the number of centroids to generate (default:8). - */ - int n_clusters = 8; - - /** - * Method for initialization, defaults to k-means++: - * - InitMethod::KMeansPlusPlus (k-means++): Use scalable k-means++ algorithm - * to select the initial cluster centers. - * - InitMethod::Random (random): Choose 'n_clusters' observations (rows) at - * random from the input data for the initial centroids. - * - InitMethod::Array (ndarray): Use 'centroids' as initial cluster centers. - */ - InitMethod init = KMeansPlusPlus; - - /** - * Maximum number of iterations of the k-means algorithm for a single run. - */ - int max_iter = 300; - - /** - * Relative tolerance with regards to inertia to declare convergence. - */ - double tol = 1e-4; - - /** - * verbosity level. - */ - int verbosity = RAFT_LEVEL_INFO; - - /** - * Seed to the random number generator. - */ - raft::random::RngState rng_state{0}; - - /** - * Number of instance k-means algorithm will be run with different seeds. - */ - int n_init = 1; - - /** - * Oversampling factor for use in the k-means|| algorithm - */ - double oversampling_factor = 2.0; - - // batch_samples and batch_centroids are used to tile 1NN computation which is - // useful to optimize/control the memory footprint - // Default tile is [batch_samples x n_clusters] i.e. when batch_centroids is 0 - // then don't tile the centroids - int batch_samples = 1 << 15; - - /** - * if 0 then batch_centroids = n_clusters - */ - int batch_centroids = 0; // - - bool inertia_check = false; -}; - -} // namespace raft::cluster::kmeans - -namespace raft::cluster { - -using kmeans::KMeansParams; - -} // namespace raft::cluster diff --git a/cpp/include/raft/cluster/single_linkage.cuh b/cpp/include/raft/cluster/single_linkage.cuh deleted file mode 100644 index 067445c542..0000000000 --- a/cpp/include/raft/cluster/single_linkage.cuh +++ /dev/null @@ -1,113 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include - -namespace raft::cluster { - -/** - * Note: All of the functions below in the raft::cluster namespace are deprecated - * and will be removed in a future release. Please use raft::cluster::hierarchy - * instead. - */ - -/** - * Single-linkage clustering, capable of constructing a KNN graph to - * scale the algorithm beyond the n^2 memory consumption of implementations - * that use the fully-connected graph of pairwise distances by connecting - * a knn graph when k is not large enough to connect it. - - * @tparam value_idx - * @tparam value_t - * @tparam dist_type method to use for constructing connectivities graph - * @param[in] handle raft handle - * @param[in] X dense input matrix in row-major layout - * @param[in] m number of rows in X - * @param[in] n number of columns in X - * @param[in] metric distance metrix to use when constructing connectivities graph - * @param[out] out struct containing output dendrogram and cluster assignments - * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect - control - * of k. The algorithm will set `k = log(n) + c` - * @param[in] n_clusters number of clusters to assign data samples - */ -template -[[deprecated("Use cuVS instead")]] void single_linkage(raft::resources const& handle, - const value_t* X, - size_t m, - size_t n, - raft::distance::DistanceType metric, - linkage_output* out, - int c, - size_t n_clusters) -{ - detail::single_linkage( - handle, X, m, n, metric, out, c, n_clusters); -} -}; // namespace raft::cluster - -namespace raft::cluster::hierarchy { - -constexpr int DEFAULT_CONST_C = 15; - -/** - * Single-linkage clustering, capable of constructing a KNN graph to - * scale the algorithm beyond the n^2 memory consumption of implementations - * that use the fully-connected graph of pairwise distances by connecting - * a knn graph when k is not large enough to connect it. - - * @tparam value_idx - * @tparam value_t - * @tparam dist_type method to use for constructing connectivities graph - * @param[in] handle raft handle - * @param[in] X dense input matrix in row-major layout - * @param[out] dendrogram output dendrogram (size [n_rows - 1] * 2) - * @param[out] labels output labels vector (size n_rows) - * @param[in] metric distance metrix to use when constructing connectivities graph - * @param[in] n_clusters number of clusters to assign data samples - * @param[in] c a constant used when constructing connectivities from knn graph. Allows the indirect - control of k. The algorithm will set `k = log(n) + c` - */ -template -[[deprecated("Use cuVS instead")]] void single_linkage( - raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view dendrogram, - raft::device_vector_view labels, - raft::distance::DistanceType metric, - size_t n_clusters, - std::optional c = std::make_optional(DEFAULT_CONST_C)) -{ - linkage_output out_arrs; - out_arrs.children = dendrogram.data_handle(); - out_arrs.labels = labels.data_handle(); - - raft::cluster::single_linkage( - handle, - X.data_handle(), - static_cast(X.extent(0)), - static_cast(X.extent(1)), - metric, - &out_arrs, - c.has_value() ? c.value() : DEFAULT_CONST_C, - n_clusters); -} -}; // namespace raft::cluster::hierarchy diff --git a/cpp/include/raft/cluster/single_linkage_types.hpp b/cpp/include/raft/cluster/single_linkage_types.hpp deleted file mode 100644 index cd815622bf..0000000000 --- a/cpp/include/raft/cluster/single_linkage_types.hpp +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::cluster::hierarchy { - -/** - * Determines the method for computing the minimum spanning tree (MST) - */ -enum LinkageDistance { - - /** - * Use a pairwise distance matrix as input to the mst. This - * is very fast and the best option for fairly small datasets (~50k data points) - */ - PAIRWISE = 0, - - /** - * Construct a KNN graph as input to the mst and provide additional - * edges if the mst does not converge. This is slower but scales - * to very large datasets. - */ - KNN_GRAPH = 1 -}; - -}; // end namespace raft::cluster::hierarchy - -// The code below is now considered legacy -namespace raft::cluster { - -using hierarchy::LinkageDistance; - -/** - * Simple container object for consolidating linkage results. This closely - * mirrors the trained instance variables populated in - * Scikit-learn's AgglomerativeClustering estimator. - * @tparam value_idx - * @tparam value_t - */ -template -class linkage_output { - public: - idx_t m; - idx_t n_clusters; - - idx_t n_leaves; - idx_t n_connected_components; - - // TODO: These will be made private in a future release - idx_t* labels; // size: m - idx_t* children; // size: (m-1, 2) - - raft::device_vector_view get_labels() - { - return raft::make_device_vector_view(labels, m); - } - - raft::device_matrix_view get_children() - { - return raft::make_device_matrix_view(children, m - 1, 2); - } -}; - -class linkage_output_int : public linkage_output {}; -class linkage_output_int64 : public linkage_output {}; - -}; // namespace raft::cluster diff --git a/cpp/include/raft/cluster/specializations.cuh b/cpp/include/raft/cluster/specializations.cuh deleted file mode 100644 index e85b05575f..0000000000 --- a/cpp/include/raft/cluster/specializations.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_HIDE_DEPRECATION_WARNINGS -#pragma message( \ - __FILE__ \ - " is deprecated and will be removed." \ - " Including specializations is not necessary any more." \ - " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") -#endif diff --git a/cpp/include/raft/distance/detail/compress_to_bits.cuh b/cpp/include/raft/distance/detail/compress_to_bits.cuh deleted file mode 100644 index 5ffb717c42..0000000000 --- a/cpp/include/raft/distance/detail/compress_to_bits.cuh +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include - -namespace raft::distance::detail { - -/** - * @brief Compress 2D boolean matrix to bitfield - * - * Utility kernel for masked_l2_nn. - * - * @tparam T - * - * @parameter[in] in An `m x n` boolean matrix. Row major. - * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of - * type T, where T is of size `bits_per_elem` bits. - * Note: the division (`/`) is a ceilDiv. - */ -template ::value>> -RAFT_KERNEL compress_to_bits_kernel( - raft::device_matrix_view in, - raft::device_matrix_view out) -{ - constexpr int bits_per_element = 8 * sizeof(T); - constexpr int tile_dim_m = bits_per_element; - constexpr int nthreads = 128; - constexpr int tile_dim_n = nthreads; // read 128 bools at once = 1 sector - - // Tile in shared memory is transposed - __shared__ bool smem[tile_dim_n][tile_dim_m]; - - const int num_tiles_per_m = raft::ceildiv(in.extent(0), tile_dim_m); - const int num_tiles_per_n = raft::ceildiv(in.extent(1), tile_dim_n); - - for (int lin_tile_idx = blockIdx.x; true; lin_tile_idx += gridDim.x) { - const int tile_idx_n = tile_dim_n * (lin_tile_idx % num_tiles_per_n); - const int tile_idx_m = tile_dim_m * (lin_tile_idx / num_tiles_per_n); - - if (in.extent(0) <= tile_idx_m) { break; } - // Fill shared memory tile - bool reg_buf[tile_dim_m]; -#pragma unroll - for (int i = 0; i < tile_dim_m; ++i) { - const int in_m = tile_idx_m + i; - const int in_n = tile_idx_n + threadIdx.x; - bool in_bounds = in_m < in.extent(0) && in_n < in.extent(1); - reg_buf[i] = in_bounds ? in(in_m, in_n) : false; - smem[threadIdx.x][i] = reg_buf[i]; - } - __syncthreads(); - - // Drain memory tile into single output element out_elem. - T out_elem{0}; -#pragma unroll - for (int j = 0; j < tile_dim_n; ++j) { - if (smem[threadIdx.x][j]) { out_elem |= T(1) << j; } - } - __syncthreads(); - - // Write output. - int out_m = tile_idx_m / bits_per_element; - int out_n = tile_idx_n + threadIdx.x; - - if (out_m < out.extent(0) && out_n < out.extent(1)) { out(out_m, out_n) = out_elem; } - } -} - -/** - * @brief Compress 2D boolean matrix to bitfield - * - * Utility kernel for masked_l2_nn. - * - * @tparam T - * - * @parameter[in] in An `m x n` boolean matrix. Row major. - * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of - * type T, where T is of size `bits_per_elem` bits. - * Note: the division (`/`) is a ceilDiv. - */ -template ::value>> -void compress_to_bits(raft::resources const& handle, - raft::device_matrix_view in, - raft::device_matrix_view out) -{ - auto stream = resource::get_cuda_stream(handle); - constexpr int bits_per_element = 8 * sizeof(T); - - RAFT_EXPECTS(raft::ceildiv(in.extent(0), bits_per_element) == out.extent(0), - "Number of output rows must be ceildiv(input rows, bits_per_elem)"); - RAFT_EXPECTS(in.extent(1) == out.extent(1), "Number of output columns must equal input columns."); - - const int num_SMs = raft::getMultiProcessorCount(); - int blocks_per_sm = 0; - constexpr int num_threads = 128; - constexpr int dyn_smem_size = 0; - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &blocks_per_sm, compress_to_bits_kernel, num_threads, dyn_smem_size)); - - dim3 grid(num_SMs * blocks_per_sm); - dim3 block(128); - compress_to_bits_kernel<<>>(in, out); - RAFT_CUDA_TRY(cudaGetLastError()); -} - -}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh deleted file mode 100644 index a39dbf6700..0000000000 --- a/cpp/include/raft/distance/detail/distance.cuh +++ /dev/null @@ -1,873 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief: A tag type for overload resolution based on DistanceType - * - * It is not possible to partially specialize function templates on a single - * parameter. Instead, it is often easier to use a combination of conventional - * method overloading and a parameter with a specific tag type. The following - * type is used to help method overloading based on the DistanceType enum. - */ -template -using distance_tag = std::integral_constant; - -/** - * @brief Implement pairwise_matrix for specific distance - * - * There are multiple overloads for this function, one for each distance type. - * They are implemented below. The documentation of this function serves as - * documentation for all functions. The following overloads are defined: - * - * - DistanceType::Canberra: - * - DistanceType::CorrelationExpanded: - * - DistanceType::CosineExpanded: - * - DistanceType::DiceExpanded: - * - DistanceType::HammingUnexpanded: - * - DistanceType::HellingerExpanded: - * - DistanceType::JensenShannon: - * - DistanceType::KLDivergence: - * - DistanceType::L1: - * - DistanceType::L2Expanded: - * - DistanceType::L2SqrtExpanded: - * - DistanceType::L2Unexpanded: - * - DistanceType::L2SqrtUnexpanded: - * - DistanceType::Linf: - * - DistanceType::LpUnexpanded: - * - DistanceType::RusselRaoExpanded: - * - * @tparam DataT Input data type - * @tparam AccT Accumulation data type - * @tparam OutT Output data type - * @tparam FinOpT Type of final operation - * @tparam IdxT Index type - * - * @param handle RAFT resources handle - * @param distance_type A tag type to indicate which distance is calculated. - * @param x First set of points - * @param y Second set of points - * @param out Output distance matrix - * @param m Number of points in x - * @param n Number of points in y - * @param k Dimensionality of points in x, y - * @param workspace Temporary workspace needed for computations - * @param worksize Number of bytes of the workspace - * @param is_row_major Whether the matrices are row-major or col-major - * @param metric_arg The `p` argument for Lp. - */ -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, // unused - size_t worksize, // unused - FinOpT fin_op, - bool is_row_major, - DataT metric_arg) // unused -{ - ops::canberra_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - bool is_row_major, - DataT) // unused -{ - ASSERT(!(worksize < 2 * (m + n) * sizeof(AccT)), "workspace size error"); - ASSERT(workspace != nullptr, "workspace is null"); - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - AccT* x_norm = workspace; - AccT* y_norm = workspace; - AccT* sq_x_norm = workspace; - AccT* sq_y_norm = workspace; - // TODO: Column major case looks to have lower accuracy for X == Y, - // perhaps the use of stridedSummationKernel could be causing this, - // need to investigate and fix. - if (x == y && is_row_major) { - raft::linalg::reduce(x_norm, - x, - k, - std::max(m, n), - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - sq_x_norm += std::max(m, n); - sq_y_norm = sq_x_norm; - raft::linalg::rowNorm( - sq_x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream); - } else { - y_norm += m; - raft::linalg::reduce(x_norm, - x, - k, - m, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - raft::linalg::reduce(y_norm, - y, - k, - n, - (AccT)0, - is_row_major, - true, - stream, - false, - raft::identity_op(), - raft::add_op()); - - sq_x_norm += (m + n); - sq_y_norm = sq_x_norm + m; - raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); - raft::linalg::rowNorm(sq_y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream); - } - - using OpT = ops::correlation_distance_op; - OpT corr_op(is_row_major, sq_x_norm, sq_y_norm, m, n, k); - pairwise_matrix_dispatch( - corr_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - bool is_row_major, - DataT) // unused -{ - // raft distance support inputs as float/double and output as uint8_t/float/double. - static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), - "OutT can be uint8_t, float, double," - "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - - ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); - ASSERT(workspace != nullptr, "workspace is null"); - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - DataT* x_norm = workspace; - DataT* y_norm = workspace; - // TODO: Column major case looks to have lower accuracy for X == Y, - // perhaps the use of stridedSummationKernel could be causing this, - // need to investigate and fix. - if (x == y && is_row_major) { - raft::linalg::rowNorm( - x_norm, x, k, std::max(m, n), raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); - } else { - y_norm += m; - raft::linalg::rowNorm( - x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); - raft::linalg::rowNorm( - y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); - } - - ops::cosine_distance_op distance_op{}; - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - bool is_row_major, - DataT) // unused -{ - // raft distance support inputs as float/double and output as uint8_t/float/double. - static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), - "OutT can be uint8_t, float, double," - "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - - ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); - ASSERT(workspace != nullptr, "workspace is null"); - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - DataT* x_norm = workspace; - DataT* y_norm = workspace; - // TODO: Column major case looks to have lower accuracy for X == Y, - // perhaps the use of stridedSummationKernel could be causing this, - // need to investigate and fix. - if (x == y && is_row_major) { - raft::linalg::reduce(x_norm, - x, - k, - std::max(m, n), - (AccT)0, - is_row_major, - true, - stream, - false, - raft::nz_op(), - raft::add_op()); - } else { - y_norm += m; - raft::linalg::reduce( - x_norm, x, k, m, (AccT)0, is_row_major, true, stream, false, raft::nz_op(), raft::add_op()); - raft::linalg::reduce( - y_norm, y, k, n, (AccT)0, is_row_major, true, stream, false, raft::nz_op(), raft::add_op()); - } - - ops::dice_distance_op distance_op{}; - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - ops::hamming_distance_op distance_op{k}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - raft::linalg::gemm(handle, - out, - const_cast(x), - const_cast(y), - m, - n, - k, - !is_row_major, - !is_row_major, - is_row_major, - stream); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - // First sqrt x and y - const auto raft_sqrt = raft::linalg::unaryOp; - - raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); - if (x != y) { raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } - - // Then calculate Hellinger distance - ops::hellinger_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); - - // Finally revert sqrt of x and y - raft_sqrt((DataT*)x, x, m * k, raft::sqrt_op{}, stream); - if (x != y) { raft_sqrt((DataT*)y, y, n * k, raft::sqrt_op{}, stream); } - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - ops::jensen_shannon_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - auto unaryOp_lambda = [] __device__(DataT input) { - const bool x_zero = (input == 0); - return (!x_zero) * raft::log(input + x_zero); - }; - - auto unaryOp_lambda_reverse = [] __device__(DataT input) { - // reverse previous log (x) back to x using (e ^ log(x)) - const bool x_zero = (input == 0); - return (!x_zero) * raft::exp(input); - }; - - if (x != y) { - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda, stream); - } - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - // This op takes some shortcuts when x equals y. So its behavior changes based - // on this. - ops::kl_divergence_op distance_op{is_row_major, x == y}; - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); - - if (x != y) { - // Now reverse previous log (x) back to x using (e ^ log(x)) - raft::linalg::unaryOp( - (DataT*)y, y, n * k, unaryOp_lambda_reverse, stream); - } -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - ops::l1_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl_l2_expanded( // NOTE: different name - bool perform_sqrt, // dispatch on sqrt - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // raft distance support inputs as float/double and output as uint8_t/float/double. - static_assert(!((sizeof(OutT) > 1) && (sizeof(AccT) != sizeof(OutT))), - "OutT can be uint8_t, float, double," - "if sizeof(OutT) > 1 then sizeof(AccT) == sizeof(OutT)."); - - ASSERT(!(worksize < (m + n) * sizeof(AccT)), "workspace size error"); - ASSERT(workspace != nullptr, "workspace is null"); - - DataT* x_norm = workspace; - DataT* y_norm = workspace; - // TODO: Column major case looks to have lower accuracy for X == Y, - // perhaps the use of stridedSummationKernel could be causing this, - // need to investigate and fix. - if ((x == y) && is_row_major) { - raft::linalg::rowNorm(x_norm, - x, - k, - std::max(m, n), - raft::linalg::L2Norm, - is_row_major, - stream, - raft::identity_op{}); - } else { - y_norm += m; - raft::linalg::rowNorm( - x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); - raft::linalg::rowNorm( - y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); - } - - ops::l2_exp_distance_op distance_op{perform_sqrt}; - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - bool perform_sqrt = false; - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_impl_l2_expanded( - perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT* workspace, - size_t worksize, - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - bool perform_sqrt = true; - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - distance_impl_l2_expanded( - perform_sqrt, x, y, out, m, n, k, workspace, worksize, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - bool perform_sqrt = false; - ops::l2_unexp_distance_op l2_op(perform_sqrt); - - // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - bool perform_sqrt = true; - ops::l2_unexp_distance_op l2_op(perform_sqrt); - - // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - ops::l_inf_distance_op distance_op{}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT metric_arg) -{ - ops::lp_unexp_distance_op distance_op{metric_arg}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -template -void distance_impl(raft::resources const& handle, - distance_tag distance_type, - const DataT* x, - const DataT* y, - OutT* out, - IdxT m, - IdxT n, - IdxT k, - AccT*, // workspace unused - size_t, // worksize unused - FinOpT fin_op, - bool is_row_major, - DataT) // metric_arg unused -{ - ops::russel_rao_distance_op distance_op{k}; - - const DataT* x_norm = nullptr; - const DataT* y_norm = nullptr; - - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); -} - -/** - * @brief Evaluate pairwise distances with the user epilogue lamba allowed - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type - * - * @param x first set of points - * @param y second set of points - * @param out output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param fin_op the final gemm epilogue lambda - * @param stream cuda stream - * @param isRowMajor whether the matrices are row-major or col-major - * - * @note fin_op: This is a device lambda which is supposed to operate upon the - * input which is AccType and returns the output in OutType. It's signature is - * as follows:
OutType fin_op(AccType in, int g_idx);
. If one needs - * any other parameters, feel free to pass them via closure. - */ -template -void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* out, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - bool isRowMajor = true, - InType metric_arg = 2.0f) -{ - // raft distance support inputs as float/double and output as uint8_t/float/double. - static_assert(!((sizeof(OutType) > 1) && (sizeof(AccType) != sizeof(OutType))), - "OutType can be uint8_t, float, double," - "if sizeof(OutType) > 1 then sizeof(AccType) == sizeof(OutType)."); - - distance_impl( - handle, - distance_tag{}, - x, - y, - out, - m, - n, - k, - reinterpret_cast(workspace), - worksize, - fin_op, - isRowMajor, - metric_arg); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** - * @brief Evaluate pairwise distances for the simple use case - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param stream cuda stream - * @param isRowMajor whether the matrices are row-major or col-major - */ -template -void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* out, - Index_ m, - Index_ n, - Index_ k, - void* workspace, - size_t worksize, - bool isRowMajor = true, - InType metric_arg = 2.0f) -{ - auto fin_op = raft::identity_op(); - - distance( - handle, x, y, out, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); -} - -/** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type - * @param x first set of points - * @param y second set of points - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * - * @note If the specified distanceType doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, Index_ k) -{ - size_t worksize = 0; - constexpr bool is_allocated = - (distanceType <= raft::distance::DistanceType::CosineExpanded) || - (distanceType == raft::distance::DistanceType::CorrelationExpanded) || - (distanceType == raft::distance::DistanceType::DiceExpanded); - constexpr int numOfBuffers = - (distanceType == raft::distance::DistanceType::CorrelationExpanded) ? 2 : 1; - - if (is_allocated) { - // TODO : when X == Y allocate std::max(m, n) instead of m + n when column major input - // accuracy issue is resolved until then we allocate as m + n. - worksize += numOfBuffers * m * sizeof(AccType); - worksize += numOfBuffers * n * sizeof(AccType); - } - - return worksize; -} - -}; // namespace detail -}; // namespace distance -}; // namespace raft diff --git a/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh b/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh deleted file mode 100644 index 84eb3c705b..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/all_ops.cuh +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -// Defines a named requirement "has_cutlass_op" -#include - -// The distance operations: -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh deleted file mode 100644 index eaf37b7e9c..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::abs -#include // DI - -namespace raft::distance::detail::ops { - -/** - * @brief The canberra distance matrix calculation - * - * It computes the following equation: - * - * c_ij = sum_k |x_ik - y_kj| / ( |x_ik| + |y_kj| ) - */ -template -struct canberra_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = true; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - const auto diff = raft::abs(x - y); - const auto add = raft::abs(x) + raft::abs(y); - // deal with potential for 0 in denominator by - // forcing 0/1 instead - acc += ((add != 0) * diff / (add + (add == 0))); - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - return; - } -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh deleted file mode 100644 index 4fc4bb8297..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh +++ /dev/null @@ -1,126 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace raft::distance::detail::ops { - -/** @brief The correlation distance - * - * It computes the following equation: - * - * d(x, y) = ((x - mean(x)) ⋅ (y - mean(y))) - * / - * (|| x - mean(x) ||_2 || y - mean(y) ||_2) - */ -template -struct correlation_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - const DataT* x2n; - const DataT* y2n; - IdxT m; - IdxT n; - IdxT k; - - correlation_distance_op( - bool is_row_major, const DataT* x2n_, const DataT* y2n_, IdxT m_, IdxT n_, IdxT k_) noexcept - : x2n(x2n_), y2n(y2n_), m(m_), n(n_), k(k_) - { - // The distance op is typically created before the row-major/col-major - // swapping has been done. So we do it here. - if (!is_row_major) { - std::swap(x2n, y2n); - std::swap(m, n); - } - } - - // Load norms of input data - static constexpr bool use_norms = true; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize + (2 * (Policy::Mblk + Policy::Nblk) * sizeof(DataT)); - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - // Note how we can sneakily get a pointer to shared memory here, to store - // more data. If the implementation of PairwiseDistanceMatKernel ever - // changes, this will be where we find the bugs. - extern __shared__ char smem[]; - - DataT regx2n[Policy::AccRowsPerTh], regy2n[Policy::AccColsPerTh]; - - DataT* sx2Norm = - (DataT*)(&smem[Policy::SmemSize + (Policy::Mblk + Policy::Nblk) * sizeof(DataT)]); - DataT* sy2Norm = (&sx2Norm[Policy::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - if (gridStrideX == blockIdx.x * Policy::Nblk) { - for (int i = threadIdx.x; i < Policy::Mblk; i += Policy::Nthreads) { - auto idx = gridStrideY + i; - sx2Norm[i] = idx < m ? x2n[idx] : 0; - } - } - - for (int i = threadIdx.x; i < Policy::Nblk; i += Policy::Nthreads) { - auto idx = gridStrideX + i; - sy2Norm[i] = idx < n ? y2n[idx] : 0; - } - __syncthreads(); - -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - regx2n[i] = sx2Norm[i * Policy::AccThRows + (threadIdx.x / Policy::AccThCols)]; - } -#pragma unroll - for (int i = 0; i < Policy::AccColsPerTh; ++i) { - regy2n[i] = sy2Norm[i * Policy::AccThCols + (threadIdx.x % Policy::AccThCols)]; - } - -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - auto numer = k * acc[i][j] - (regxn[i] * regyn[j]); - auto Q_denom = k * regx2n[i] - (regxn[i] * regxn[i]); - auto R_denom = k * regy2n[j] - (regyn[j] * regyn[j]); - - acc[i][j] = 1 - (numer / raft::sqrt(Q_denom * R_denom)); - } - } - } -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh deleted file mode 100644 index 0883136c9f..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace raft::distance::detail::ops { - -// Epilogue operator for CUTLASS based kernel -template -struct cosine_cutlass_op { - __device__ cosine_cutlass_op() noexcept {} - __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept - { - return static_cast(1.0) - static_cast(accVal / (aNorm * bNorm)); - } - __device__ AccT operator()(DataT aData) const noexcept { return aData; } -}; - -/** - * @brief the expanded cosine distance matrix calculation - * - * It computes the following equation: - * - * d(x, y) = 1 - (x ⋅ y) / ( ||x||_2 ||y||_2) - */ -template -struct cosine_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - // Load norms of input data - static constexpr bool use_norms = true; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = 1.0 - (acc[i][j] / (regxn[i] * regyn[j])); - } - } - } - - constexpr cosine_cutlass_op get_cutlass_op() const - { - return cosine_cutlass_op(); - } -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh deleted file mode 100644 index 68e843c6f5..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // std::false_type -#include // std::declval - -namespace raft::distance::detail::ops { - -// This file defines the named requirement "has_cutlass_op" that can be used to -// determine if a distance operation has a CUTLASS op that can be used to pass -// to CUTLASS. Examples of distance operations that satisfy this requirement are -// cosine_distance_op and l2_exp_distance_op. - -// Primary template handles types that do not support CUTLASS. -// This pattern is described in: -// https://en.cppreference.com/w/cpp/types/void_t -template -struct has_cutlass_op : std::false_type {}; - -// Specialization recognizes types that do support CUTLASS -template -struct has_cutlass_op().get_cutlass_op())>> - : std::true_type {}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/dice.cuh b/cpp/include/raft/distance/detail/distance_ops/dice.cuh deleted file mode 100644 index 362ba7eab7..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/dice.cuh +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace raft::distance::detail::ops { - -// Epilogue operator for CUTLASS based kernel -template -struct dice_cutlass_op { - __device__ dice_cutlass_op() noexcept {} - __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept - { - return static_cast(1.0) - static_cast(2 * accVal / (aNorm + bNorm)); - } - __device__ AccT operator()(DataT aData) const noexcept { return aData; } -}; - -/** - * @brief the expanded dice distance matrix calculation - * - * It computes the following equation: - * - * d(x, y) = 1 - 2*(x ⋅ y) / ( Σ(x) + Σ(y) ) - */ -template -struct dice_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - // Load norms of input data - static constexpr bool use_norms = true; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - acc += (x != DataT(0) ? DataT(1) : DataT(0)) * (y != DataT(0) ? DataT(1) : DataT(0)); - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = 1.0 - (2 * acc[i][j] / (regxn[i] + regyn[j])); - } - } - } - - constexpr dice_cutlass_op get_cutlass_op() const - { - return dice_cutlass_op(); - } -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh deleted file mode 100644 index 475b8892e9..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace raft::distance::detail::ops { - -/** - * @brief the Hamming Unexpanded distance matrix calculation - * It computes the following equation: - * - * c_ij = sum_k (x_ik != y_kj) / k - */ -template -struct hamming_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - IdxT k; - - hamming_distance_op(IdxT k_) noexcept : k(k_) {} - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += (x != y); }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - const DataT one_over_k = DataT(1.0) / k; -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] *= one_over_k; - } - } - } -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh deleted file mode 100644 index 0489b45854..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh +++ /dev/null @@ -1,77 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include // DI - -namespace raft::distance::detail::ops { - -/** - * @brief the Hellinger distance matrix calculation - * - * It computes the following equation: - * - * c_ij = sqrt(1 - sum_k sqrt(x_ik * y_kj)) - * - */ -template -struct hellinger_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - // This is sqrt(x) * sqrt(y). - const auto product = x * y; - acc += product; - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative - const auto finalVal = (1 - acc[i][j]); - const auto rectifier = (!signbit(finalVal)); - acc[i][j] = raft::sqrt(rectifier * finalVal); - } - } - } -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh deleted file mode 100644 index e46c63734c..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include // raft::log -#include // DI - -namespace raft::distance::detail::ops { - -// Describes the computation the jensen_shannon distance - -/** - * @brief the Jensen Shannon distance matrix calculation - * - * It computes the following equation: - * - * c_ij = sqrt(0.5 * sum( -x_i * (log(0.5 * (x_i + y_i)) - log(x_i)) - * + (-y_i * (log(0.5 * (x_i + y_i)) - log(y_i))))) - */ -template -struct jensen_shannon_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = true; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - const DataT m = 0.5f * (x + y); - const bool m_zero = (m == 0); - const auto logM = (!m_zero) * raft::log(m + m_zero); - - const bool x_zero = (x == 0); - const bool y_zero = (y == 0); - acc += (-x * (logM - raft::log(x + x_zero))) + (-y * (logM - raft::log(y + y_zero))); - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = raft::sqrt(0.5 * acc[i][j]); - } - } - } -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh deleted file mode 100644 index d083c5ddcc..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include // raft::log -#include // DI - -namespace raft::distance::detail::ops { - -/** - * @brief the KL Divergence distance matrix calculation - * - * It computes the following equation: - * - * c_ij = 0.5 * sum(x * log (x / y)); - */ -template -struct kl_divergence_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - const bool is_row_major; - const bool x_equal_y; - - kl_divergence_op(bool row_major_, bool x_equal_y_ = false) noexcept - : is_row_major(row_major_), x_equal_y(x_equal_y_) - { - } - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = true; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - // TODO: make sure that these branches get hoisted out of main loop.. Could - // be quite expensive otherwise. - if (x_equal_y) { - if (is_row_major) { - const bool x_zero = (x == 0); - const bool y_zero = (y == 0); - acc += x * (raft::log(x + x_zero) - (!y_zero) * raft::log(y + y_zero)); - } else { - const bool y_zero = (y == 0); - const bool x_zero = (x == 0); - acc += y * (raft::log(y + y_zero) - (!x_zero) * raft::log(x + x_zero)); - } - } else { - if (is_row_major) { - const bool x_zero = (x == 0); - acc += x * (raft::log(x + x_zero) - y); - } else { - const bool y_zero = (y == 0); - acc += y * (raft::log(y + y_zero) - x); - } - } - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = (0.5f * acc[i][j]); - } - } - } -}; -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh deleted file mode 100644 index 7e86fd3603..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include // DI - -namespace raft::distance::detail::ops { - -/** - * @brief the L1 distance matrix calculation - * - * It computes the following equation: - * - * c_ij = sum_k abs(x_ik - y_kj) - */ -template -struct l1_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - // Do not load norms of data, the computation of L1 distance does not use them. - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += raft::abs(x - y); }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - return; - }; -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh deleted file mode 100644 index a218c85a0a..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include // DI - -namespace raft::distance::detail::ops { - -/** - * Reserve 1 digit of precision from each floating-point type - * for round-off error tolerance. - * @tparam DataT - */ -template -__device__ constexpr DataT get_clamp_precision() -{ - switch (sizeof(DataT)) { - case 2: return 1e-3; - case 4: return 1e-6; - case 8: return 1e-15; - default: return 0; - } -} - -// Epilogue operator for CUTLASS based kernel -template -struct l2_exp_cutlass_op { - bool sqrt; - - __device__ l2_exp_cutlass_op() noexcept : sqrt(false) {} - __device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {} - inline __device__ AccT operator()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept - { - AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; - - /** - * Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product (accVal) - * can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal instead. - */ - outVal = outVal * !((outVal * outVal < get_clamp_precision()) * (aNorm == bNorm)); - return sqrt ? raft::sqrt(outVal * (outVal > 0)) : outVal; - } - - __device__ AccT operator()(DataT aData) const noexcept { return aData; } -}; - -/** - * @brief the expanded euclidean distance matrix calculation - * - * It computes the following equation: - * - * c_ij = - 2 sum_k x_ik * y_kj + ||x_i.||_2 + ||y_.j||_2 - * - */ -template -struct l2_exp_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - const bool sqrt; - - l2_exp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} - - // Load norms of input data - static constexpr bool use_norms = true; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - DataT accVal = acc[i][j]; - DataT val = regxn[i] + regyn[j] - (DataT)2.0 * accVal; - - /** - * Self-neighboring points should have (aNorm == bNorm) == accVal and the dot product - * (accVal) can sometimes have round-off errors, which will cause (aNorm == bNorm) ~ accVal - * instead. - */ - acc[i][j] = - val * (val > 0) * !((val * val < get_clamp_precision()) * (regxn[i] == regyn[j])); - } - } - if (sqrt) { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = raft::sqrt(acc[i][j]); - } - } - } - } - - constexpr l2_exp_cutlass_op get_cutlass_op() const - { - return l2_exp_cutlass_op(sqrt); - } -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh deleted file mode 100644 index 62c212ee8f..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace raft::distance::detail::ops { - -/** - * @brief the unexpanded euclidean distance matrix calculation - * - * It computes the following equation: - * - * c_ij = optional_sqrt ( sum_k (x_ik - y_kj)^2 ) - */ -template -struct l2_unexp_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - bool sqrt; - - l2_unexp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} - - // Do not load norms of data, the computation of L1 distance does not use them. - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - const auto diff = x - y; - acc += diff * diff; - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - if (sqrt) { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = raft::sqrt(acc[i][j]); - } - } - } - }; -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh deleted file mode 100644 index 88853a3083..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace raft::distance::detail::ops { - -/** - * @brief the L_inf (Chebyshev) distance matrix calculation - * - * It computes the following equation: - * - * c_ij = max_k | x_ik - y_kj | - */ -template -struct l_inf_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - const auto diff = raft::abs(x - y); - acc = raft::max(acc, diff); - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - return; - } -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh deleted file mode 100644 index 290f4af1b4..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once -#include // raft::pow, raft::abs -#include // DI - -namespace raft::distance::detail::ops { - -/** - * @brief the unexpanded Lp (Minkowski) distance matrix calculation - * - * It computes the following equation: - * - * c_ij = (sum_k |x_ik - y_jk|^p)^(1/p) - */ -template -struct lp_unexp_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - DataT p; - - lp_unexp_distance_op(DataT p_) noexcept : p(p_) {} - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = true; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const - { - const auto diff = raft::abs(x - y); - acc += raft::pow(diff, p); - }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - const auto one_over_p = 1.0f / p; -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = raft::pow(acc[i][j], one_over_p); - } - } - } -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh deleted file mode 100644 index 63dbf350d1..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh +++ /dev/null @@ -1,74 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace raft::distance::detail::ops { - -/** - * @brief the Russell Rao distance matrix calculation - * - * It computes the following equation: - * - * c_ij = (k - (sum_k x_ik * y_kj)) / k - */ -template -struct russel_rao_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - IdxT k; - const float one_over_k; - - russel_rao_distance_op(IdxT k_) noexcept : k(k_), one_over_k(1.0f / k_) {} - - // Load norms of input data - static constexpr bool use_norms = false; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { acc += x * y; }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = (k - acc[i][j]) * one_over_k; - } - } - } -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh deleted file mode 100644 index 4320068361..0000000000 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // DI - -namespace raft::distance::detail::ops { - -// Describes the computation the template distance -// -// Fill in the TODO items. - -template -struct template_distance_op { - using DataT = DataType; - using AccT = AccType; - using IdxT = IdxType; - - TODO member; - - template_distance_op(TODO member_) noexcept : member(member_) {} - - // Load norms of input data - static constexpr bool use_norms = TODO; - // Whether the core function requires so many instructions that it makes sense - // to reduce loop unrolling, etc. We do this to keep compile times in check. - static constexpr bool expensive_inner_loop = false; - - // Size of shared memory. This is normally decided by the kernel policy, but - // some ops such as correlation_distance_op use more. - template - static constexpr size_t shared_mem_size() - { - return Policy::SmemSize + TODO; - } - - DI void core(AccT& acc, DataT& x, DataT& y) const { TODO; }; - - template - DI void epilog(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT* regxn, - DataT* regyn, - IdxT gridStrideX, - IdxT gridStrideY) const - { - TODO; - } - - // If exist, returns a cutlass op that performs the same operation. - // See cosine and l2_exp distance ops for an example. - constexpr l2_exp_cutlass_op get_cutlass_op() const { TODO; } -}; - -} // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/fused_distance_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn.cuh deleted file mode 100644 index 4fbfdc8755..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn.cuh +++ /dev/null @@ -1,97 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::KeyValuePair -#include // raft::identity_op -#include // ops::l2_exp_distance_op -#include -#include -#include -#include -#include -#include // PairwiseDistances -#include -#include // Policy -#include // raft::util::arch::SM_* -#include // raft::ceildiv, raft::shfl - -#include // size_t -#include // std::numeric_limits - -namespace raft { -namespace distance { - -namespace detail { - -template -void fusedDistanceNNImpl(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - int* workspace, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer, - bool isRowMajor, - raft::distance::DistanceType metric, - float metric_arg, - cudaStream_t stream) -{ - // The kernel policy is determined by fusedDistanceNN. - typedef Policy P; - - dim3 blk(P::Nthreads); - auto nblks = raft::ceildiv(m, P::Nthreads); - constexpr auto maxVal = std::numeric_limits::max(); - typedef KeyValuePair KVPair; - - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); - if (initOutBuffer) { - initKernel - <<>>(min, m, maxVal, redOp); - RAFT_CUDA_TRY(cudaGetLastError()); - } - - switch (metric) { - case raft::distance::DistanceType::CosineExpanded: - fusedCosineNN( - min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, stream); - break; - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2Expanded: - // initOutBuffer is take care by fusedDistanceNNImpl() so we set it false to fusedL2NNImpl. - fusedL2NNImpl( - min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, false, stream); - break; - default: assert("only cosine/l2 metric is supported with fusedDistanceNN\n"); break; - } -} - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h b/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h deleted file mode 100644 index 186715851b..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn/custom_epilogue_with_broadcast.h +++ /dev/null @@ -1,668 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ - -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - - The epilogue rearranges the result of a matrix product through shared memory to match canonical - tensor layouts in global memory. Epilogues support conversion and reduction operations. - -This file contains a customized version of EpilogueWithBroadcast from CUTLASS 2.9.1 -(https://github.com/NVIDIA/cutlass/blob/v2.9.1/include/cutlass/epilogue/threadblock/epilogue_with_broadcast.h) - -Changes: -- customized the compute_source_needed_() and apply_output_operator_() to suit the needs of per row -reduction -*/ - -#pragma once - -#if defined(__CUDACC_RTC__) -#include -#include -#else -#include - -#include -#endif - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace threadblock { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This base class is meant to define the concept required of the -/// EpilogueWithBroadcast::OutputOp -template -struct EpilogueWithBroadcastOpBaseCustom { - using ElementOutput = ElementC_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using ElementZ = ElementZ_; - using ElementT = ElementT_; - static int const kElementsPerAccess = ElementsPerAccess; - - using FragmentAccumulator = Array; - using FragmentCompute = Array; - using FragmentC = Array; - using FragmentZ = Array; - using FragmentT = Array; - - /// If true, the 'Z' tensor is stored - static bool const kStoreZ = StoreZ; - - /// If true, the 'T' tensor is stored - static bool const kStoreT = StoreT; - - /// Parameters structure - required - struct Params {}; - - // - // Methods - // - - /// Constructor from Params - EpilogueWithBroadcastOpBaseCustom(Params const& params_) {} - - /// Determine if the source is needed. May return false if - bool is_source_needed() const { return true; } - - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) {} - - /// Applies the operation when is_source_needed() is true - CUTLASS_HOST_DEVICE - void operator()(FragmentZ& frag_Z, - FragmentT& frag_T, - FragmentAccumulator const& AB, - FragmentC const& frag_C, - FragmentCompute const& V) const - { - } - - /// Applies the operation when is_source_needed() is false - CUTLASS_HOST_DEVICE - void operator()(FragmentZ& frag_Z, - FragmentT& frag_T, - FragmentAccumulator const& AB, - FragmentCompute const& V) const - { - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -/// Epilogue operator with bias vector broadcast over columns. -/// -/// Computes the following: -/// -/// -/// Z, T = OutputOp(AB, C, Broadcast) -/// -/// if (ElementwiseOp::kStoreZ) { -/// store(converted_u); -/// } -/// -/// if (ElementwiseOp::kStoreT) { -/// store(v); -/// } -/// -template < - typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) - typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) - int PartitionsK, ///< Number of partitions of the K dimension - typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors (z) - typename TensorTileIterator_, ///< Additional tile iterator for tensor-valued operands (t) - typename ElementVector_, ///< Pointer to broadcast vector - typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators - typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM - typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM - typename OutputOp_, ///< Output operator - concept is EpilogueWithBroadcastOp - typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: - ///< MatrixShape) - int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity - int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large - (!IsEpilogueFunctorHeavy::value)> -class EpilogueWithBroadcastCustom : public EpilogueBase { - public: - using Base = EpilogueBase; - - using Shape = Shape_; - using WarpMmaOperator = WarpMmaOperator_; - static int const kPartitionsK = PartitionsK; - using OutputTileIterator = OutputTileIterator_; - using TensorTileIterator = TensorTileIterator_; - using ElementVector = ElementVector_; - using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; - using WarpTileIterator = WarpTileIterator_; - using SharedLoadIterator = SharedLoadIterator_; - using OutputOp = OutputOp_; - using Padding = Padding_; - - using Layout = layout::RowMajor; - using LongIndex = typename Layout::LongIndex; - - /// The complete warp-level accumulator tile - using AccumulatorTile = typename Base::AccumulatorTile; - - /// Accumulator element - using ElementAccumulator = typename WarpTileIterator::Element; - - /// Compute data type produced by the output op - using ElementCompute = typename OutputOp::ElementCompute; - - /// Compute fragment - using FragmentCompute = Array; - - /// Thread map used by output tile iterators - using ThreadMap = typename OutputTileIterator::ThreadMap; - - /// Fragment object used to store the broadcast values - using BroadcastFragment = - Array; - - /// Output element - using ElementOutput = typename OutputTileIterator::Element; - - /// Data type of additional tensor - using ElementTensor = typename TensorTileIterator::Element; - - /// Output access size - static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - - /// Tensor reference to destination tensor - using TensorRef = typename OutputTileIterator::TensorRef; - - /// Tensor reference to sync tensor - using SyncTensorRef = typename cutlass::TensorRef; - - /// Const tensor reference to source tensor - using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; - - /// Array type used to output - using OutputAccessType = - Array; - - /// Array type used by output functor - using AccumulatorAccessType = - Array; - - /// Array type used by output functor - using ComputeAccessType = Array; - - /// Tensor access type - using TensorAccessType = Array; - - /// Number of warps - using WarpCount = typename Base::WarpCount; - - /// Shared memory allocation from epilogue base class - using BaseSharedStorage = typename Base::SharedStorage; - - static int constexpr kSmemTiles = - Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; - static int constexpr kSmemPointerOffset = Base::SharedStorage::StorageShape::kCount / kSmemTiles; - - /// Used for the broadcast - struct BroadcastDetail { - /// Number of threads per warp - static int const kWarpSize = 32; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - - /// Number of distinct scalar column indices handled by each thread - static int const kColumnsPerThread = - ThreadMap::Iterations::kColumn * ThreadMap::kElementsPerAccess; - - /// Number of distinct scalar row indices handled by each thread - static int const kRowsPerThread = - ThreadMap::Iterations::kCount / ThreadMap::Iterations::kColumn; - - /// Number of threads per threadblock - static int const kThreadCount = kWarpSize * WarpCount::kCount; - - /// Number of distinct threads per row of output tile - static int const kThreadsPerRow = (Shape::kN / kColumnsPerThread); - - /// Number of distinct threads which must be reduced during the final reduction phase within the - /// threadblock. - static int const kThreadRows = kThreadCount / kThreadsPerRow; - - /// I'm not sure what I meant here. - static int const kThreadAccessesPerRow = - const_max(1, (Shape::kN + kThreadCount - 1) / kThreadCount); - - /// Shape of the shared memory allocation for the epilogue - using StorageShape = MatrixShape; - - /// Debug printing - CUTLASS_DEVICE - static void print() - { -#if 0 - printf("BroadcastDetail {\n"); - printf( - " kColumnsPerThread: %d\nkRowsPerThread: %d\n,kThreadCount: %d\nkThreadsPerRow: %d\n" - "kThreadRows: %d\nThreadAccessesPerRow: %d\nStorageShape: %d x %d (count: %d)\n", - kColumnsPerThread, - kRowsPerThread, - kThreadCount, - kThreadsPerRow, - kThreadRows, - kThreadAccessesPerRow, - StorageShape::kRow, - StorageShape::kColumn, - StorageShape::kCount - ); - printf("};\n"); -#endif - } - }; - - /// Shared storage structure (shadows base) with additional SMEM buffer for reduction - struct SharedStorage { - union { - BaseSharedStorage base; - }; - - CUTLASS_HOST_DEVICE - SharedStorage() {} - }; - - public: - static_assert(SharedLoadIterator::Fragment::kElements == TensorTileIterator::Fragment::kElements, - "Mismatch between shared load iterator and output tile iterator."); - - static_assert(OutputTileIterator::kElementsPerAccess, - "OutputTileIterator::kElementsPerAccess must not be zero."); - - static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), - "Divisibility"); - - private: - /// Loads fragment from shared memory aligned with output tensor - SharedLoadIterator shared_load_iterator_; - - /// Thread index within the threadblock - int thread_idx_; - - public: - /// Constructor - CUTLASS_DEVICE - EpilogueWithBroadcastCustom(SharedStorage& shared_storage, ///< Shared storage object - int thread_idx, ///< ID of a thread within the threadblock - int warp_idx, ///< ID of warp within threadblock - int lane_idx ///< Id of thread within warp - ) - : Base(shared_storage.base, thread_idx, warp_idx, lane_idx), - shared_load_iterator_(shared_storage.base.reference(), thread_idx), - thread_idx_(thread_idx) - { - } - - /// Streams the result to global memory - CUTLASS_DEVICE - void operator()( - OutputOp const& output_op, ///< Output operator - ElementVector const* broadcast_ptr, ///< Broadcast vector - AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator source_iterator, ///< Tile iterator for source accumulator matrix - TensorTileIterator - tensor_iterator, ///< Threadblock tile iterator for additional tensor operand - MatrixCoord const& - problem_size = ///< Problem size needed to guard against out-of-bounds accesses - MatrixCoord(Shape::kM, Shape::kN), - MatrixCoord const& - threadblock_offset = ///< Threadblock's initial offset within the problem size space - MatrixCoord()) - { - BroadcastFragment broadcast_fragment; - - load_broadcast_fragment_(broadcast_fragment, broadcast_ptr, problem_size, threadblock_offset); - - compute_source_needed_( - output_op, broadcast_fragment, accumulators, source_iterator, tensor_iterator); - } - - private: - CUTLASS_DEVICE - void load_broadcast_fragment_( - BroadcastFragment& - broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns - ElementVector const* broadcast_ptr, ///< Broadcast vector - MatrixCoord const& - problem_size, ///< Problem size needed to guard against out-of-bounds accesses - MatrixCoord const& - threadblock_offset ///< Threadblock's initial offset within the problem size space - ) - { - broadcast_fragment.clear(); - - // If no pointer is supplied, set with all zeros and avoid memory accesses - if (!broadcast_ptr) { return; } - - int thread_initial_column = ThreadMap::initial_offset(thread_idx_).column(); - - int thread_column_idx = threadblock_offset.column() + thread_initial_column; - broadcast_ptr += thread_initial_column; - - NumericArrayConverter - converter; - using AccessType = AlignedArray; - using ComputeFragmentType = Array; - - ComputeFragmentType* frag_ptr = reinterpret_cast(&broadcast_fragment); - - CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < ThreadMap::Iterations::kColumn; ++j) { - AccessType loaded; - - loaded.clear(); - - if (thread_column_idx < problem_size.column()) { - loaded = *reinterpret_cast(broadcast_ptr); - } - - ComputeFragmentType cvt = converter(loaded); - frag_ptr[j] = cvt; - - thread_column_idx += ThreadMap::Delta::kColumn; - broadcast_ptr += ThreadMap::Delta::kColumn; - } - } - - template - struct acc2smem_source_not_needed; - - template - struct acc2smem_source_not_needed> { - template - CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, - WarpTileIterator& warp_tile_iterator) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Advance; i++) { - ++accum_fragment_iterator; - } - - CUTLASS_PRAGMA_UNROLL - for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { - typename AccumulatorFragmentIterator::Fragment accum_fragment; - - accum_fragment_iterator.load(accum_fragment); - ++accum_fragment_iterator; - - warp_tile_iterator.store(accum_fragment); - if (p < Base::kFragmentsPerIteration - 1) { - warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); - } - } - - if (Base::kFragmentsPerIteration > 1) { - warp_tile_iterator.add_pointer_offset(kSmemPointerOffset * - (1 - Base::kFragmentsPerIteration)); - } - } - - CUTLASS_DEVICE - static void push(size_t pos, - AccumulatorFragmentIterator const& iterator_begin, - WarpTileIterator& warp_tile_iterator) - { - int dummy[] = { - (pos == (Seq * Base::kFragmentsPerIteration)) && - (helper(iterator_begin, warp_tile_iterator), 0)...}; - - CUTLASS_UNUSED(dummy[0]); - } - }; - - /// Streams the result to global memory - CUTLASS_DEVICE - void compute_source_not_needed_( - OutputOp const& output_op, ///< Output operator - BroadcastFragment const& - broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns - OutputTileIterator destination_iterator, ///< Tile iterator for destination - AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile - TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand - ) - { - } - - template - struct acc2smem_source_needed; - - template - struct acc2smem_source_needed> { - template - CUTLASS_DEVICE static void helper(AccumulatorFragmentIterator accum_fragment_iterator, - WarpTileIterator& warp_tile_iterator) - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < Advance; i++) { - ++accum_fragment_iterator; - } - - typename AccumulatorFragmentIterator::Fragment accum_fragment; - accum_fragment_iterator.load(accum_fragment); - warp_tile_iterator.store(accum_fragment); - } - - CUTLASS_DEVICE - static void push(size_t pos, - AccumulatorFragmentIterator const& iterator_begin, - WarpTileIterator& warp_tile_iterator) - { - int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; - } - }; - - /// Streams the result to global memory - CUTLASS_DEVICE - void compute_source_needed_( - OutputOp const& output_op, ///< Output operator - BroadcastFragment const& - broadcast_fragment, ///< Fragment containing the accumulated partial reduction over columns - AccumulatorTile const& accumulators, ///< Complete warp-level accumulator tile - OutputTileIterator - source_iterator, ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) - TensorTileIterator tensor_iterator ///< Threadblock tile iterator for additioanl tensor operand - ) - { - typename OutputTileIterator::Fragment source_fragment; - source_fragment.clear(); - - // - // Iterator over warp-level accumulator fragment - // - - AccumulatorFragmentIterator accum_fragment_iterator(accumulators); - - // - // Iterate over accumulator tile - // - -#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) - for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { - // - // Convert and store fragment - // - - //__syncthreads(); - - acc2smem_source_needed>::push( - iter, accum_fragment_iterator, this->warp_tile_iterator_); - - __syncthreads(); - - // - // Load fragments from shared memory - // - - typename SharedLoadIterator::Fragment aligned_accum_fragment[kPartitionsK]; - - shared_load_iterator_.load(aligned_accum_fragment[0]); - - // - // Apply output operation - // - - typename TensorTileIterator::Fragment frag_T; - - // - // Load the source - // - - source_iterator.load(source_fragment); - ++source_iterator; - - apply_output_operator_( - frag_T, output_op, aligned_accum_fragment[0], source_fragment, broadcast_fragment); - - // - // Conditionally store fragments - // - if (OutputOp::kStoreT) { - tensor_iterator.store(frag_T); - ++tensor_iterator; - } - } - tensor_iterator.dumpToGmem(); - } - - /// Helper to invoke the output functor over each vector of output - CUTLASS_DEVICE - void apply_output_operator_(typename TensorTileIterator::Fragment& frag_T, - OutputOp const& output_op, - typename SharedLoadIterator::Fragment const& frag_AB, - typename OutputTileIterator::Fragment const& frag_C, - BroadcastFragment const& frag_Broadcast) - { - using AccessTypeT = Array; - using AccessTypeBroadcast = Array; - - AccessTypeT* frag_T_ptr = reinterpret_cast(&frag_T); - - AccumulatorAccessType const* frag_AB_ptr = - reinterpret_cast(&frag_AB); - - OutputAccessType const* frag_C_ptr = reinterpret_cast(&frag_C); - - AccessTypeBroadcast const* frag_Broadcast_ptr = - reinterpret_cast(&frag_Broadcast); - - int const kOutputOpIterations = - TensorTileIterator::Fragment::kElements / TensorTileIterator::kElementsPerAccess; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kOutputOpIterations; ++i) { - output_op(frag_T_ptr[i], - frag_AB_ptr[i], - frag_C_ptr[(i / ThreadMap::Iterations::kColumn)], - frag_Broadcast_ptr[i % ThreadMap::Iterations::kColumn]); - } - } - - /// Helper to invoke the output functor over each vector of output - CUTLASS_DEVICE - void apply_output_operator_source_not_needed_( - typename OutputTileIterator::Fragment& frag_Z, - typename TensorTileIterator::Fragment& frag_T, - OutputOp const& output_op, - typename SharedLoadIterator::Fragment const& frag_AB, - BroadcastFragment const& frag_Broadcast) - { - } -}; - -//////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh deleted file mode 100644 index b2fc5e0cc7..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn/cutlass_base.cuh +++ /dev/null @@ -1,177 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#pragma GCC diagnostic ignored "-Wtautological-compare" - -// We define CUTLASS_NAMESPACE in case -// RAFT cmake is not used -#ifndef CUTLASS_NAMESPACE -#define cutlass raft_cutlass -#endif - -#include // FusedDistanceNNEpilogueElementwise -#include // FusedDistanceNNGemm -#include // getMultiProcessorCount -#include // RAFT_CUTLASS_TRY - -#include - -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft { -namespace distance { -namespace detail { - -template -RAFT_KERNEL initBinMutexKernel(cuda::binary_semaphore* mut, IdxT m) -{ - auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; - - if (tid < m) { mut[tid].release(); } -} - -template -void cutlassFusedDistanceNN(const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - int* mutexes, - CGReduceOpT cg_reduce_op, - DistanceFn dist_op, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - cudaStream_t stream) -{ - using EpilogueOutputOp = cutlass::epilogue::thread::FusedDistanceNNEpilogueElementwise< - DataT, // ElementC_ - AccT, // ElementAccumulator_ - DataT, // ElementCompute_ - AccT, // ElementZ_ - OutT, // ElementT_ - // 128 / cutlass::sizeof_bits::value, - 1, // Elements per access 1 - DistanceFn, - CGReduceOpT, - ReduceOpT, - KVPReduceOpT>; - constexpr int batch_count = 1; - - rmm::device_uvector> bin_mutex(m, stream); - - int blks_ = (m / 256) + 1; - - initBinMutexKernel<<>>(bin_mutex.data(), m); - - typename EpilogueOutputOp::Params epilog_op_param( - dist_op, cg_reduce_op, redOp, pairRedOp, mutexes, bin_mutex.data()); - - // Number of pipelines you want to use - constexpr int NumStages = 3; - // Alignment - constexpr int Alignment = VecLen; - - // default initialize problem size with row major inputs - auto problem_size = cutlass::gemm::GemmCoord(m, n, k); - - constexpr bool isRowMajor = true; - - using fusedDistanceNNKernel = - typename cutlass::gemm::kernel::FusedDistanceNNGemm::GemmKernel; - - using fusedDistanceNN = cutlass::gemm::device::GemmGrouped; - - int num_blocks_per_sm = fusedDistanceNN::maximum_active_blocks(); - int num_sms = raft::getMultiProcessorCount(); - int full_wave = num_blocks_per_sm * num_sms; - constexpr int mmaShapeM = fusedDistanceNNKernel::Mma::Shape::kM; - constexpr int mmaShapeN = fusedDistanceNNKernel::Mma::Shape::kN; - int columnTiles = (problem_size.n() - 1 + mmaShapeN) / mmaShapeN; - int rowTiles = (problem_size.m() - 1 + mmaShapeM) / mmaShapeM; - int totalTiles = columnTiles * rowTiles; - int thread_blocks = - rowTiles < full_wave ? (totalTiles < full_wave ? totalTiles : full_wave) : rowTiles; - - typename fusedDistanceNN::Arguments arguments{ - problem_size, - batch_count, // num of problems. - thread_blocks, - epilog_op_param, - x, - y, - xn, // C matrix eq vector param, which here is A norm - (DataT*)yn, // this is broadcast vec, which is required to be non-const param - dOutput, // Output distance matrix - (int64_t)lda, // stride A - (int64_t)ldb, // stride B - (int64_t)1, // stride A norm - (int64_t)ldd // stride Output matrix - }; - - // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = fusedDistanceNN::get_workspace_size(arguments); - // Allocate workspace memory - rmm::device_uvector workspace(workspace_size, stream); - // Instantiate CUTLASS kernel depending on templates - fusedDistanceNN fusedDistanceNN_op; - // Check the problem size is supported or not - RAFT_CUTLASS_TRY(fusedDistanceNN_op.can_implement(arguments)); - // Initialize CUTLASS kernel with arguments and workspace pointer - RAFT_CUTLASS_TRY(fusedDistanceNN_op.initialize(arguments, workspace.data(), stream)); - // Launch initialized CUTLASS kernel - RAFT_CUTLASS_TRY(fusedDistanceNN_op.run(stream)); -} - -}; // namespace detail -}; // namespace distance -}; // namespace raft - -#pragma GCC diagnostic pop diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh deleted file mode 100644 index 22848d30b0..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue.cuh +++ /dev/null @@ -1,134 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - -This is adapted from DefaultEpilogueWithBroadcastTensorOp from CUTLASS 2.9.0 -(https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h#L75) - -This epilogue allows us to load norm buffers using PredicatedTileIteratorNormVec -and EpilogueWithBroadcast used for distances L2/cosine as well as applies user-define elementwise -operation. --- A norm load is provided PredicatedTileIteratorNormVec --- B norm load is provided by EpilogueWithBroadcast --- elementwise operation is provided by OutputOp -*/ - -#pragma once - -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Defines sensible defaults for epilogues for TensorOps. -template -struct FusedDistanceNNEpilogue { - /// Use defaults related to the existing epilogue - using Base = - DefaultEpilogueTensorOp; - - // - // Stores the result z = (y = GEMM(A, B, C), broadcast) - // - using RowNormTileIterator = cutlass::epilogue::threadblock:: - PredicatedTileIteratorNormVecSmem; - - // - // Additional tensor tile iterator - stores t = Elementwise(z) - // - using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIteratorReducedVec< - typename Base::OutputTileThreadMap, - ElementTensor, - LayoutT, - typename OutputOp::Params>; - - /// Define the epilogue - using Epilogue = cutlass::epilogue::threadblock::EpilogueWithBroadcastCustom< - Shape, - WarpMmaTensorOp, - PartitionsK, - RowNormTileIterator, - OutputTileIterator, - ElementVector, - typename Base::AccumulatorFragmentIterator, - typename Base::WarpTileIterator, - typename Base::SharedLoadIterator, - OutputOp, - typename Base::Padding, - Base::kFragmentsPerIteration>; -}; - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh deleted file mode 100644 index e69b2486df..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn/epilogue_elementwise.cuh +++ /dev/null @@ -1,220 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// -/*! \file - \brief Functor performing distance operations used by epilogues of pairwise distance - * kernels. -* This is adapted from LinearCombinationBiasElementwise from CUTLASS 2.9.0 -* customized for applying elementwise distance formula on accumulated GEMM value -* and applying user-defined operation which can convert distance values to key-value pair. -* . -*/ - -#pragma once - -#include - -#include -#include -#include -#include -#include -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This base class is meant to define the concept required of the -/// EpilogueWithBroadcast::OutputOp -template -class FusedDistanceNNEpilogueElementwise { - public: - using ElementOutput = ElementC_; - using ElementC = ElementC_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using ElementZ = ElementZ_; - using ElementT = ElementT_; - static int const kElementsPerAccess = ElementsPerAccess; - static int const kCount = kElementsPerAccess; - - using DistanceOp = DistanceOp_; - using CGReduceOp = CGReduceOp_; - - using FragmentAccumulator = Array; - using FragmentCompute = Array; - using FragmentC = Array; - using FragmentZ = Array; - using OutValT = typename CGReduceOp::AccTypeT; - using FragmentT = Array; - - using FragmentOutput = FragmentZ; - - static bool const kIsHeavy = true; // ElementwiseOp::kIsHeavy; - - /// If true, the 'Z' tensor is stored - static bool const kStoreZ = false; // We don't store anything in Z, - - /// If true, the 'T' tensor is stored - static bool const kStoreT = true; // this is our final output storage. - - /// Host-constructable parameters structure - struct Params { - CGReduceOp_ cg_reduce_op; - DistanceOp_ dist_op_; - KVPReduceOpT_ pair_redop_; - ReduceOpT_ red_op_; - int* mutexes_; - cuda::binary_semaphore* bin_mutex_; - using CGReduceT = CGReduceOp_; - // - // Methods - // - CUTLASS_HOST_DEVICE - Params(DistanceOp_ dist_op, - CGReduceOp cg_reduce_op, - ReduceOpT_ red_op, - KVPReduceOpT_ pair_redop, - int* mutexes, - cuda::binary_semaphore* bin_mutex) - : cg_reduce_op(cg_reduce_op), - dist_op_(dist_op), - pair_redop_(pair_redop), - red_op_(red_op), - mutexes_(mutexes), - bin_mutex_(bin_mutex) - { - } - - CUTLASS_HOST_DEVICE - Params() {} - }; - - private: - // - // Data members - // - DistanceOp_ elementwise_op; - KVPReduceOpT_ pair_redop; - - public: - ReduceOpT_ red_op; - - // - // Methods - // - - /// Constructor from Params - CUTLASS_HOST_DEVICE - FusedDistanceNNEpilogueElementwise(Params const& params) - : elementwise_op(params.dist_op_), pair_redop(params.pair_redop_), red_op(params.red_op_) - { - } - - /// Returns true if source is needed - CUTLASS_HOST_DEVICE - bool is_source_needed() const - { - // we use for making sure C matrix is used for A mat norm. - return true; - } - - /// Functionally required for serial reduction in the epilogue - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) {} - - /// Applies the operation when is_source_needed() is true - CUTLASS_HOST_DEVICE - void operator()(FragmentT& frag_T, - FragmentAccumulator const& AB, - FragmentC const& frag_C, - FragmentCompute const& V) const - { - FragmentCompute tmp_Accum = - NumericArrayConverter()(AB); - FragmentCompute tmp_C = - NumericArrayConverter()(frag_C); - FragmentCompute result_Z; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kElementsPerAccess; ++i) { - ElementCompute res_Z = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); - frag_T[i] = res_Z; - } - } - - /// Applies the operation when is_source_needed() is false - CUTLASS_HOST_DEVICE - void operator()(FragmentZ& frag_Z, - FragmentT& frag_T, - FragmentAccumulator const& AB, - FragmentCompute const& V) const - { - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh deleted file mode 100644 index f29c8b4d4c..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn/fused_cosine_nn.cuh +++ /dev/null @@ -1,136 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::KeyValuePair -#include // raft::identity_op -#include // ops::l2_exp_distance_op -#include -#include -#include -#include // PairwiseDistances -#include // Policy -#include // raft::util::arch::SM_* -#include // raft::ceildiv, raft::shfl - -#include // size_t -#include // std::numeric_limits - -namespace raft { -namespace distance { - -namespace detail { - -template -void fusedCosineNN(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - int* workspace, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - cudaStream_t stream) -{ - // The kernel policy is determined by fusedL2NN. - typedef Policy P; - - dim3 blk(P::Nthreads); - constexpr auto maxVal = std::numeric_limits::max(); - typedef KeyValuePair KVPair; - - namespace arch = raft::util::arch; - using AccT = DataT; - ops::cosine_distance_op distance_op{}; - - raft::identity_op fin_op{}; - - auto kernel = fusedDistanceNNkernel; - - // Get pointer to fp32 SIMT kernel to determine the runtime architecture of the - // current system. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 - void* kernel_ptr = reinterpret_cast(kernel); - auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); - auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); - - if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. - using cosineOp = raft::distance::detail::ops::cosine_cutlass_op; - using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; - kvp_cg_min_reduce_op_ cg_reduce_op; - cosineOp cosine_dist_op; - - IdxT lda, ldb, ldd; - lda = k, ldb = k, ldd = n; - - cutlassFusedDistanceNN(x, - y, - xn, - yn, - m, - n, - k, - lda, - ldb, - ldd, - min, - workspace, - cg_reduce_op, - cosine_dist_op, - redOp, - pairRedOp, - stream); - } else { - // If device less than SM_80, use fp32 SIMT kernel. - constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); - - kernel<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); - RAFT_CUDA_TRY(cudaGetLastError()); - } -} - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh deleted file mode 100644 index 65475e73c7..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn/fused_l2_nn.cuh +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::KeyValuePair -#include // raft::identity_op -#include // ops::l2_exp_distance_op -#include -#include -#include -#include // PairwiseDistances -#include // Policy -#include // raft::util::arch::SM_* -#include // raft::ceildiv, raft::shfl - -#include // size_t -#include // std::numeric_limits - -namespace raft { -namespace distance { - -namespace detail { - -template -void fusedL2NNImpl(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - int* workspace, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - // The kernel policy is determined by fusedL2NN. - typedef Policy P; - - dim3 blk(P::Nthreads); - auto nblks = raft::ceildiv(m, P::Nthreads); - constexpr auto maxVal = std::numeric_limits::max(); - typedef KeyValuePair KVPair; - - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); - if (initOutBuffer) { - initKernel - <<>>(min, m, maxVal, redOp); - RAFT_CUDA_TRY(cudaGetLastError()); - } - - namespace arch = raft::util::arch; - using AccT = DataT; - ops::l2_exp_distance_op distance_op{sqrt}; - - raft::identity_op fin_op{}; - - auto kernel = fusedDistanceNNkernel; - - // Get pointer to fp32 SIMT kernel to determine the best compute architecture - // out of all for which the kernel was compiled for that matches closely - // to the current device. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 - void* kernel_ptr = reinterpret_cast(kernel); - auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); - auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); - - if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. - using L2Op = raft::distance::detail::ops::l2_exp_cutlass_op; - using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; - kvp_cg_min_reduce_op_ cg_reduce_op; - L2Op L2_dist_op(sqrt); - - IdxT lda, ldb, ldd; - lda = k, ldb = k, ldd = n; - - cutlassFusedDistanceNN(x, - y, - xn, - yn, - m, - n, - k, - lda, - ldb, - ldd, - min, - workspace, - cg_reduce_op, - L2_dist_op, - redOp, - pairRedOp, - stream); - } else { - // If device less than SM_80, use fp32 SIMT kernel. - constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); - - kernel<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); - RAFT_CUDA_TRY(cudaGetLastError()); - } -} - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h deleted file mode 100644 index 42de4860a0..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn/gemm.h +++ /dev/null @@ -1,409 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -#include -#include -#include -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// -/* - * This configuration is used for float inputs with veclen(kAlignmentA/B) = 2 or 4, - * ideal threadblock tile shape is 32x256x16 for such cases as there is no - * registers spills for it. - * - */ -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Number of stages used in the pipelined mainloop - int Stages, - /// data layout row/column major of inputs - bool isRowMajor> -struct FusedDistanceNNGemm { - // This struct is specialized for fp32/3xTF32 - - /// Threadblock-level tile size (concept: GemmShape) - // <- threadblock tile M = 32, N = 256, K = 16 - // this is more performant but note that for veclen = 1 - // this shape has register spills - using ThreadblockShape = cutlass::gemm::GemmShape<32, 256, 16>; - - // <- threadblock tile M = 32, N = 128, K = 16 - // this shape has high occupancy but less perf - // this is less performant but this shape has *no* register spills - // for any veclens(1, 2, 4) - // using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - - /// Warp-level tile size (concept: GemmShape) - // This code section describes tile size a warp will compute - // <- warp tile M = 64, N = 64, K = 16 - // this is more performant for veclen 2,4. - using WarpShape = cutlass::gemm::GemmShape<32, 64, 16>; - - // this shape has high occupancy but less perf used for 32x128x16 - // using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; - - /// Warp-level tile size (concept: GemmShape) - // This code section describes the size of MMA op - // <- MMA Op tile M = 16, N = 8, K = 4 - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; - - /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastF32; - // using Operator = cutlass::arch::OpMultiplyAdd; // this runs only 1xTF32 for float inputs - - // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU - // SM - using OperatorClass = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using ArchTag = cutlass::arch::Sm80; - - // This code section describes how threadblocks are scheduled on GPU - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - - /// data layout for final output matrix. - // we keep this same layout even for column major inputs - using LayoutOutput = cutlass::layout::RowMajor; - - typedef typename std::conditional::type NormXLayout; - - typedef typename std:: - conditional::type LayoutA_; - - typedef typename std:: - conditional::type LayoutB_; - - using GemmBase = typename DefaultGemmUniversal::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementAccumulator, - typename EpilogueOutputOp::ElementT, - ElementAccumulator, - EpilogueOutputOp, - NormXLayout, - GemmBase::Epilogue::kElementsPerAccess>::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = FusedDistanceNNPersistent; -}; - -/* - * This configuration is used for float inputs with veclen(kAlignmentA/B) = 1, - * ideal threadblock tile shape is 32x128x16 for such cases as there is no - * registers spills for it. - * - */ -template < - /// Element type for C and D matrix operands - typename ElementC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Number of stages used in the pipelined mainloop - int Stages, - /// data layout row/column major of inputs - bool isRowMajor> -struct FusedDistanceNNGemm { - // This struct is specialized for fp32/3xTF32 - using ElementA_ = float; - using ElementB_ = float; - - /// Threadblock-level tile size (concept: GemmShape) - // <- threadblock tile M = 32, N = 128, K = 16 - // this shape has high occupancy and no register spills for veclen = 1. - using ThreadblockShape = cutlass::gemm::GemmShape<32, 128, 16>; - - /// Warp-level tile size (concept: GemmShape) - // This code section describes tile size a warp will compute - // <- warp tile M = 32, N = 32, K = 16 - using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; - - /// Warp-level tile size (concept: GemmShape) - // This code section describes the size of MMA op - // <- MMA Op tile M = 16, N = 8, K = 4 - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 4>; - - /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastF32; - // using Operator = cutlass::arch::OpMultiplyAdd; // this runs only 1xTF32 for float inputs - - // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU - // SM - using OperatorClass = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using ArchTag = cutlass::arch::Sm80; - - // This code section describes how threadblocks are scheduled on GPU - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - - /// data layout for final output matrix. - // we keep this same layout even for column major inputs - using LayoutOutput = cutlass::layout::RowMajor; - - typedef typename std::conditional::type NormXLayout; - - typedef typename std:: - conditional::type LayoutA_; - - typedef typename std:: - conditional::type LayoutB_; - - using GemmBase = typename DefaultGemmUniversal::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementAccumulator, - typename EpilogueOutputOp::ElementT, - ElementAccumulator, - EpilogueOutputOp, - NormXLayout, - GemmBase::Epilogue::kElementsPerAccess>::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = FusedDistanceNNPersistent; -}; - -template < - /// Layout type for A matrix operand - int kAlignmentA, - /// Layout type for B matrix operand - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Number of stages used in the pipelined mainloop - int Stages, - /// data layout row/column major of inputs - bool isRowMajor> -struct FusedDistanceNNGemm { - // Threadblock-level tile size (concept: GemmShape) - // <- threadblock tile M = 64, N = 64, K = 16 - using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 16>; - // using ThreadblockShape = cutlass::gemm::GemmShape<16, 128, 16>; - /// Warp-level tile size (concept: GemmShape) - // This code section describes tile size a warp will compute - // <- warp tile M = 32, N = 32, K = 16 - using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; - // using WarpShape = cutlass::gemm::GemmShape<16, 32, 16>; - /// Warp-level tile size (concept: GemmShape) - // This code section describes the size of MMA op - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; - - // Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAdd; - // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU - // SM - using OperatorClass = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using ArchTag = cutlass::arch::Sm80; - - // This code section describes how threadblocks are scheduled on GPU - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - - /// data layout for final output matrix. - // we keep this same layout even for column major inputs - using LayoutOutput = cutlass::layout::RowMajor; - - typedef typename std::conditional::type NormXLayout; - - typedef typename std:: - conditional::type LayoutA_; - - typedef typename std:: - conditional::type LayoutB_; - - using GemmBase = typename DefaultGemmUniversal::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::FusedDistanceNNEpilogue< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementC_, - typename EpilogueOutputOp::ElementT, - ElementC_, - EpilogueOutputOp, - NormXLayout, - GemmBase::Epilogue::kElementsPerAccess>::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = FusedDistanceNNPersistent; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh deleted file mode 100644 index e056c5d397..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn/helper_structs.cuh +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::KeyValuePair -#include // raft::identity_op -#include // ops::l2_exp_distance_op -#include -#include -#include // PairwiseDistances -#include // Policy -#include // raft::util::arch::SM_* -#include // raft::ceildiv, raft::shfl -#include - -#include // size_t -#include // std::numeric_limits - -namespace raft { -namespace distance { - -namespace detail { - -template -struct KVPMinReduceImpl { - typedef raft::KeyValuePair KVP; - DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - -}; // KVPMinReduce - -template -struct MinAndDistanceReduceOpImpl { - typedef typename raft::KeyValuePair KVP; - - DI void operator()(LabelT rid, KVP* out, const KVP& other) const - { - if (other.value < out->value) { - out->key = other.key; - out->value = other.value; - } - } - DI void operator()(LabelT rid, volatile KVP* out, const KVP& other) const - { - if (other.value < out->value) { - out->key = other.key; - out->value = other.value; - } - } - - DI void operator()(LabelT rid, DataT* out, const KVP& other) const - { - if (other.value < *out) { *out = other.value; } - } - - DI void operator()(LabelT rid, volatile DataT* out, const KVP& other) const - { - if (other.value < *out) { *out = other.value; } - } - - DI void operator()(LabelT rid, DataT* out, const DataT& other) const - { - if (other < *out) { *out = other; } - } - - DI void operator()(LabelT rid, volatile DataT* out, const DataT& other) const - { - if (other < *out) { *out = other; } - } - - DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } - DI void init(KVP* out, DataT maxVal) const - { - out->value = maxVal; - out->key = 0xfffffff0; - } - - DI void init_key(DataT& out, LabelT idx) const { return; } - DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } - - DI DataT get_value(KVP& out) const { return out.value; } - DI DataT get_value(DataT& out) const { return out; } -}; - -template -struct MinReduceOpImpl { - typedef typename raft::KeyValuePair KVP; - DI void operator()(LabelT rid, DataT* out, const KVP& other) - { - if (other.value < *out) { *out = other.value; } - } - - DI void init(DataT* out, DataT maxVal) { *out = maxVal; } -}; - -template -RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) -{ - auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; - if (tid < m) { redOp.init(min + tid, maxVal); } -} - -template -void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) -{ - auto blks = raft::ceildiv(m, 256); - initKernel<<>>(min, m, maxVal, redOp); -} - -// cg::reduce functor for FusedDistanceNN used in its cutlass version -// to output the min distance value & key(loc id). -// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h -// store_with_byte_offset() passed to cg::reduce() & select_reduce. -template -struct kvp_cg_min_reduce_op { - typedef typename raft::KeyValuePair KVP; - - __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; - - using AccTypeT = AccType; - using IndexT = Index; - // functor signature. - __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } - - __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } - - __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } -}; - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h b/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h deleted file mode 100644 index f1a7c728e9..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn/persistent_gemm.h +++ /dev/null @@ -1,512 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Problem visitor for grouped GEMMs -This file contains heavily customized version of GemmGrouped from CUTLASS 2.10.0 -(https://github.com/NVIDIA/cutlass/blob/v2.10.0/include/cutlass/gemm/kernel/gemm_grouped.h) - -Changes: -- adds support for only single problem size to be launched persistently - where each threablock processes more than one tile of the same problem. -*/ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template -struct FusedDistanceNNPersistent { - public: - using Mma = Mma_; - using Epilogue = Epilogue_; - using EpilogueOutputOp = typename Epilogue::OutputOp; - using ThreadblockSwizzle = ThreadblockSwizzle_; - static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; - static bool const kTransposed = Transposed; - - // Optional transpose - using MapArguments = kernel::detail::MapArguments; - - // Public-facing type definitions related to operand element type, layout, and complex conjugate - // operation. Must interact with the 'kTransposed' notion. - using ElementA = typename MapArguments::ElementA; - using LayoutA = typename MapArguments::LayoutA; - using ElementB = typename MapArguments::ElementB; - using LayoutB = typename MapArguments::LayoutB; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename MapArguments::LayoutC; - - static ComplexTransform const kTransformA = MapArguments::kTransformA; - static ComplexTransform const kTransformB = MapArguments::kTransformB; - - // Type definitions about the mainloop. - using Operator = typename Mma::Operator; - using OperatorClass = typename Mma::Operator::OperatorClass; - using ThreadblockShape = typename Mma::Shape; - using WarpShape = typename Mma::Operator::Shape; - using InstructionShape = typename Mma::Policy::Operator::InstructionShape; - using ArchTag = typename Mma::ArchTag; - - static int const kStages = Mma::kStages; - static int const kAlignmentA = MapArguments::kAlignmentA; - static int const kAlignmentB = MapArguments::kAlignmentB; - static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; - - /// Warp count (concept: GemmShape) - using WarpCount = typename Mma::WarpCount; - static int const kThreadCount = 32 * WarpCount::kCount; - - using ProblemVisitor = GemmGroupedProblemVisitor; - - // - // Structures - // - - struct temp_problem_visitor { - int problem_count; - - CUTLASS_HOST_DEVICE temp_problem_visitor() : problem_count(0){}; - CUTLASS_HOST_DEVICE temp_problem_visitor(int problem_count_) : problem_count(problem_count_){}; - }; - - /// Argument structure - struct Arguments { - // - // Data members - // - GemmCoord problem_sizes; - temp_problem_visitor problem_visitor; - int problem_count; - int threadblock_count; - - typename EpilogueOutputOp::Params output_op; - - void const* ptr_A; - void const* ptr_B; - void const* ptr_C; - void* ptr_Vector; - void* ptr_Tensor; - - typename LayoutA::Stride::Index lda; - typename LayoutB::Stride::Index ldb; - typename LayoutC::Stride::Index ldc; - typename LayoutC::Stride::Index ldt; - - // Only used by device-level operator - GemmCoord* host_problem_sizes; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() - : threadblock_count(0), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_Vector(nullptr), - ptr_Tensor(nullptr), - lda(0), - ldb(0), - ldc(0), - ldt(0), - host_problem_sizes(nullptr) - { - } - - /// Ctor - CUTLASS_HOST_DEVICE - Arguments(GemmCoord problem_sizes, - int problem_count, - int threadblock_count, - typename EpilogueOutputOp::Params output_op, - void const* ptr_A, - void const* ptr_B, - void const* ptr_C, - void* ptr_Vector, - void* ptr_Tensor, - typename LayoutA::Stride::Index lda, - typename LayoutB::Stride::Index ldb, - typename LayoutC::Stride::Index ldc, - typename LayoutC::Stride::Index ldt, - GemmCoord* host_problem_sizes = nullptr) - : problem_sizes(problem_sizes), - threadblock_count(threadblock_count), - output_op(output_op), - ptr_A(ptr_A), - ptr_B(ptr_B), - ptr_C(ptr_C), - ptr_Vector(ptr_Vector), - ptr_Tensor(ptr_Tensor), - lda(lda), - ldb(ldb), - ldc(ldc), - ldt(ldt), - host_problem_sizes(host_problem_sizes) - { - problem_visitor.problem_count = problem_count; - } - }; - - // - // Structure for precomputing values in host memory and passing to kernels - // - - /// Parameters structure - struct Params { - temp_problem_visitor problem_visitor; - int threadblock_count; - - typename Mma::IteratorA::Params params_A; - typename Mma::IteratorB::Params params_B; - typename Epilogue::OutputTileIterator::Params params_C; - typename Epilogue::TensorTileIterator::Params params_Tensor; - - typename EpilogueOutputOp::Params output_op; - - void* ptr_A; - void* ptr_B; - void* ptr_C; - void* ptr_Vector; - void* ptr_Tensor; - - GemmCoord problem_size; - typename LayoutA::Stride::Index lda; - typename LayoutB::Stride::Index ldb; - typename LayoutC::Stride::Index ldc; - typename LayoutC::Stride::Index ldt; - - // - // Methods - // - - CUTLASS_HOST_DEVICE - Params() - : params_A(0), - params_B(0), - params_C(0), - ptr_A(nullptr), - ptr_B(nullptr), - ptr_C(nullptr), - ptr_Vector(nullptr), - ptr_Tensor(nullptr), - lda(0), - ldb(0), - ldc(0), - ldt(0) - { - } - - CUTLASS_HOST_DEVICE - Params(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - : problem_size(args.problem_sizes), - threadblock_count(args.threadblock_count), - output_op(args.output_op), - params_A(args.lda), - params_B(args.ldb), - params_C(args.ldc), - // Here we pass additional user args via args.output_op - // to the reduction output tile iterator - params_Tensor(args.ldt, args.output_op), - ptr_A(const_cast(args.ptr_A)), - ptr_B(const_cast(args.ptr_B)), - ptr_C(const_cast(args.ptr_C)), - ptr_Vector(args.ptr_Vector), - ptr_Tensor(args.ptr_Tensor), - lda(args.lda), - ldb(args.ldb), - ldc(args.ldc), - ldt(args.ldt) - { - problem_visitor.problem_count = args.problem_visitor.problem_count; - } - - CUTLASS_HOST_DEVICE - void update(Arguments const& args, void* workspace = nullptr, int tile_count = 0) - { - threadblock_count = args.threadblock_count; - output_op = args.output_op; - ptr_A = const_cast(args.ptr_A); - ptr_B = const_cast(args.ptr_B); - ptr_C = const_cast(args.ptr_C); - ptr_Vector = args.ptr_Vector; - ptr_Tensor = args.ptr_Tensor; - lda = args.lda; - ldb = args.ldb; - ldc = args.ldc; - ldt = args.ldt; - - problem_size = args.problem_sizes; - } - }; - - /// Shared memory storage structure - struct SharedStorage { - union { - typename Mma::SharedStorage main_loop; - typename Epilogue::SharedStorage epilogue; - } kernel; - - typename Epilogue::TensorTileIterator::SharedStorage reduced_store; - typename Epilogue::OutputTileIterator::SharedStorage rownorm_store; - }; - - public: - // - // Methods - // - - CUTLASS_DEVICE - FusedDistanceNNPersistent() {} - - /// Determines whether kernel satisfies alignment - static Status can_implement(cutlass::gemm::GemmCoord const& problem_size) - { - return Status::kSuccess; - } - - static Status can_implement(Arguments const& args) { return Status::kSuccess; } - - static size_t get_extra_workspace_size(Arguments const& args, - cutlass::gemm::GemmCoord const& grid_tiled_shape) - { - return 0; - } - - CUTLASS_DEVICE - static uint32_t tile_count(const cutlass::MatrixCoord& grid) - { - return grid.row() * grid.column(); - } - - /// Get the grid shape - CUTLASS_DEVICE - static cutlass::MatrixCoord grid_shape(const cutlass::gemm::GemmCoord& problem) - { - return cutlass::MatrixCoord(((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), - ((problem.n() - 1 + ThreadblockShape::kN) / ThreadblockShape::kN)); - } - - /// Executes one GEMM - CUTLASS_DEVICE - void operator()(Params const& params, SharedStorage& shared_storage) - { -#if __CUDA_ARCH__ >= 800 - // - // These types shadow the type-level definitions and support the ability to implement - // a 'transposed' GEMM that computes the transposed problems. - // - using ElementA = typename Mma::IteratorA::Element; - using LayoutA = typename Mma::IteratorA::Layout; - using ElementB = typename Mma::IteratorB::Element; - using LayoutB = typename Mma::IteratorB::Layout; - using ElementC = typename Epilogue::OutputTileIterator::Element; - using LayoutC = typename Epilogue::OutputTileIterator::Layout; - - const GemmCoord& problem_size = params.problem_size; - const auto grid_shape_ = grid_shape(problem_size); - const uint32_t problem_chunk = (tile_count(grid_shape_) - 1 + gridDim.x) / gridDim.x; - const uint32_t problem_chunk_end = blockIdx.x * problem_chunk + problem_chunk; - typename LayoutB::Index column = - ((blockIdx.x * problem_chunk) % grid_shape_.column()) * Mma::Shape::kN; - - typename LayoutB::Index row = - ((blockIdx.x * problem_chunk) / grid_shape_.column()) * Mma::Shape::kM; - if (column) { - shared_storage.reduced_store.initSmem(params.output_op); - shared_storage.rownorm_store.initSmem(params.ptr_C, problem_size.m(), row, sizeof(ElementC)); - } - - // Outer 'persistent' loop to iterate over tiles - for (uint32_t tile_idx = blockIdx.x * problem_chunk; tile_idx < problem_chunk_end; tile_idx++) { - const auto grid_shape_ = grid_shape(problem_size); - cutlass::MatrixCoord threadblock_offset( - int(tile_idx / grid_shape_.column()) * Mma::Shape::kM, - int(tile_idx % grid_shape_.column()) * Mma::Shape::kN); - - const bool isNextTile = ((tile_idx + 1) < problem_chunk_end); - const bool doesRowChange = - ((threadblock_offset.column() + Mma::Shape::kN) >= problem_size.n()); - const bool do_gmem_reduce = (doesRowChange || !isNextTile) ? true : false; - - ElementA* ptr_A = static_cast(params.ptr_A); - ElementB* ptr_B = static_cast(params.ptr_B); - - // Compute initial location in logical coordinates - cutlass::MatrixCoord tb_offset_A{threadblock_offset.row(), 0}; - cutlass::MatrixCoord tb_offset_B{0, threadblock_offset.column()}; - - // Compute position within threadblock - int thread_idx = threadIdx.x; - - // Construct iterators to A and B operands - typename Mma::IteratorA iterator_A( - params.params_A, ptr_A, {problem_size.m(), problem_size.k()}, thread_idx, tb_offset_A); - - typename Mma::IteratorB iterator_B( - params.params_B, ptr_B, {problem_size.k(), problem_size.n()}, thread_idx, tb_offset_B); - - // Broadcast the warp_id computed by lane 0 to ensure dependent code - // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); - - int lane_idx = threadIdx.x % 32; - - // - // Matrix multiply phase - // - - // Construct thread-scoped matrix multiply - Mma mma(shared_storage.kernel.main_loop, thread_idx, warp_idx, lane_idx); - - typename Mma::FragmentC accumulators; - - accumulators.clear(); - // Compute threadblock-scoped matrix multiply-add - int gemm_k_iterations = (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; - - // Wait for all threads to finish their epilogue phases from the previous tile. - //__syncthreads(); - - // Compute threadblock-scoped matrix multiply-add - mma(gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); - - // - // Epilogue - // - - EpilogueOutputOp output_op(params.output_op); - - ElementC* ptr_C = static_cast(params.ptr_C); - typename Epilogue::ElementTensor* ptr_Tensor = - static_cast(params.ptr_Tensor); - - // Define the reduction output pointer and move to the appropriate place - typename Epilogue::ElementVector* ptr_Vector = - static_cast(params.ptr_Vector); - - // Tile iterator loading from source tensor. - typename Epilogue::OutputTileIterator iterator_rownorm(shared_storage.rownorm_store, - params.params_C, - ptr_C, - problem_size.mn(), - thread_idx, - threadblock_offset); - - // Additional tensor to load from - typename Epilogue::TensorTileIterator tensor_iterator(shared_storage.reduced_store, - params.params_Tensor, - // Only the final block outputs Tensor - ptr_Tensor, - problem_size.mn(), - thread_idx, - do_gmem_reduce, - threadblock_offset); - - Epilogue epilogue(shared_storage.kernel.epilogue, thread_idx, warp_idx, lane_idx); - - // Execute the epilogue operator to update the destination tensor. - // Move to appropriate location for this output tile - if (ptr_Vector) { ptr_Vector += threadblock_offset.column(); } - - // Execute the epilogue operator to update the destination tensor. - epilogue(output_op, - ptr_Vector, - // iterator_D, - accumulators, - iterator_rownorm, - tensor_iterator, - problem_size.mn(), - threadblock_offset); - } -#endif - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h deleted file mode 100644 index 794cd5eb63..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_normvec_smem.h +++ /dev/null @@ -1,448 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - -This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 -(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) - -Changes: -- added `Layout_` template param -- Only the row index is used to load the data in load_with_byte_offset(). - This way the same normalization data is used across all columns in a row. - -*/ -#pragma once - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////// - -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator used to load and store output tile from global memory in epilogue. -/// -/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -/// -template -class PredicatedTileIteratorNormVecSmem { - public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; - - using Element = Element_; - - using Layout = Layout_; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kThreads = ThreadMap::kThreads; - static int const kIterations = ThreadMap::Count::kTile; - - static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * - ThreadMap::Count::kTile * ThreadMap::Delta::kRow; - - static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); - static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); - static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); - static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); - - using Fragment = Array; - - /// Memory access size - using AccessType = AlignedArray; - - // - // Parameters struct - // - - /// Uses a non-template class - struct Params : PredicatedTileIteratorParams { - using Base = PredicatedTileIteratorParams; - - CUTLASS_HOST_DEVICE - Params() {} - - CUTLASS_HOST_DEVICE - Params(Layout const& layout) - : PredicatedTileIteratorParams( - layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, - make_OutputTileThreadMapDesc()) - { - } - - CUTLASS_HOST_DEVICE - Params(Base const& base) : Base(base) {} - }; - - /// Mask object - struct Mask { - static int const kCount = ThreadMap::Iterations::kColumn; - - /// Predicate state - bool predicates[kCount]; - - // - // Mask - // - CUTLASS_HOST_DEVICE - Mask() { enable(); } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_HOST_DEVICE void clear() - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = false; - } - } - - ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask - CUTLASS_DEVICE void enable() - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = true; - } - } - }; - - /// Shared storage allocation needed by the predicated tile - // iterator for storing rowNorm chunk. - struct SharedStorage { - // - // Type definitions - // - using Shape = MatrixShape; - - /// Shape of the shared memory allocation - using StorageShape = MatrixShape; - - // - // Data members - // - // Methods - // - AlignedBuffer storage; - - CUTLASS_DEVICE - Element* data() { return storage.data(); } - - SharedStorage() {} - - CUTLASS_DEVICE - void initSmem(void* pointer, - const Index& num_rows, - const Index& tb_row_offset, - const LongIndex& stride) - { - Element* shared_elem_arr = data(); - uint8_t* first_tile_byte_pointer_ = - reinterpret_cast(pointer) + LongIndex(tb_row_offset) * LongIndex(stride); - const auto gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - bool guard = (tb_row_offset + row) < num_rows; - cutlass::arch::cp_async(shared_elem_arr + row, gmem_ptr + row, guard); - cutlass::arch::cp_async_wait<0>(); - } - } - }; - - private: - // - // Data members - // - - /// Parameters structure containing reference and precomputed state. - PredicatedTileIteratorParams params_; - - /// Byte-level pointer - uint8_t* byte_pointer_; - - /// Array of boolean values to contain steady-state predicates - Mask mask_; - - /// Extent of the matrix tile in rows - Index extent_row_; - - /// Extent of the matrix tile in rows - Index extent_column_; - - /// A thread's starting row position (assuming steady-state predicates have been computed) - Index thread_start_row_; - - /// A thread's starting column - Index thread_start_column_; - - /// Internal state counter - int state_[3]; - - /// Scatter indices - int const* indices_; - - // - // Static asserts about internal strides - // - - static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); - - private: - // - // Methods - // - - protected: - SharedStorage& shared_storage_; - - public: - // - // Methods - // - - /// Constructor - CUTLASS_DEVICE - PredicatedTileIteratorNormVecSmem(SharedStorage& shared_storage, - PredicatedTileIteratorParams const& params, - Element* pointer, - TensorCoord extent, - int thread_idx, - TensorCoord& threadblock_offset, - int const* indices = nullptr) - : params_(params), indices_(indices), shared_storage_(shared_storage) - { - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; - - extent_row_ = extent.row(); - extent_column_ = extent.column(); - - thread_start_row_ = thread_offset.row(); - thread_start_column_ = thread_offset.column(); - - // Initialize predicates - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { - mask_.predicates[c] = - ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); - } - - // Null pointer performs no accesses - if (!pointer) { - mask_.clear(); - return; - } - - if (ScatterD && !indices) { mask_.clear(); } - - // Initialize pointer - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.row()) * LongIndex(params_.stride); - - if (ScatterD) { - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; - } - - if (threadblock_offset.column() == 0) { - shared_storage_.initSmem(pointer, extent_row_, threadblock_offset.row(), params_.stride); - } - - // Initialize internal state counter - state_[0] = state_[1] = state_[2] = 0; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const - { - AccessType* frag_ptr = reinterpret_cast(&frag); - - Element* shared_elem_arr = shared_storage_.data(); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - int iter_row = ((row_offset + thread_start_row_) % total_rows); - Element val = shared_elem_arr[iter_row]; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kElementsPerAccess; ++i) { - (*frag_ptr)[frag_row_idx + i] = val; - } - } - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } - - CUTLASS_DEVICE - MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_row() const { return thread_start_row_; } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_column() const { return thread_start_column_; } - - /// Extent of the matrix in rows - CUTLASS_DEVICE - Index extent_row() const { return extent_row_; } - - /// Extent of the matrix in columns - CUTLASS_DEVICE - Index extent_column() const { return extent_column_; } - - /// Advances to the next position to load or store - CUTLASS_HOST_DEVICE - PredicatedTileIteratorNormVecSmem& operator++() - { - ++state_[0]; - - if (!ScatterD) { byte_pointer_ += params_.advance_row; } - - thread_start_row_ += ThreadMap::Shape::kRow; - - if (state_[0] == ThreadMap::Count::kRow) { - state_[0] = 0; - ++state_[1]; - byte_pointer_ += params_.advance_group; - - thread_start_row_ += - (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; - - if (state_[1] == ThreadMap::Count::kGroup) { - state_[1] = 0; - ++state_[2]; - byte_pointer_ += params_.advance_cluster; - - thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * - ThreadMap::Count::kRow * ThreadMap::Shape::kRow; - - if (state_[2] == ThreadMap::Count::kCluster) { - state_[2] = 0; - byte_pointer_ += params_.advance_tile; - } - } - } - - return *this; - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_DEVICE void clear_mask() { mask_.clear(); } - - ///< Efficiently enables all accesses guarded by mask - CUTLASS_DEVICE void enable_mask() { mask_.enable(); } - - ///< Sets the mask - CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } - - ///< Sets the mask - CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h b/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h deleted file mode 100644 index d61018593f..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn/predicated_tile_iterator_reduced_vec.h +++ /dev/null @@ -1,608 +0,0 @@ -/*************************************************************************************************** - * Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * SPDX-License-Identifier: BSD-3-Clause - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * 3. Neither the name of the copyright holder nor the names of its - * contributors may be used to endorse or promote products derived from - * this software without specific prior written permission. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - **************************************************************************************************/ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - -This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 -(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) - -Changes: -- added `Layout_` template param -- PredicatedTileIteratorParams() is customized to not stride by layout.stride(0). -- makes use of `SharedStorage` to store reduced values across warps to gmem in coalesced manner. -- customized the store_with_byte_offset() to perform reduction per row and write final value to -gmem. -- customized the Params() struct to take user inputs from epilogueOp params. - -*/ - -#pragma once - -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace cg = cooperative_groups; - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////// - -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator used to load and store output tile from global memory in epilogue. -/// -/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -/// -template -class PredicatedTileIteratorReducedVec { - public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; - - using Element = Element_; - - using Layout = Layout_; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; - using EpilogueOpParams = EpilogueOpParams_; - using OutIdxT = typename EpilogueOpParams::CGReduceT::IndexT; - using OutValT = typename EpilogueOpParams::CGReduceT::AccTypeT; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kThreads = ThreadMap::kThreads; - static int const kIterations = ThreadMap::Count::kTile; - - static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); - static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); - static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); - static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); - static_assert(!UseCUDAStore, "UseCUDAStore path is not supported"); - - static int const total_rows = ThreadMap::kWarpCount * ThreadMap::Iterations::kRow * - ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * - ThreadMap::Count::kTile * ThreadMap::Delta::kRow; - /// Fragment object - using Fragment = - Array; - - // Memory access size - using AccessType = AlignedArray; - using AccessTypeValT = AlignedArray; - - // - // Parameters struct - // - - /// Uses a non-template class - struct Params : PredicatedTileIteratorParams { - using Base = PredicatedTileIteratorParams; - - EpilogueOpParams user_param; - CUTLASS_HOST_DEVICE - Params() {} - - CUTLASS_HOST_DEVICE - Params(Layout const& layout) - : PredicatedTileIteratorParams( - layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, - make_OutputTileThreadMapDesc()) - { - } - - CUTLASS_HOST_DEVICE - Params(Layout const& layout, EpilogueOpParams const& user_param_) - : PredicatedTileIteratorParams(int(sizeof(AccessType)) / kElementsPerAccess, - make_OutputTileThreadMapDesc()), - user_param(user_param_) - { - } - - CUTLASS_HOST_DEVICE - Params(Base const& base) : Base(base) {} - }; - - /// Mask object - struct Mask { - // static int const kCount = ThreadMap::Iterations::kColumn; - static int const kCount = ThreadMap::Iterations::kColumn * kElementsPerAccess; - - /// Predicate state - bool predicates[kCount]; - - // - // Mask - // - CUTLASS_HOST_DEVICE - Mask() { enable(); } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_HOST_DEVICE void clear() - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = false; - } - } - - ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask - CUTLASS_DEVICE void enable() - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = true; - } - } - }; - - /// Shared storage allocation needed by the predicated tile - // iterator for reduction. - struct SharedStorage { - // - // Type definitions - // - using Shape = MatrixShape; - - /// Shape of the shared memory allocation for the reduced values store - using StorageShape = MatrixShape; - - // - // Data members - - // - // Methods - // - AlignedBuffer storage; - - CUTLASS_DEVICE - Element* data() { return storage.data(); } - - SharedStorage() {} - - CUTLASS_DEVICE - void initSmem(EpilogueOpParams const& user_params) - { - Element* shared_elem_arr = data(); - constexpr auto maxVal = std::numeric_limits::max(); - - for (int row = threadIdx.x; row < total_rows; row += blockDim.x) { - user_params.red_op_.init(&shared_elem_arr[row], maxVal); - } - } - }; - - template - struct select_reduce { - /// Performs warp level reduction and stores a reduced output to memory - CUTLASS_DEVICE - select_reduce(OutT value, - ValT prev_red_val, - cg_reduce_op_t reduce_op, - cg_group_t cg_warp_group, - OutT& shmem_ptr) - { - if (cg_warp_group.any(reduce_op.isAmin(value, prev_red_val))) { - OutT reduced_val = cg::reduce(cg_warp_group, value, reduce_op); - if (cg_warp_group.thread_rank() == 0) { shmem_ptr = reduced_val; } - } - } - }; - - template - struct select_reduce> { - using ValT = float; - using Ty = raft::KeyValuePair; - /// Performs warp level reduction of key value pair and stores a reduced output to memory - CUTLASS_DEVICE - select_reduce(Ty val_to_red, - float prev_red_val, - cg_reduce_op_t cg_reduce_op, - cg_group_t cg_warp_group, - Ty& shmem_ptr) - { - ValT val = val_to_red.value; - - if (cg_warp_group.any(cg_reduce_op.isAmin(val, prev_red_val))) { - ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); - bool pred = (reduced_val == val); - auto subTile = cg::binary_partition(cg_warp_group, pred); - if (pred) { - if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } - } - } - } - }; - - template - struct select_reduce> { - using ValT = double; - using Ty = raft::KeyValuePair; - /// Performs warp level reduction of key value pair and stores a reduced output to memory - CUTLASS_DEVICE - select_reduce(Ty val_to_red, - double prev_red_val, - cg_reduce_op_t cg_reduce_op, - cg_group_t cg_warp_group, - Ty& shmem_ptr) - { - ValT val = val_to_red.value; - - if (cg_warp_group.any(cg_reduce_op.isAmin(val, prev_red_val))) { - ValT reduced_val = cg::reduce(cg_warp_group, val, cg_reduce_op); - bool pred = (reduced_val == val); - auto subTile = cg::binary_partition(cg_warp_group, pred); - if (pred) { - if (subTile.thread_rank() == 0) { shmem_ptr = val_to_red; } - } - } - } - }; - - private: - // - // Data members - // - - /// Parameters structure containing reference and precomputed state. - Params params_; - - /// Byte-level pointer first tile offset of this threadblock. - volatile uint8_t* first_tile_byte_pointer_; - - /// Array of boolean values to contain steady-state predicates - Mask mask_; - - /// Extent of the matrix tile in rows - Index extent_row_; - - /// Extent of the matrix tile in rows - Index extent_column_; - - /// A thread's starting row position (assuming steady-state predicates have been computed) - Index thread_start_row_; - Index block_start_row_first_tile_; - - /// A thread's starting column - Index thread_start_column_; - - /// Internal state counter - int state_[3]; - // mutable int shared_tile_id; - - /// Scatter indices - int const* indices_; - - const int do_gmem_reduction_; - - // - // Static asserts about internal strides - // - - static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(Params::stride) == 8, "Expected 64b strides"); - - protected: - SharedStorage& shared_storage_; - - private: - // - // Methods - // - public: - // - // Methods - // - /// Constructor - CUTLASS_DEVICE - PredicatedTileIteratorReducedVec(SharedStorage& shared_storage, - Params const& params, - volatile Element* pointer, - TensorCoord extent, - int thread_idx, - const bool do_gmem_reduction, - TensorCoord threadblock_offset = TensorCoord(), - int const* indices = nullptr) - : params_(params), - indices_(indices), - shared_storage_(shared_storage), - do_gmem_reduction_(do_gmem_reduction) - { - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; - - extent_row_ = extent.row(); - extent_column_ = extent.column(); - - thread_start_row_ = thread_offset.row(); - thread_start_column_ = thread_offset.column(); - - TensorCoord block_offset = ThreadMap::initial_offset(0) + threadblock_offset; - block_start_row_first_tile_ = block_offset.row(); - - // Initialize predicates - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kColumn * kElementsPerAccess; ++c) { - int columnPerAccess = (c / kElementsPerAccess); - int columnWithinPerAccess = c % kElementsPerAccess; - mask_.predicates[c] = ((thread_offset.column() + ThreadMap::Delta::kColumn * columnPerAccess + - columnWithinPerAccess) < extent.column()); - } - - if (threadblock_offset.column() == 0) { - EpilogueOpParams const& user_params = params_.user_param; - shared_storage_.initSmem(user_params); - } - __syncthreads(); - - // Null pointer performs no accesses - if (!pointer) { mask_.clear(); } - - if (ScatterD && !indices) { mask_.clear(); } - - // Initialize pointer - first_tile_byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(block_offset.row()) * LongIndex(params_.stride); - - // Initialize internal state counter - state_[0] = state_[1] = state_[2] = 0; - } - - CUTLASS_DEVICE void dumpToGmem() - { - if (block_start_row_first_tile_ >= extent_row_) { return; } - - if (do_gmem_reduction_) { - EpilogueOpParams const& user_params = params_.user_param; - const uint32_t mutex_id = (block_start_row_first_tile_ / total_rows); - const bool useGmemMutex = (gridDim.x != ((extent_row_ - 1 + total_rows) / total_rows)); - int row = threadIdx.x; - Element* shared_elem_arr = shared_storage_.data(); - Element row_local_min; - if (row < total_rows) { row_local_min = shared_elem_arr[row]; } - - // single lock per block for multiple rows - if (useGmemMutex && threadIdx.x == 0) { user_params.bin_mutex_[mutex_id].acquire(); } - __syncthreads(); - - if (row < total_rows) { - volatile Element* gmem_ptr = reinterpret_cast(first_tile_byte_pointer_); - - if ((block_start_row_first_tile_ + row) < extent_row_) { - user_params.red_op_(block_start_row_first_tile_ + row, (gmem_ptr + row), row_local_min); - } - } - - __syncthreads(); - __threadfence(); - - if (useGmemMutex && (threadIdx.x == 0)) { - // release mutex lock. - user_params.bin_mutex_[mutex_id].release(); - } - shared_storage_.initSmem(user_params); - __syncthreads(); - } - } - - /// Destructor - CUTLASS_DEVICE - ~PredicatedTileIteratorReducedVec() {} - - /// Performs reduction and Stores a reduced output to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment& frag, int64_t byte_offset) const - { - AccessTypeValT* frag_ptr = reinterpret_cast(&frag); - - cg::thread_block cta = cg::this_thread_block(); - // tile_width 16 is required if kElementPerAccess > 1 - constexpr int tile_width = (32 / ThreadMap::Delta::kColumn) ? 32 : 16; - cg::thread_block_tile tile32 = cg::tiled_partition(cta); - EpilogueOpParams const& user_params = params_.user_param; - - using cg_reduce_t = decltype(user_params.cg_reduce_op); - using tile32_t = decltype(tile32); - - Element* shared_elem_arr = shared_storage_.data(); - constexpr auto maxVal = std::numeric_limits::max(); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - const OutIdxT row_id = row_offset + thread_start_row_; - bool row_guard = (row_id < extent_row_); - - const int frag_idx = frag_row_idx * ThreadMap::Iterations::kColumn * kElementsPerAccess; - Element red_val; - user_params.red_op_.init(&red_val, maxVal); - - if (row_guard) { - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn * kElementsPerAccess; - ++column) { - int columnPerAccess = column / kElementsPerAccess; - int columnWithPerAccess = column % kElementsPerAccess; - bool guard = mask_.predicates[column]; - if (guard) { - const OutIdxT key_id = thread_start_column_ + - ThreadMap::Delta::kColumn * columnPerAccess + - columnWithPerAccess; - const int frag_col_idx = frag_idx + column; - - Element this_val; - user_params.red_op_.init(&this_val, (*frag_ptr)[frag_col_idx]); - user_params.red_op_.init_key(this_val, key_id); - user_params.red_op_(row_id, &red_val, this_val); - } - } - } - const int iter_row = (row_id % total_rows); - const auto prev_red_val = user_params.red_op_.get_value(shared_elem_arr[iter_row]); - if (row_guard) { - // select_reduce doesn't need to use `red_op_` as at the warp level we use cg_reduce_op, - // this satisfies the requirement of mst/single linkage of checking colors buffer. - select_reduce red_obj( - red_val, prev_red_val, user_params.cg_reduce_op, tile32, shared_elem_arr[iter_row]); - } - } - } - } - __syncthreads(); - } - - /// Stores a fragment to memory - CUTLASS_DEVICE - void store(Fragment& frag) const { store_with_byte_offset(frag, 0); } - - CUTLASS_DEVICE - MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_row() const { return thread_start_row_; } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_column() const { return thread_start_column_; } - - /// Extent of the matrix in rows - CUTLASS_DEVICE - Index extent_row() const { return extent_row_; } - - /// Extent of the matrix in columns - CUTLASS_DEVICE - Index extent_column() const { return extent_column_; } - - /// Advances to the next position to load or store - CUTLASS_HOST_DEVICE - PredicatedTileIteratorReducedVec& operator++() - { - ++state_[0]; - - thread_start_row_ += ThreadMap::Shape::kRow; - - if (state_[0] == ThreadMap::Count::kRow) { - state_[0] = 0; - ++state_[1]; - - thread_start_row_ += - (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; - - if (state_[1] == ThreadMap::Count::kGroup) { - state_[1] = 0; - ++state_[2]; - - thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * - ThreadMap::Count::kRow * ThreadMap::Shape::kRow; - - if (state_[2] == ThreadMap::Count::kCluster) { state_[2] = 0; } - } - } - return *this; - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_DEVICE void clear_mask() { mask_.clear(); } - - ///< Efficiently enables all accesses guarded by mask - CUTLASS_DEVICE void enable_mask() { mask_.enable(); } - - ///< Sets the mask - CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } - - ///< Sets the mask - CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh b/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh deleted file mode 100644 index 7417fd5dac..0000000000 --- a/cpp/include/raft/distance/detail/fused_distance_nn/simt_kernel.cuh +++ /dev/null @@ -1,187 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::KeyValuePair -#include // ops::l2_exp_distance_op -#include // PairwiseDistances -#include // Policy - -#include // size_t -#include // std::numeric_limits - -namespace raft { -namespace distance { -namespace detail { - -// TODO: specialize this function for MinAndDistanceReduceOp -// with atomicCAS of 64 bit which will eliminate mutex and shfls -template -DI void updateReducedVal( - int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) -{ - const auto lid = threadIdx.x % raft::WarpSize; - const auto accrowid = threadIdx.x / P::AccThCols; - - // Update each output row in order within a warp. This will resolve hang - // issues with pre-Volta architectures -#pragma unroll - for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { - if (lid == j * P::AccThCols) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rid = gridStrideY + accrowid + i * P::AccThRows; - if (rid < m) { - auto value = val[i]; - while (atomicCAS(mutex + rid, 0, 1) == 1) - ; - __threadfence(); - red_op(rid, min + rid, value); - __threadfence(); - atomicCAS(mutex + rid, 1, 0); - } - } - } - } -} - -template -__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedDistanceNNkernel(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - DataT maxVal, - int* mutex, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - OpT distance_op, - FinalLambda fin_op) -{ -// compile only if below non-ampere arch. -#if __CUDA_ARCH__ < 800 - extern __shared__ char smem[]; - - typedef KeyValuePair KVPair; - KVPair val[P::AccRowsPerTh]; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {0, maxVal}; - } - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( - DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - - // intra thread reduce - const auto acccolid = threadIdx.x % P::AccThCols; - const auto accrowid = threadIdx.x / P::AccThCols; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; - KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < n) { - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } - } - } - }; - - auto rowEpilog_lambda = - [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - ReduceOpT red_op(redOp); - - const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); - - // reduce -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, - // but the shfl op applies the modulo internally. - auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); - auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); - KVPair tmp = {tmpkey, tmpvalue}; - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } - } - - updateReducedVal(mutex, min, val, red_op, m, gridStrideY); - - // reset the val array. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {0, maxVal}; - } - }; - - IdxT lda = k, ldb = k, ldd = n; - constexpr bool row_major = true; - constexpr bool write_out = false; - PairwiseDistances - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - xn, - yn, - nullptr, // Output pointer - smem, - distance_op, - epilog_lambda, - fin_op, - rowEpilog_lambda); - obj.run(); -#endif -} - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh deleted file mode 100644 index 5e62045305..0000000000 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ /dev/null @@ -1,386 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::KeyValuePair -#include // raft::identity_op -#include // ops::l2_exp_distance_op -#include -#include // PairwiseDistances -#include // Policy -#include // raft::util::arch::SM_* -#include // raft::ceildiv, raft::shfl - -#include // size_t -#include // std::numeric_limits - -namespace raft { -namespace distance { - -namespace detail { - -template -struct KVPMinReduceImpl { - typedef raft::KeyValuePair KVP; - DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } - -}; // KVPMinReduce - -template -struct MinAndDistanceReduceOpImpl { - typedef typename raft::KeyValuePair KVP; - DI void operator()(LabelT rid, KVP* out, const KVP& other) const - { - if (other.value < out->value) { - out->key = other.key; - out->value = other.value; - } - } - - DI void operator()(LabelT rid, DataT* out, const KVP& other) const - { - if (other.value < *out) { *out = other.value; } - } - - DI void operator()(LabelT rid, DataT* out, const DataT& other) const - { - if (other < *out) { *out = other; } - } - - DI void init(DataT* out, DataT maxVal) const { *out = maxVal; } - DI void init(KVP* out, DataT maxVal) const { out->value = maxVal; } - - DI void init_key(DataT& out, LabelT idx) const { return; } - DI void init_key(KVP& out, LabelT idx) const { out.key = idx; } - - DI DataT get_value(KVP& out) const - { - return out.value; - ; - } - DI DataT get_value(DataT& out) const { return out; } -}; - -template -struct MinReduceOpImpl { - typedef typename raft::KeyValuePair KVP; - DI void operator()(LabelT rid, DataT* out, const KVP& other) - { - if (other.value < *out) { *out = other.value; } - } - - DI void init(DataT* out, DataT maxVal) { *out = maxVal; } -}; - -template -RAFT_KERNEL initKernel(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) -{ - auto tid = IdxT(blockIdx.x) * blockDim.x + threadIdx.x; - if (tid < m) { redOp.init(min + tid, maxVal); } -} - -template -void initialize(OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp, cudaStream_t stream) -{ - auto blks = raft::ceildiv(m, 256); - initKernel<<>>(min, m, maxVal, redOp); -} - -// TODO: specialize this function for MinAndDistanceReduceOp -// with atomicCAS of 64 bit which will eliminate mutex and shfls -template -DI void updateReducedVal( - int* mutex, OutT* min, KVPair* val, ReduceOpT red_op, IdxT m, IdxT gridStrideY) -{ - const auto lid = threadIdx.x % raft::WarpSize; - const auto accrowid = threadIdx.x / P::AccThCols; - - // Update each output row in order within a warp. This will resolve hang - // issues with pre-Volta architectures -#pragma unroll - for (int j = 0; j < (raft::WarpSize / P::AccThCols); j++) { - if (lid == j * P::AccThCols) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rid = gridStrideY + accrowid + i * P::AccThRows; - if (rid < m) { - auto value = val[i]; - while (atomicCAS(mutex + rid, 0, 1) == 1) - ; - __threadfence(); - red_op(rid, min + rid, value); - __threadfence(); - atomicCAS(mutex + rid, 1, 0); - } - } - } - } -} - -template -__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL fusedL2NNkernel(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - DataT maxVal, - int* mutex, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - OpT distance_op, - FinalLambda fin_op) -{ -// compile only if below non-ampere arch. -#if __CUDA_ARCH__ < 800 - extern __shared__ char smem[]; - - typedef KeyValuePair KVPair; - KVPair val[P::AccRowsPerTh]; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {0, maxVal}; - } - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [n, pairRedOp, &val, maxVal] __device__( - DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - - // intra thread reduce - const auto acccolid = threadIdx.x % P::AccThCols; - const auto accrowid = threadIdx.x / P::AccThCols; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + gridStrideX; - KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < n) { - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } - } - } - }; - - auto rowEpilog_lambda = - [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT gridStrideY) { - KVPReduceOpT pairRed_op(pairRedOp); - ReduceOpT red_op(redOp); - - const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); - - // reduce -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - // Actually, the srcLane (lid +j) should be (lid +j) % P:AccThCols, - // but the shfl op applies the modulo internally. - auto tmpkey = raft::shfl(val[i].key, lid + j, P::AccThCols); - auto tmpvalue = raft::shfl(val[i].value, lid + j, P::AccThCols); - KVPair tmp = {tmpkey, tmpvalue}; - val[i] = pairRed_op(accrowid + i * P::AccThRows + gridStrideY, tmp, val[i]); - } - } - - updateReducedVal(mutex, min, val, red_op, m, gridStrideY); - - // reset the val array. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {0, maxVal}; - } - }; - - IdxT lda = k, ldb = k, ldd = n; - constexpr bool row_major = true; - constexpr bool write_out = false; - PairwiseDistances - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - xn, - yn, - nullptr, // Output pointer - smem, - distance_op, - epilog_lambda, - fin_op, - rowEpilog_lambda); - obj.run(); -#endif -} - -// cg::reduce functor for FusedDistanceNN used in its cutlass version -// to output the min distance value & key(loc id). -// This is used in fused_distance_nn/predicated_tile_iterator_reduced_vec.h -// store_with_byte_offset() passed to cg::reduce() & select_reduce. -template -struct kvp_cg_min_reduce_op { - typedef typename raft::KeyValuePair KVP; - - __host__ __device__ kvp_cg_min_reduce_op() noexcept {}; - - using AccTypeT = AccType; - using IndexT = Index; - // functor signature. - __host__ __device__ KVP operator()(KVP a, KVP b) const { return a.value < b.value ? a : b; } - - __host__ __device__ AccType operator()(AccType a, AccType b) const { return min(a, b); } - - __host__ __device__ bool isAmin(AccType a, AccType b) const { return a < b ? true : false; } -}; - -template -void fusedL2NNImpl(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - int* workspace, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - // The kernel policy is determined by fusedL2NN. - typedef Policy P; - - dim3 blk(P::Nthreads); - auto nblks = raft::ceildiv(m, P::Nthreads); - constexpr auto maxVal = std::numeric_limits::max(); - typedef KeyValuePair KVPair; - - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); - if (initOutBuffer) { - initKernel - <<>>(min, m, maxVal, redOp); - RAFT_CUDA_TRY(cudaGetLastError()); - } - - namespace arch = raft::util::arch; - using AccT = DataT; - ops::l2_exp_distance_op distance_op{sqrt}; - - raft::identity_op fin_op{}; - - auto kernel = fusedL2NNkernel; - - // Get pointer to fp32 SIMT kernel to determine the best compute architecture - // out of all for which the kernel was compiled for that matches closely - // to the current device. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 - void* kernel_ptr = reinterpret_cast(kernel); - auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); - auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); - - if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. - using L2Op = raft::distance::detail::ops::l2_exp_cutlass_op; - using kvp_cg_min_reduce_op_ = kvp_cg_min_reduce_op; - kvp_cg_min_reduce_op_ cg_reduce_op; - L2Op L2_dist_op(sqrt); - - IdxT lda, ldb, ldd; - lda = k, ldb = k, ldd = n; - - cutlassFusedDistanceNN(x, - y, - xn, - yn, - m, - n, - k, - lda, - ldb, - ldd, - min, - workspace, - cg_reduce_op, - L2_dist_op, - redOp, - pairRedOp, - stream); - } else { - // If device less than SM_80, use fp32 SIMT kernel. - constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); - - kernel<<>>( - min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); - RAFT_CUDA_TRY(cudaGetLastError()); - } -} - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh b/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh deleted file mode 100644 index 4ce489bc05..0000000000 --- a/cpp/include/raft/distance/detail/kernels/gram_matrix.cuh +++ /dev/null @@ -1,488 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -// #include -#include -#include -#include -#include - -namespace raft::distance::kernels::detail { - -template -using dense_input_matrix_view_t = raft::device_matrix_view; -template -using dense_output_matrix_view_t = raft::device_matrix_view; -template -using csr_input_matrix_view_t = raft::device_csr_matrix_view; - -/** - * Base class for general Gram matrices - * A Gram matrix is the Hermitian matrix of inner probucts G_ik = - * Here, the inner product is evaluated for all elements from vectors sets X1, - * and X2. - * - * To be more precise, on exit the output buffer will store: - * - if is_row_major == true: out[j+k*n1] = , - * - if is_row_major == false: out[j*n2 + k] = , - * where x1_j is the j-th vector from the x1 set and x2_k is the k-th vector - * from the x2 set. - */ -template -class GramMatrixBase { - protected: - cublasHandle_t cublas_handle; - bool legacy_interface; - - public: - GramMatrixBase() : legacy_interface(false){}; - [[deprecated]] GramMatrixBase(cublasHandle_t cublas_handle) - : cublas_handle(cublas_handle), legacy_interface(true){}; - - virtual ~GramMatrixBase(){}; - - /** Convenience function to evaluate the Gram matrix for two vector sets. - * Vector sets are provided in Matrix format - * - * @param [in] handle raft handle - * @param [in] x1 dense device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. - * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. - */ - void operator()(raft::resources const& handle, - dense_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1 = nullptr, - math_t* norm_x2 = nullptr) - { - evaluate(handle, x1, x2, out, norm_x1, norm_x2); - } - - /** Convenience function to evaluate the Gram matrix for two vector sets. - * Vector sets are provided in Matrix format - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. - * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. - */ - void operator()(raft::resources const& handle, - csr_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1 = nullptr, - math_t* norm_x2 = nullptr) - { - evaluate(handle, x1, x2, out, norm_x1, norm_x2); - } - - /** Convenience function to evaluate the Gram matrix for two vector sets. - * Vector sets are provided in Matrix format - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 csr device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. - * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. - */ - void operator()(raft::resources const& handle, - csr_input_matrix_view_t x1, - csr_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1 = nullptr, - math_t* norm_x2 = nullptr) - { - evaluate(handle, x1, x2, out, norm_x1, norm_x2); - } - - // unfortunately, 'evaluate' cannot be templatized as it needs to be virtual - - /** Evaluate the Gram matrix for two vector sets using simple dot product. - * - * @param [in] handle raft handle - * @param [in] x1 dense device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - virtual void evaluate(raft::resources const& handle, - dense_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - linear(handle, x1, x2, out); - } - /** Evaluate the Gram matrix for two vector sets using simple dot product. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - virtual void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - linear(handle, x1, x2, out); - } - /** Evaluate the Gram matrix for two vector sets using simple dot product. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 csr device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - virtual void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - csr_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - linear(handle, x1, x2, out); - } - - /** Evaluate the Gram matrix for two vector sets using simple dot product. - * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 (usually it is n1) - * @param ld2 leading dimension of x2 (usually it is n2) - * @param ld_out leading dimension of out (usually it is n1) - */ - [[deprecated]] virtual void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { - linear(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); - } - - /** Convenience function to evaluate the Gram matrix for two vector sets. - * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 - * @param ld2 leading dimension of x2 - * @param ld_out leading dimension of out - */ - [[deprecated]] void operator()(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1 = 0, - int ld2 = 0, - int ld_out = 0) - { - ASSERT(legacy_interface, "Legacy interface can only be used with legacy ctor."); - if (ld1 <= 0) { ld1 = is_row_major ? n_cols : n1; } - if (ld2 <= 0) { ld2 = is_row_major ? n_cols : n2; } - if (ld_out <= 0) { ld_out = is_row_major ? n2 : n1; } - evaluate(x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); - } - - protected: - /** Calculates the Gram matrix using simple dot product between vector sets. - * - * out = x1 * x2 - * - * Can be used as a building block for more complex kernel functions. - * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 - * @param ld2 leading dimension of x2 - * @param ld_out leading dimension of out - */ - [[deprecated]] void linear(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { - math_t alpha = 1.0; - math_t beta = 0.0; - if (is_row_major) { - // #TODO: Call from public API when ready - RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_handle, - CUBLAS_OP_T, - CUBLAS_OP_N, - n2, - n1, - n_cols, - &alpha, - x2, - ld2, - x1, - ld1, - &beta, - out, - ld_out, - stream)); - } else { - // #TODO: Call from public API when ready - RAFT_CUBLAS_TRY(raft::linalg::detail::cublasgemm(cublas_handle, - CUBLAS_OP_N, - CUBLAS_OP_T, - n1, - n2, - n_cols, - &alpha, - x1, - ld1, - x2, - ld2, - &beta, - out, - ld_out, - stream)); - } - } - - protected: - bool get_is_row_major(dense_output_matrix_view_t matrix) - { - return (matrix.stride(1) == 1); - } - - bool get_is_row_major(dense_input_matrix_view_t matrix) - { - return (matrix.stride(1) == 1); - } - - bool get_is_col_major(dense_output_matrix_view_t matrix) - { - return (matrix.stride(0) == 1); - } - - bool get_is_col_major(dense_input_matrix_view_t matrix) - { - return (matrix.stride(0) == 1); - } - - /** Calculates the Gram matrix using simple dot product between vector sets. - * - * out = x1 * x2 - * - * Can be used as a building block for more complex kernel functions. - * - * @param [in] handle raft handle - * @param [in] x1 dense device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - */ - void linear(raft::resources const& handle, - dense_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out) - { - // check is_row_major consistency - bool is_row_major = get_is_row_major(x1) && get_is_row_major(x2) && get_is_row_major(out); - bool is_col_major = get_is_col_major(x1) && get_is_col_major(x2) && get_is_col_major(out); - ASSERT(is_row_major || is_col_major, - "GramMatrix leading dimensions for x1, x2 and out do not match"); - - // check dimensions - int n1 = out.extent(0); - int n2 = out.extent(1); - int n_cols = x1.extent(1); - ASSERT(x1.extent(0) == n1, "GramMatrix input matrix dimensions for x1 and out do not match"); - ASSERT(x2.extent(0) == n2, "GramMatrix input matrix dimensions for x2 and out do not match"); - ASSERT(x2.extent(1) == n_cols, "GramMatrix input matrix dimensions for x1 and x2 do not match"); - - // extract major stride - int ld1 = is_row_major ? x1.stride(0) : x1.stride(1); - int ld2 = is_row_major ? x2.stride(0) : x2.stride(1); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - - math_t alpha = 1.0; - math_t beta = 0.0; - if (is_row_major) { - // #TODO: Use mdspan-based API when stride-capable - // https://github.com/rapidsai/raft/issues/875 - raft::linalg::gemm(handle, - true, - false, - n2, - n1, - n_cols, - &alpha, - x2.data_handle(), - ld2, - x1.data_handle(), - ld1, - &beta, - out.data_handle(), - ld_out, - resource::get_cuda_stream(handle)); - } else { - // #TODO: Use mdspan-based API when stride-capable - // https://github.com/rapidsai/raft/issues/875 - raft::linalg::gemm(handle, - false, - true, - n1, - n2, - n_cols, - &alpha, - x1.data_handle(), - ld1, - x2.data_handle(), - ld2, - &beta, - out.data_handle(), - ld_out, - resource::get_cuda_stream(handle)); - } - } - - /** Calculates the Gram matrix using simple dot product between vector sets. - * - * out = x1 * x2 - * - * Can be used as a building block for more complex kernel functions. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - */ - void linear(raft::resources const& handle, - csr_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out) - { - // check is_row_major consistency - bool is_row_major = get_is_row_major(x2) && get_is_row_major(out); - bool is_col_major = get_is_col_major(x2) && get_is_col_major(out); - ASSERT(is_row_major || is_col_major, - "GramMatrix leading dimensions for x2 and out do not match"); - - // check dimensions - auto x1_structure = x1.structure_view(); - ASSERT(x1_structure.get_n_rows() == out.extent(0), - "GramMatrix input matrix dimensions for x1 and out do not match"); - ASSERT(x2.extent(0) == out.extent(1), - "GramMatrix input matrix dimensions for x2 and out do not match"); - ASSERT(x2.extent(1) == x1_structure.get_n_cols(), - "GramMatrix input matrix dimensions for x1 and x2 do not match"); - - math_t alpha = 1.0; - math_t beta = 0.0; - - raft::sparse::linalg::spmm(handle, false, true, &alpha, x1, x2, &beta, out); - } - - /** Calculates the Gram matrix using simple dot product between vector sets. - * - * out = x1 * x2 - * - * Can be used as a building block for more complex kernel functions. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 csr device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - */ - void linear(raft::resources const& handle, - csr_input_matrix_view_t x1, - csr_input_matrix_view_t x2, - dense_output_matrix_view_t out) - { - // check layout consistency (w.r.t. strides a matrix might be both row & col major) - bool is_row_major_nopad = get_is_row_major(out) && out.stride(0) == out.extent(1); - bool is_col_major_nopad = get_is_col_major(out) && out.stride(1) == out.extent(0); - - ASSERT(is_row_major_nopad || is_col_major_nopad, - "Sparse linear Kernel distance does not support ld_out parameter"); - - // switch a,b based on is_row_major - if (is_col_major_nopad) { - auto out_row_major = raft::make_device_matrix_view( - out.data_handle(), out.extent(1), out.extent(0)); - raft::sparse::distance::pairwise_distance( - handle, x2, x1, out_row_major, raft::distance::DistanceType::InnerProduct, 0.0); - } else { - auto out_row_major = raft::make_device_matrix_view( - out.data_handle(), out.extent(0), out.extent(1)); - raft::sparse::distance::pairwise_distance( - handle, x1, x2, out_row_major, raft::distance::DistanceType::InnerProduct, 0.0); - } - } -}; - -}; // end namespace raft::distance::kernels::detail diff --git a/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh b/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh deleted file mode 100644 index a64f47cea4..0000000000 --- a/cpp/include/raft/distance/detail/kernels/kernel_factory.cuh +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "gram_matrix.cuh" -#include "kernel_matrices.cuh" - -#include -#include - -namespace raft::distance::kernels::detail { - -template -class KernelFactory { - public: - static GramMatrixBase* create(KernelParams params) - { - GramMatrixBase* res; - // KernelParams is not templated, we convert the parameters to math_t here: - math_t coef0 = params.coef0; - math_t gamma = params.gamma; - switch (params.kernel) { - case LINEAR: res = new GramMatrixBase(); break; - case POLYNOMIAL: res = new PolynomialKernel(params.degree, gamma, coef0); break; - case TANH: res = new TanhKernel(gamma, coef0); break; - case RBF: res = new RBFKernel(gamma); break; - default: throw raft::exception("Kernel not implemented"); - } - return res; - } - - [[deprecated]] static GramMatrixBase* create(KernelParams params, cublasHandle_t handle) - { - GramMatrixBase* res; - // KernelParams is not templated, we convert the parameters to math_t here: - math_t coef0 = params.coef0; - math_t gamma = params.gamma; - switch (params.kernel) { - case LINEAR: res = new GramMatrixBase(handle); break; - case POLYNOMIAL: - res = new PolynomialKernel(params.degree, gamma, coef0, handle); - break; - case TANH: res = new TanhKernel(gamma, coef0, handle); break; - case RBF: res = new RBFKernel(gamma, handle); break; - default: throw raft::exception("Kernel not implemented"); - } - return res; - } -}; - -}; // end namespace raft::distance::kernels::detail diff --git a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh deleted file mode 100644 index 0d64479d84..0000000000 --- a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh +++ /dev/null @@ -1,777 +0,0 @@ -/* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "gram_matrix.cuh" - -#include -#include -#include -#include -#include -#include - -namespace raft::distance::kernels::detail { - -/** Epiloge function for polynomial kernel without padding. - * Calculates output = (gain*in + offset)^exponent - * @param inout device vector in column major format, size [len] - * @param len array length - * @param exponent - * @param gain - * @param offset - */ -template -RAFT_KERNEL polynomial_kernel_nopad( - math_t* inout, size_t len, exp_t exponent, math_t gain, math_t offset) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < len; - tid += blockDim.x * gridDim.x) { - inout[tid] = pow(gain * inout[tid] + offset, exponent); - } -} - -/** Epiloge function for polynomial kernel with padding. - * Calculates output = (gain*input + offset)^exponent - * @param inout device vector in column major format, size [ld * cols] - * @param ld leading dimension of the inout buffer - * @param rows number of rows (rows <= ld) - * @param cols number of columns - * @param exponent - * @param gain - * @param offset - */ -template -RAFT_KERNEL polynomial_kernel( - math_t* inout, int ld, int rows, int cols, exp_t exponent, math_t gain, math_t offset) -{ - for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; - tidy += blockDim.y * gridDim.y) - for (size_t tidx = threadIdx.x + blockIdx.x * blockDim.x; tidx < rows; - tidx += blockDim.x * gridDim.x) { - inout[tidx + tidy * ld] = pow(gain * inout[tidx + tidy * ld] + offset, exponent); - } -} - -/** Epiloge function for tanh kernel without padding. - * Calculates output = tanh(gain*input + offset) - * @param inout device vector, size [len] - * @param len length of the input vector - * @param gain - * @param offset - */ -template -RAFT_KERNEL tanh_kernel_nopad(math_t* inout, size_t len, math_t gain, math_t offset) -{ - for (size_t tid = threadIdx.x + blockIdx.x * blockDim.x; tid < len; - tid += blockDim.x * gridDim.x) { - inout[tid] = tanh(gain * inout[tid] + offset); - } -} - -/** Epiloge function for tanh kernel without padding. - * Calculates output = tanh(gain*input + offset) - * @param inout device vector in column major format, size [ld * cols] - * @param ld leading dimension of the inout buffer - * @param rows number of rows (rows <= ld) - * @param cols number of columns - * @param gain - * @param offset - */ -template -RAFT_KERNEL tanh_kernel(math_t* inout, int ld, int rows, int cols, math_t gain, math_t offset) -{ - for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; - tidy += blockDim.y * gridDim.y) - for (size_t tidx = threadIdx.x + blockIdx.x * blockDim.x; tidx < rows; - tidx += blockDim.x * gridDim.x) { - inout[tidx + tidy * ld] = tanh(gain * inout[tidx + tidy * ld] + offset); - } -} - -/** Epiloge function for rbf kernel using expansion. - * - * Calculates output_ij = exp(-gain * (norm_x_i + norm_y_j - 2*input_ij)); - * - * Intended usage - * - input is the product of two matrices X and Y input_ij = sum_k X_ik * Y_jk - * - norm_x_i = l2_norm(x_i), where x_i is the i-th row of matrix X - * - norm_y_j = l2_norm(y_j), where y_j is the j-th row of matrix Y - * - * @param inout device vector in column major format, size [ld * cols] - * @param ld leading dimension of the inout buffer - * @param rows number of rows (rows <= ld) - * @param cols number of columns - * @param norm_x l2-norm of X's rows - * @param norm_y l2-norm of Y's rows - * @param gain - */ -template -RAFT_KERNEL rbf_kernel_expanded( - math_t* inout, int ld, int rows, int cols, math_t* norm_x, math_t* norm_y, math_t gain) -{ - for (size_t tidy = threadIdx.y + blockIdx.y * blockDim.y; tidy < cols; - tidy += blockDim.y * gridDim.y) { - math_t norm_y_val = norm_y[tidy]; - for (size_t tidx = threadIdx.x + blockIdx.x * blockDim.x; tidx < rows; - tidx += blockDim.x * gridDim.x) { - inout[tidx + tidy * ld] = - exp(-1.0 * gain * (norm_x[tidx] + norm_y_val - inout[tidx + tidy * ld] * 2)); - } - } -} - -namespace { -std::tuple generateLaunchConfig2dElementwiseOp(int n1, int n2) -{ - dim3 block_shape = dim3(32, 4); - const int num_blocks_x = raft::ceildiv(n1, 32); - const int num_blocks_y = std::min(raft::ceildiv(n2, 32), (1 << 16) - 1); - dim3 grid_shape = dim3(num_blocks_x, num_blocks_y); - return std::make_tuple(grid_shape, block_shape); -} -} // namespace - -/** - * Create a kernel matrix using polynomial kernel function. - */ -template -class PolynomialKernel : public GramMatrixBase { - exp_t exponent; - math_t gain; - math_t offset; - - void applyKernel( - math_t* inout, int ld, int rows, int cols, bool is_row_major, cudaStream_t stream) - { - const int n_minor = is_row_major ? cols : rows; - if (ld == n_minor) { - polynomial_kernel_nopad<<((size_t)rows * cols, 128), 128, 0, stream>>>( - inout, rows * cols, exponent, gain, offset); - } else { - int n1 = is_row_major ? cols : rows; - int n2 = is_row_major ? rows : cols; - auto [grid_shape, block_shape] = generateLaunchConfig2dElementwiseOp(n1, n2); - polynomial_kernel<<>>( - inout, ld, n1, n2, exponent, gain, offset); - } - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } - - public: - /** - * Constructs a polynomial kernel object. - * It evaluates the kernel matrix using the following formula: - * K_ij = (gain* + offset)^exponent - * - * @tparam math_t floating point type - * @tparam exp_t type of exponent - * @param exponent - * @param gain - * @param offset - */ - PolynomialKernel(exp_t exponent, math_t gain, math_t offset) - : GramMatrixBase(), exponent(exponent), gain(gain), offset(offset) - { - } - - [[deprecated]] PolynomialKernel(exp_t exponent, math_t gain, math_t offset, cublasHandle_t handle) - : GramMatrixBase(handle), exponent(exponent), gain(gain), offset(offset) - { - } - - /** Evaluate kernel matrix using polynomial kernel. - * - * output[i,k] = (gain* + offset)^exponent, - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and < , > denotes dot product. - * - * @param [in] handle raft handle - * @param [in] x1 dense device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - void evaluate(raft::resources const& handle, - dense_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate kernel matrix using polynomial kernel. - * - * output[i,k] = (gain* + offset)^exponent, - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and < , > denotes dot product. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate kernel matrix using polynomial kernel. - * - * output[i,k] = (gain* + offset)^exponent, - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and < , > denotes dot product. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 csr device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - csr_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate the Gram matrix using the legacy interface. - * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 (usually it is n1) - * @param ld2 leading dimension of x2 (usually it is n2) - * @param ld_out leading dimension of out (usually it is n1) - */ - [[deprecated]] void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { - ASSERT(GramMatrixBase::legacy_interface, - "Legacy interface can only be used with legacy ctor."); - GramMatrixBase::linear( - x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); - applyKernel(out, ld_out, n1, n2, is_row_major, stream); - } -}; - -/** - * Create a kernel matrix using tanh kernel function. - */ -template -class TanhKernel : public GramMatrixBase { - math_t gain, offset; - - void applyKernel( - math_t* inout, int ld, int rows, int cols, bool is_row_major, cudaStream_t stream) - { - const int n_minor = is_row_major ? cols : rows; - if (ld == n_minor) { - tanh_kernel_nopad<<((size_t)rows * cols, 128), 128, 0, stream>>>( - inout, rows * cols, gain, offset); - } else { - int n1 = is_row_major ? cols : rows; - int n2 = is_row_major ? rows : cols; - auto [grid_shape, block_shape] = generateLaunchConfig2dElementwiseOp(n1, n2); - tanh_kernel<<>>(inout, ld, n1, n2, gain, offset); - } - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } - - public: - /** - * Constructs a tanh kernel object. - * It evaluates the kernel matrix using the following formula: - * K_ij = tanh(gain* + offset) - * - * @tparam math_t floating point type - * @param gain - * @param offset - */ - TanhKernel(math_t gain, math_t offset) : GramMatrixBase(), gain(gain), offset(offset) {} - - [[deprecated]] TanhKernel(math_t gain, math_t offset, cublasHandle_t handle) - : GramMatrixBase(handle), gain(gain), offset(offset) - { - } - - /** Evaluate kernel matrix using tanh kernel. - * - * output_[i + k*n1] = (gain* + offset)^exponent, - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and < , > denotes dot product. - * - * @param [in] handle raft handle - * @param [in] x1 dense device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - void evaluate(raft::resources const& handle, - dense_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate kernel matrix using tanh kernel. - * - * output_[i + k*n1] = (gain* + offset)^exponent, - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and < , > denotes dot product. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate kernel matrix using tanh kernel. - * - * output_[i + k*n1] = (gain* + offset)^exponent, - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and < , > denotes dot product. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 csr device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 unused. - * @param norm_x2 unused. - */ - void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - csr_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate the Gram matrix using the legacy interface. - * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 (usually it is n1) - * @param ld2 leading dimension of x2 (usually it is n2) - * @param ld_out leading dimension of out (usually it is n1) - */ - [[deprecated]] void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { - ASSERT(GramMatrixBase::legacy_interface, - "Legacy interface can only be used with legacy ctor."); - GramMatrixBase::linear( - x1, n1, n_cols, x2, n2, out, is_row_major, stream, ld1, ld2, ld_out); - applyKernel(out, ld_out, n1, n2, is_row_major, stream); - } -}; - -/** - * Create a kernel matrix using RBF kernel function. - */ -template -class RBFKernel : public GramMatrixBase { - math_t gain; - - void applyKernel(math_t* inout, - int ld, - int rows, - int cols, - math_t* norm_x1, - math_t* norm_x2, - bool is_row_major, - cudaStream_t stream) - { - int n1 = is_row_major ? cols : rows; - int n2 = is_row_major ? rows : cols; - math_t* norm_n1 = is_row_major ? norm_x2 : norm_x1; - math_t* norm_n2 = is_row_major ? norm_x1 : norm_x2; - auto [grid_shape, block_shape] = generateLaunchConfig2dElementwiseOp(n1, n2); - rbf_kernel_expanded<<>>( - inout, ld, n1, n2, norm_n1, norm_n2, gain); - } - - public: - /** - * Constructs a RBF kernel object. - * It evaluates the kernel matrix using the following formula: - * K_ij = exp(-gain*|x1_i- x2_k|^2) - * - * @tparam math_t floating point type - * @param gain - */ - RBFKernel(math_t gain) : GramMatrixBase(), gain(gain) {} - - [[deprecated]] RBFKernel(math_t gain, cublasHandle_t handle) - : GramMatrixBase(handle), gain(gain) - { - } - - void matrixRowNormL2(raft::resources const& handle, - dense_input_matrix_view_t matrix, - math_t* target) - { - bool is_row_major = GramMatrixBase::get_is_row_major(matrix); - int minor = is_row_major ? matrix.extent(1) : matrix.extent(0); - int ld = is_row_major ? matrix.stride(0) : matrix.stride(1); - ASSERT(ld == minor, "RBF Kernel lazy rowNorm compute does not support ld parameter"); - raft::linalg::rowNorm(target, - matrix.data_handle(), - matrix.extent(1), - matrix.extent(0), - raft::linalg::NormType::L2Norm, - is_row_major, - resource::get_cuda_stream(handle)); - } - - void matrixRowNormL2(raft::resources const& handle, - csr_input_matrix_view_t matrix, - math_t* target) - { - auto matrix_structure = matrix.structure_view(); - raft::sparse::linalg::rowNormCsr(handle, - matrix_structure.get_indptr().data(), - matrix.get_elements().data(), - matrix_structure.get_nnz(), - matrix_structure.get_n_rows(), - target, - raft::linalg::NormType::L2Norm); - } - - /** Evaluate kernel matrix using RBF kernel. - * - * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and | | euclidean distance. - * - * @param [in] handle raft handle - * @param [in] x1 dense device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. - * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. - */ - void evaluate(raft::resources const& handle, - dense_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - cudaStream_t stream = resource::get_cuda_stream(handle); - // lazy compute norms if not given - rmm::device_uvector tmp_norm_x1(0, stream); - rmm::device_uvector tmp_norm_x2(0, stream); - if (norm_x1 == nullptr) { - tmp_norm_x1.reserve(x1.extent(0), stream); - norm_x1 = tmp_norm_x1.data(); - matrixRowNormL2(handle, x1, norm_x1); - } - if (norm_x2 == nullptr) { - tmp_norm_x2.reserve(x2.extent(0), stream); - norm_x2 = tmp_norm_x2.data(); - matrixRowNormL2(handle, x2, norm_x2); - } - - // compute L2expanded - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - norm_x1, - norm_x2, - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate kernel matrix using RBF kernel. - * - * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and | | euclidean distance. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 dense device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. - * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. - */ - void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - dense_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - cudaStream_t stream = resource::get_cuda_stream(handle); - - // lazy compute norms if not given - rmm::device_uvector tmp_norm_x1(0, stream); - rmm::device_uvector tmp_norm_x2(0, stream); - if (norm_x1 == nullptr) { - tmp_norm_x1.reserve(x1.structure_view().get_n_rows(), stream); - norm_x1 = tmp_norm_x1.data(); - matrixRowNormL2(handle, x1, norm_x1); - } - if (norm_x2 == nullptr) { - tmp_norm_x2.reserve(x2.extent(0), stream); - norm_x2 = tmp_norm_x2.data(); - matrixRowNormL2(handle, x2, norm_x2); - } - - // compute L2expanded - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - norm_x1, - norm_x2, - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate kernel matrix using RBF kernel. - * - * output_[i + k*n1] = exp(-gain*|x1_i - x2_k|^2), - * where x1_i is the i-th vector from the x1 set, and x2_k is k-th vector - * in the x2 set, and | | euclidean distance. - * - * @param [in] handle raft handle - * @param [in] x1 csr device matrix view, size [n1*n_cols] - * @param [in] x2 csr device matrix view, size [n2*n_cols] - * @param [out] out dense device matrix view for the Gram matrix, size [n1*n2] - * @param norm_x1 optional L2-norm of x1's rows for computation within RBF. - * @param norm_x2 optional L2-norm of x2's rows for computation within RBF. - */ - void evaluate(raft::resources const& handle, - csr_input_matrix_view_t x1, - csr_input_matrix_view_t x2, - dense_output_matrix_view_t out, - math_t* norm_x1, - math_t* norm_x2) - { - cudaStream_t stream = resource::get_cuda_stream(handle); - - // lazy compute norms if not given - rmm::device_uvector tmp_norm_x1(0, stream); - rmm::device_uvector tmp_norm_x2(0, stream); - if (norm_x1 == nullptr) { - tmp_norm_x1.reserve(x1.structure_view().get_n_rows(), stream); - norm_x1 = tmp_norm_x1.data(); - matrixRowNormL2(handle, x1, norm_x1); - } - if (norm_x2 == nullptr) { - tmp_norm_x2.reserve(x2.structure_view().get_n_rows(), stream); - norm_x2 = tmp_norm_x2.data(); - matrixRowNormL2(handle, x2, norm_x2); - } - - // compute L2expanded - bool is_row_major = GramMatrixBase::get_is_row_major(out); - int ld_out = is_row_major ? out.stride(0) : out.stride(1); - GramMatrixBase::linear(handle, x1, x2, out); - applyKernel(out.data_handle(), - ld_out, - out.extent(0), - out.extent(1), - norm_x1, - norm_x2, - is_row_major, - resource::get_cuda_stream(handle)); - } - - /** Evaluate the Gram matrix using the legacy interface. - * - * @param [in] x1 device array of vectors, size [n1*n_cols] - * @param [in] n1 number vectors in x1 - * @param [in] n_cols number of columns (features) in x1 and x2 - * @param [in] x2 device array of vectors, size [n2*n_cols] - * @param [in] n2 number vectors in x2 - * @param [out] out device buffer to store the Gram matrix, size [n1*n2] - * @param [in] is_row_major whether the input and output matrices are in row - * major format - * @param [in] stream cuda stream - * @param ld1 leading dimension of x1 (usually it is n1) - * @param ld2 leading dimension of x2 (usually it is n2) - * @param ld_out leading dimension of out (usually it is n1) - */ - [[deprecated]] void evaluate(const math_t* x1, - int n1, - int n_cols, - const math_t* x2, - int n2, - math_t* out, - bool is_row_major, - cudaStream_t stream, - int ld1, - int ld2, - int ld_out) - { - ASSERT(GramMatrixBase::legacy_interface, - "Legacy interface can only be used with legacy ctor."); - int minor1 = is_row_major ? n_cols : n1; - int minor2 = is_row_major ? n_cols : n2; - int minor_out = is_row_major ? n2 : n1; - ASSERT(ld1 == minor1, "RBF Kernel distance does not support ld1 parameter"); - ASSERT(ld2 == minor2, "RBF Kernel distance does not support ld2 parameter"); - ASSERT(ld_out == minor_out, "RBF Kernel distance does not support ld_out parameter"); - - math_t gain = this->gain; - using index_t = int64_t; - - rbf_fin_op fin_op{gain}; - - raft::resources handle; - resource::set_cuda_stream(handle, stream); - - raft::distance::distance(handle, - const_cast(x1), - const_cast(x2), - out, - n1, - n2, - n_cols, - NULL, - 0, - fin_op, - is_row_major); - } -}; - -}; // end namespace raft::distance::kernels::detail diff --git a/cpp/include/raft/distance/detail/kernels/rbf_fin_op.cuh b/cpp/include/raft/distance/detail/kernels/rbf_fin_op.cuh deleted file mode 100644 index cd19675477..0000000000 --- a/cpp/include/raft/distance/detail/kernels/rbf_fin_op.cuh +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -/* - * This file defines rbf_fin_op, which is used in GramMatrixBase. - * - * This struct has been moved to a separate file, so that it is cheap to include - * in distance/distance-ext.cuh, where an instance of raft::distance::distance - * with the rbf_fin_op is instantiated. - * - */ - -#include // raft::exp -#include // HD - -namespace raft::distance::kernels::detail { - -/** @brief: Final op for Gram matrix with RBF kernel. - * - * Calculates output = e^(-gain * in) - * - */ -template -struct rbf_fin_op { - OutT gain; - - explicit HD rbf_fin_op(OutT gain_) noexcept : gain(gain_) {} - - template - HDI OutT operator()(OutT d_val, Args... unused_args) - { - return raft::exp(-gain * d_val); - } -}; // struct rbf_fin_op - -} // namespace raft::distance::kernels::detail diff --git a/cpp/include/raft/distance/detail/masked_distance_base.cuh b/cpp/include/raft/distance/detail/masked_distance_base.cuh deleted file mode 100644 index 96b778f11f..0000000000 --- a/cpp/include/raft/distance/detail/masked_distance_base.cuh +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include -#include -#include - -#include - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief Device class for masked nearest neighbor computations. - * - * @tparam useNorms whether norms are needed - * @tparam DataT input data-type (for x and y matrices) - * @tparam AccT accumulation data-type - * @tparam IdxT index data-type - * @tparam Policy struct which tunes the Contraction kernel - * @tparam CoreLambda tells how to accumulate an x and y into - acc. its signature: - template void core_lambda(AccT& acc, - const DataT& x, const DataT& y) - * @tparam EpilogueLambda applies an elementwise function to compute final - values. Its signature is: - template void epilogue_lambda - (AccT acc[][], DataT* regxn, DataT* regyn); - * @tparam FinalLambda the final lambda called on final distance value - * @tparam rowEpilogueLambda epilog lambda that executes when a full row has - * been processed. - * - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of x - * @param[in] n number of columns of y - * @param[in] k number of cols of x and y - * @param[in] lda leading dimension of x - * @param[in] ldb leading dimension of y - * @param[in] ldd parameter to keep Contractions_NT happy.. - * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine - * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine - * @param[in] adj An adjacency matrix encoded as a bitfield indicating for each - * row of `x` and each group in `y` whether to compute the - * distance. Dim = `(m / 64) x num_groups`. - * @param[in] group_idxs An array containing the *end* indices of each group - * in `y`. The value of group_idxs[j] indicates the - * start of group j + 1, i.e., it is the inclusive - * scan of the group lengths. The first group is - * always assumed to start at index 0 and the last - * group typically ends at index `n`. Length = - * `num_groups`. - * @param[in] num_groups The number of groups in group_idxs. - * @param[in] smem shared mem buffer for intermediate storage of x, y, xn & yn. - * @param core_op the core accumulation operation lambda - * @param epilog_op the epilog operation lambda - * @param fin_op the final gemm epilogue lambda - * @param rowEpilog_op epilog lambda that executes when a full row has been processed. - */ -template > -struct MaskedDistances : public BaseClass { - private: - typedef Policy P; - const DataT* xn; - const DataT* yn; - const DataT* const yBase; - const uint64_t* adj; - const IdxT* group_idxs; - IdxT num_groups; - char* smem; - CoreLambda core_op; - EpilogueLambda epilog_op; - FinalLambda fin_op; - rowEpilogueLambda rowEpilog_op; - - AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; - - public: - // Constructor - DI MaskedDistances(const DataT* _x, - const DataT* _y, - IdxT _m, - IdxT _n, - IdxT _k, - IdxT _lda, - IdxT _ldb, - IdxT _ldd, - const DataT* _xn, - const DataT* _yn, - const uint64_t* _adj, - const IdxT* _group_idxs, - IdxT _num_groups, - char* _smem, - CoreLambda _core_op, - EpilogueLambda _epilog_op, - FinalLambda _fin_op, - rowEpilogueLambda _rowEpilog_op) - : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), - xn(_xn), - yn(_yn), - yBase(_y), - adj(_adj), - group_idxs(_group_idxs), - num_groups(_num_groups), - smem(_smem), - core_op(_core_op), - epilog_op(_epilog_op), - fin_op(_fin_op), - rowEpilog_op(_rowEpilog_op) - { - } - - DI void run() - { - const auto grid_stride_m = (P::Mblk * gridDim.y); - const auto grid_offset_m = (P::Mblk * blockIdx.y); - - const auto grid_stride_g = gridDim.x; - const auto grid_offset_g = blockIdx.x; - - for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { - // Start loop over groups - for (auto idx_g = grid_offset_g; idx_g < this->num_groups; idx_g += grid_stride_g) { - const uint64_t block_adj = get_block_adjacency(adj, tile_idx_m, idx_g); - // block_adj is a bitfield that contains a 1 if a row is adjacent to the - // current group. All zero means we can skip this group. - if (block_adj == 0) { continue; } - - // thread_adj is a bitfield that contains a 1 at location i iff we must - // compute row i of acc (the accumulator register tile). That is, - // for i = 0,.., AccRowsPerTh and j = 0,.., AccColsPerTh: - // - // ((1 << i) & thread_adj) > 0 <=> acc[i][j] must be computed. - // - // We precompute this information because it is used in various - // locations to skip thread-local computations, specifically: - // - // 1. To skip computations if thread_adj == 0, i.e., none of the values - // of `acc` have to be computed. - // - // 2. In epilog_op, to consider only values of `acc` to be reduced that - // are not masked of. - // - // Note 1: Even when the computation can be skipped for a specific thread, - // the thread still participates in synchronization operations. - // - // Note 2: In theory, it should be possible to skip computations for - // specific rows of `acc`. In practice, however, this does not improve - // performance. - int thread_adj = compute_thread_adjacency(block_adj); - - auto tile_idx_n = idx_g == 0 ? 0 : group_idxs[idx_g - 1]; - const auto group_end_n = group_idxs[idx_g]; - for (; tile_idx_n < group_end_n; tile_idx_n += P::Nblk) { - // We provide group_end_n to limit the number of unnecessary data - // points that are loaded from y. - this->ldgXY(tile_idx_m, tile_idx_n, 0, group_end_n); - - reset_accumulator(); - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(tile_idx_m, tile_idx_n, kidx, group_end_n); - // Process all data in shared memory (previous k-block) and - // accumulate in registers. - if (thread_adj != 0) { accumulate(); } - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); - } - if (thread_adj != 0) { - accumulate(); // last iteration - } - // The pre-condition for the loop over tile_idx_n is that write_buffer - // and read_buffer point to the same buffer. This flips read_buffer - // back so that it satisfies the pre-condition of this loop. - this->switch_read_buffer(); - - if (useNorms) { - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; - load_norms(tile_idx_m, tile_idx_n, group_end_n, regxn, regyn); - if (thread_adj != 0) { - epilog_op(acc, thread_adj, regxn, regyn, tile_idx_n, tile_idx_m, group_end_n); - } - } else { - if (thread_adj != 0) { - epilog_op(acc, thread_adj, nullptr, nullptr, tile_idx_n, tile_idx_m, group_end_n); - } - } - } // tile_idx_n - } // idx_g - rowEpilog_op(tile_idx_m); - } // tile_idx_m - } - - private: - DI uint64_t get_block_adjacency(const uint64_t* adj, IdxT tile_idx_m, IdxT idx_group) - { - // A single element of `adj` contains exactly enough bits to indicate which - // rows in the current tile to skip and which to compute. - static_assert(P::Mblk == 8 * sizeof(adj[0]), - "masked_l2_nn only supports a policy with 64 rows per block."); - IdxT block_flag_idx = tile_idx_m / P::Mblk; - // Index into adj at row tile_idx_m / 64 and column idx_group. - return adj[block_flag_idx * this->num_groups + idx_group]; - } - - DI uint32_t compute_thread_adjacency(const uint64_t block_adj) - { - // thread_adj is a bitfield that contains a 1 at location i iff we must - // compute row i of acc (the accumulator register tile). It is described in - // more detail in the run() method. - uint32_t thread_adj = 0; -#pragma unroll - for (int thread_row_idx = 0; thread_row_idx < P::AccRowsPerTh; ++thread_row_idx) { - // Index `thread_row_idx` refers to a row of the current threads' register - // tile `acc`, i.e., acc[i][:]. Index `block_row_idx` refers to the - // corresponding row of the current block tile in shared memory. - const int block_row_idx = this->accrowid + thread_row_idx * P::AccThRows; - - // block_row_is_adjacent is true if the current block_row_idx is adjacent - // to the current group. - const uint64_t block_mask = 1ull << block_row_idx; - const bool block_row_is_adjacent = (block_adj & block_mask) != 0; - if (block_row_is_adjacent) { - // If block row is adjacent, write a 1 bit to thread_adj at location - // `thread_row_idx`. - const uint32_t thread_mask = 1 << thread_row_idx; - thread_adj |= thread_mask; - } - } - return thread_adj; - } - - DI void reset_accumulator() - { - // Reset accumulator registers to zero. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero(); - } - } - } - - DI void accumulate() - { -#pragma unroll - for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { - this->ldsXY(ki); -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { -#pragma unroll - for (int v = 0; v < P::Veclen; ++v) { - core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); - } - } - } - } - } - - DI void load_norms(IdxT tile_idx_m, - IdxT tile_idx_n, - IdxT end_n, - DataT (®xn)[P::AccRowsPerTh], - DataT (®yn)[P::AccColsPerTh]) - { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = tile_idx_m + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } - - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = tile_idx_n + i; - syNorm[i] = idx < end_n ? yn[idx] : 0; - } - __syncthreads(); - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } -#pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; - } - } -}; // struct MaskedDistances - -}; // namespace detail -}; // namespace distance -}; // namespace raft diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh deleted file mode 100644 index 951e030cbd..0000000000 --- a/cpp/include/raft/distance/detail/masked_nn.cuh +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include - -namespace raft { -namespace distance { -namespace detail { - -template -__launch_bounds__(P::Nthreads, 2) RAFT_KERNEL masked_l2_nn_kernel(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - const uint64_t* adj, - const IdxT* group_idxs, - IdxT num_groups, - IdxT m, - IdxT n, - IdxT k, - bool sqrt, - DataT maxVal, - int* mutex, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - CoreLambda core_op, - FinalLambda fin_op) -{ - extern __shared__ char smem[]; - - typedef raft::KeyValuePair KVPair; - KVPair val[P::AccRowsPerTh]; -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; - } - - // epilogue operation lambda for final value calculation - auto epilog_lambda = [pairRedOp, &val, maxVal, sqrt] __device__( - DataT acc[P::AccRowsPerTh][P::AccColsPerTh], - int thread_adj, - DataT* regxn, - DataT* regyn, - IdxT tile_idx_n, - IdxT tile_idx_m, - IdxT tile_end_n) { - KVPReduceOpT pairRed_op(pairRedOp); - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - if (sqrt) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = raft::sqrt(acc[i][j]); - } - } - } - - // intra thread reduce - const auto acccolid = threadIdx.x % P::AccThCols; - const auto accrowid = threadIdx.x / P::AccThCols; - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - // thread_adj is a bitfield that contains a 1 at location i iff we must - // compute row i of acc (the accumulator register tile). It is described in - // more detail in the maskedDistances.run() method. - const bool ignore = (thread_adj & (1 << i)) == 0; - if (ignore) { continue; } -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto tmpkey = acccolid + j * P::AccThCols + tile_idx_n; - if (tile_end_n <= tmpkey) { - // Do not process beyond end of tile. - continue; - } - KVPair tmp = {tmpkey, acc[i][j]}; - if (tmpkey < tile_end_n) { - val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); - } - } - } - }; - - auto rowEpilog_lambda = - [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT tile_idx_m) { - KVPReduceOpT pairRed_op(pairRedOp); - ReduceOpT red_op(redOp); - - const auto accrowid = threadIdx.x / P::AccThCols; - const auto lid = raft::laneId(); - // reduce -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = P::AccThCols / 2; j > 0; j >>= 1) { - auto tmpkey = raft::shfl(val[i].key, lid + j); - auto tmpvalue = raft::shfl(val[i].value, lid + j); - KVPair tmp = {tmpkey, tmpvalue}; - val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); - } - } - - updateReducedVal(mutex, min, val, red_op, m, tile_idx_m); - - // reset the val array. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; - } - }; - - IdxT lda = k, ldb = k, ldd = n; - MaskedDistances - obj(x, - y, - m, - n, - k, - lda, - ldb, - ldd, - xn, - yn, - adj, - group_idxs, - num_groups, - smem, - core_op, - epilog_lambda, - fin_op, - rowEpilog_lambda); - obj.run(); -} - -/** - * @brief Wrapper for masked_l2_nn_kernel - * - * Responsibilities: - * - Allocate (and initialize) workspace memory for: - * - mutexes used in nearest neighbor update step - * - adjacency matrix bitfield - * - Compress adjacency matrix to bitfield - * - Initialize output buffer (conditional on `initOutBuffer`) - * - Specify core and final operations for the L2 norm - * - Determine optimal launch configuration for kernel. - * - Launch kernel and check for errors. - * - * @tparam DataT Input data-type (for x and y matrices). - * @tparam OutT Output data-type (for key-value pairs). - * @tparam IdxT Index data-type. - * @tparam ReduceOpT A struct to perform the final needed reduction - * operation and also to initialize the output array - * elements with the appropriate initial value needed for - * reduction. - * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. - * - * @param handle RAFT handle for managing expensive resources - * @param[out] out Will contain reduced output (nn key-value pairs) - * @param[in] x First matrix. Row major. Dim = `m x k`. (on device) - * @param[in] y Second matrix. Row major. Dim = `n x k`. (on device) - * @param[in] xn L2 squared norm of `x`. Length = `m`. - * @param[in] yn L2 squared norm of `y`. Length = `n`. - * @param[in] adj A boolean adjacency matrix indicating for each - * row of `x` and each group in `y` whether to compute the - * distance. Dim = `m x num_groups`. - * @param[in] group_idxs An array containing the *end* indices of each group - * in `y`. The value of group_idxs[j] indicates the - * start of group j + 1, i.e., it is the inclusive - * scan of the group lengths. The first group is - * always assumed to start at index 0 and the last - * group typically ends at index `n`. Length = - * `num_groups`. - * @param[in] num_groups Length of `group_idxs`. - * @param m Rows of `x`. - * @param n Rows of `y`. - * @param k Cols of `x` and `y`. - * @param redOp Reduction operator in the epilogue - * @param pairRedOp Reduction operation on key value pairs - * @param sqrt Whether to compute the squared or actual (i.e. sqrt) L2 norm. - * @param initOutBuffer Whether to initialize the output buffer - * - * - */ -template -void masked_l2_nn_impl(raft::resources const& handle, - OutT* out, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - const bool* adj, - const IdxT* group_idxs, - IdxT num_groups, - IdxT m, - IdxT n, - IdxT k, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer) -{ - typedef typename linalg::Policy4x4::Policy P; - - static_assert(P::Mblk == 64, "masked_l2_nn_impl only supports a policy with 64 rows per block."); - - // Get stream and workspace memory resource - auto stream = resource::get_cuda_stream(handle); - auto ws_mr = resource::get_workspace_resource(handle); - - // Acquire temporary buffers and initialize to zero: - // 1) Adjacency matrix bitfield - // 2) Workspace for fused nearest neighbor operation - size_t m_div_64 = raft::ceildiv(m, IdxT(64)); - rmm::device_uvector ws_adj64{m_div_64 * num_groups, stream, ws_mr}; - rmm::device_uvector ws_fused_nn{size_t(m), stream, ws_mr}; - RAFT_CUDA_TRY(cudaMemsetAsync(ws_adj64.data(), 0, ws_adj64.size() * sizeof(uint64_t), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(ws_fused_nn.data(), 0, ws_fused_nn.size() * sizeof(int), stream)); - - // Compress boolean adjacency matrix to bitfield. - auto adj_view = raft::make_device_matrix_view(adj, m, num_groups); - auto adj64_view = - raft::make_device_matrix_view(ws_adj64.data(), m_div_64, num_groups); - compress_to_bits(handle, adj_view, adj64_view); - - // Initialize output buffer with keyvalue pairs as determined by the reduction - // operator (it will be called with maxVal). - constexpr auto maxVal = std::numeric_limits::max(); - if (initOutBuffer) { - dim3 grid(raft::ceildiv(m, P::Nthreads)); - dim3 block(P::Nthreads); - - initKernel<<>>(out, m, maxVal, redOp); - RAFT_CUDA_TRY(cudaGetLastError()); - } - - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - auto fin_op = raft::identity_op{}; - - auto kernel = masked_l2_nn_kernel; - constexpr size_t smemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - dim3 block(P::Nthreads); - dim3 grid = launchConfigGenerator

(m, n, smemSize, kernel); - - kernel<<>>(out, - x, - y, - xn, - yn, - ws_adj64.data(), - group_idxs, - num_groups, - m, - n, - k, - sqrt, - maxVal, - ws_fused_nn.data(), - redOp, - pairRedOp, - core_lambda, - fin_op); - - RAFT_CUDA_TRY(cudaGetLastError()); -} - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh deleted file mode 100644 index a8a541bf53..0000000000 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include // raft::linalg::Contractions_NT -#include // ceildiv -#include // RAFT_CUDA_TRY - -#include // size_t - -namespace raft { -namespace distance { -namespace detail { - -/** - * @brief Device class for L1, L2 and cosine distance metrics. - * @tparam DataT input data-type (for A and B matrices) - * @tparam AccT accumulation data-type - * @tparam OutT output data-type (for C and D matrices) - * @tparam IdxT index data-type - * @tparam Policy struct which tunes the Contraction kernel - * @tparam OpT A distance operation, e.g., cosine_distance_op. - * @tparam EpilogueLambda applies an elementwise function to compute final - values. Its signature is: - template void epilogue_lambda - (AccT acc[][], DataT* regxn, DataT* regyn); - * @tparam FinalLambda the final lambda called on final distance value - * @param[in] x input matrix - * @param[in] y input matrix - * @param[in] m number of rows of A and C/D - * @param[in] n number of columns of B and C/D - * @param[in] k number of cols of A and rows of B - * @param[in] lda leading dimension of A - * @param[in] ldb leading dimension of B - * @param[in] ldd leading dimension of C/D - * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine - * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine - * @param[output] pD output matrix - * @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn. - * @param distance_op the distance operation, e.g. cosine_distance_op - * @param epilog_op the epilog operation lambda - * @param fin_op the final gemm epilogue lambda - * @param rowEpilog_op epilog lambda that executes when a full row has been processed - */ - -template > -struct PairwiseDistances : public BaseClass { - // Get accumulation type from distance_op - using AccT = typename OpT::AccT; - - private: - typedef Policy P; - const DataT* xn; - const DataT* yn; - const DataT* const yBase; - OutT* dOutput; - char* smem; - OpT distance_op; - EpilogueLambda epilog_op; - FinalLambda fin_op; - rowEpilogueLambda rowEpilog_op; - - const IdxT grid_stride_m; - const IdxT grid_stride_n; - const IdxT grid_offset_m; - const IdxT grid_offset_n; - - AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; - - public: - // Constructor - DI PairwiseDistances(const DataT* _x, - const DataT* _y, - IdxT _m, - IdxT _n, - IdxT _k, - IdxT _lda, - IdxT _ldb, - IdxT _ldd, - const DataT* _xn, - const DataT* _yn, - OutT* _dOutput, - char* _smem, - OpT _distance_op, - EpilogueLambda _epilog_op, - FinalLambda _fin_op, - rowEpilogueLambda _rowEpilog_op) - : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), - xn(_xn), - yn(_yn), - yBase(_y), - dOutput(_dOutput), - smem(_smem), - distance_op(_distance_op), - epilog_op(_epilog_op), - fin_op(_fin_op), - rowEpilog_op(_rowEpilog_op), - grid_stride_m(P::Mblk * gridDim.y), - grid_stride_n(P::Nblk * gridDim.x), - grid_offset_m(P::Mblk * blockIdx.y), - grid_offset_n(P::Nblk * blockIdx.x) - { - } - - DI void run() - { - for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { - this->ldgXY(tile_idx_m, grid_offset_n, 0); - for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { - // Prolog: - reset_accumulator(); - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - - // Main loop: - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(tile_idx_m, tile_idx_n, kidx); - // Process all data in shared memory (previous k-block) and - // accumulate in registers. - accumulate(); - this->stsXY(); - __syncthreads(); - this->switch_write_buffer(); - this->switch_read_buffer(); - } - accumulate(); // last iteration - // The pre-condition for the loop over tile_idx_n is that write_buffer - // and read_buffer point to the same buffer. This flips read_buffer back - // so that it satisfies the pre-condition of this loop. - this->switch_read_buffer(); - - // Epilog: - if (distance_op.use_norms) { - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; - load_norms(tile_idx_m, tile_idx_n, regxn, regyn); - // Overlap ldg with epilog computation - ldgNextGridStride(tile_idx_m, tile_idx_n); - // Calculate distance_op epilog. - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); - // And any possible additional epilogs - epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); - } else { - // Overlap ldg with epilog computation - ldgNextGridStride(tile_idx_m, tile_idx_n); - // Calculate distance_op epilog. - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); - // And any possible additional epilogs - epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); - } - if (writeOut) { store_output(tile_idx_m, tile_idx_n); } - } - rowEpilog_op(tile_idx_m); - } - } - - private: - DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n) - { - // Fetch next grid stride ldg if within range - const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; - const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m; - if ((next_tile_tile_idx_n) < this->n) { - this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0); - } else if ((next_tile_tile_idx_m) < this->m) { - this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0); - } - } - - DI void reset_accumulator() - { - // Reset accumulator registers to zero. -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = BaseClass::Zero(); - } - } - } - - DI void accumulate_reg_tile(DataT (®_x)[P::AccRowsPerTh][P::Veclen], - DataT (®_y)[P::AccColsPerTh][P::Veclen]) - { -#pragma unroll - for (int v = 0; v < P::Veclen; ++v) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]); - } - } - } - } - - DI void accumulate() - { - // We have a separate ldsXY and accumulate_reg_tile outside the loop body, - // so that these separated calls can be interspersed with preceding and - // following instructions, thereby hiding latency. - this->ldsXY(0); - - // If expensive inner loop, do not unroll loop. - constexpr int num_iterations = P::Kblk / P::Veclen - 1; - constexpr int unroll_count = decltype(distance_op)::expensive_inner_loop ? 1 : num_iterations; -#pragma unroll unroll_count - for (int ki = P::Veclen; ki < P::Kblk; ki += P::Veclen) { - accumulate_reg_tile(this->regx, this->regy); - this->ldsXY(ki); - } - - // Accumulate last loaded tile. - accumulate_reg_tile(this->regx, this->regy); - } - - DI void load_norms(IdxT tile_idx_m, - IdxT tile_idx_n, - DataT (®xn)[P::AccRowsPerTh], - DataT (®yn)[P::AccColsPerTh]) - { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - if (tile_idx_n == blockIdx.x * P::Nblk) { - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = tile_idx_m + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } - } - - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = tile_idx_n + i; - syNorm[i] = idx < this->n ? yn[idx] : 0; - } - __syncthreads(); - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } -#pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; - } - } - - DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n) - { - IdxT starty = tile_idx_m + this->accrowid; - IdxT startx = tile_idx_n + this->acccolid; - -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rowId = starty + i * P::AccThRows; -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto colId = startx + j * P::AccThCols; - if (rowId < this->m && colId < this->n) { - // Promote to 64 bit index for final write, as output array can be > 2^31 - dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); - } - } - } - } -}; // struct PairwiseDistances - -template -dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func) -{ - int devId; - RAFT_CUDA_TRY(cudaGetDevice(&devId)); - int numSMs; - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, devId)); - - int numBlocksPerSm = 0; - dim3 grid; - - RAFT_CUDA_TRY( - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&numBlocksPerSm, func, P::Nthreads, sMemSize)); - std::size_t minGridSize = numSMs * numBlocksPerSm; - std::size_t yChunks = raft::ceildiv(m, P::Mblk); - std::size_t xChunks = raft::ceildiv(n, P::Nblk); - grid.y = yChunks > minGridSize ? minGridSize : yChunks; - grid.x = (minGridSize - grid.y) <= 0 ? 1 : xChunks; - if (grid.x != 1) { - std::size_t i = 1; - while (grid.y * i < minGridSize) { - i++; - } - grid.x = i >= xChunks ? xChunks : i; - } - - return grid; -} - -}; // namespace detail -}; // namespace distance -}; // namespace raft diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh deleted file mode 100644 index 1cc272f74e..0000000000 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ /dev/null @@ -1,185 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#pragma GCC diagnostic push -#pragma GCC diagnostic ignored "-Wstrict-aliasing" -#pragma GCC diagnostic ignored "-Wtautological-compare" - -// We define CUTLASS_NAMESPACE in case -// RAFT cmake is not used -#ifndef CUTLASS_NAMESPACE -#define cutlass raft_cutlass -#endif - -#include "pairwise_distance_epilogue_elementwise.h" -#include "pairwise_distance_gemm.h" - -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft { -namespace distance { -namespace detail { - -template -std::enable_if_t::value> cutlassDistanceKernel(const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - OpT distance_op, - cudaStream_t stream) -{ - static_assert(!(std::is_same::value), - "OutType bool is not supported use uint8_t instead"); - - auto dist_op = distance_op.get_cutlass_op(); - using DistanceFn = decltype(dist_op); - using EpilogueOutputOp = - cutlass::epilogue::thread::PairwiseDistanceEpilogueElementwise; - constexpr int batch_count = 1; - - constexpr auto mode = cutlass::gemm::GemmUniversalMode::kGemm; - - typename EpilogueOutputOp::Params epilog_op_param(dist_op, fin_op); - - // Number of pipelines you want to use - constexpr int NumStages = 3; - // Alignment - constexpr int Alignment = VecLen; - - using cutlassDistKernel = - typename cutlass::gemm::kernel::PairwiseDistanceGemm::GemmKernel; - - using cutlassDist = cutlass::gemm::device::GemmUniversalAdapter; - - constexpr uint32_t gridYZMax = ((1 << (sizeof(uint16_t) * 8)) - 1); - constexpr uint32_t max_batch_size = gridYZMax * cutlassDistKernel::ThreadblockShape::kN; - IdxT numNbatches = (n - 1 + max_batch_size) / max_batch_size; - - for (IdxT i = 0; i < numNbatches; i++) { - const DataT *a, *b; - IdxT gemm_lda, gemm_ldb; - size_t offsetN = i * max_batch_size; - - if constexpr (isRowMajor) { - gemm_lda = ldb; - gemm_ldb = lda; - a = y + offsetN * gemm_lda; - b = x; - } else { - gemm_lda = lda; - gemm_ldb = ldb; - a = x; - b = y + offsetN; - } - IdxT chunkN = (i + 1) * max_batch_size; - IdxT currentN = (chunkN < n) ? max_batch_size : (n - offsetN); - - // default initialize problem size with row major inputs - auto problem_size = isRowMajor ? cutlass::gemm::GemmCoord(currentN, m, k) - : cutlass::gemm::GemmCoord(m, currentN, k); - - typename cutlassDist::Arguments arguments{ - mode, - problem_size, - batch_count, - epilog_op_param, - a, - b, - xn, // C matrix eq vector param, which here is A norm - nullptr, // tensor_Z, - (DataT*)yn + offsetN, // this is broadcast vec, which is required to be non-const param - dOutput + offsetN, // Output distance matrix - (int64_t)0, // batch stride A - (int64_t)0, // batch stride B - (int64_t)0, // batch stride Norm A - (int64_t)0, - (int64_t)0, // batch stride Norm B - (int64_t)0, // batch stride Output - gemm_lda, // stride A - gemm_ldb, // stride B - 1, // stride A norm - 0, // this is no-op for Z - 0, // This must be zero - ldd // stride Output matrix - }; - - // Using the arguments, query for extra workspace required for matrix multiplication computation - size_t workspace_size = cutlassDist::get_workspace_size(arguments); - // Allocate workspace memory - rmm::device_uvector workspace(workspace_size, stream); - // Instantiate CUTLASS kernel depending on templates - cutlassDist cutlassDist_op; - // Check the problem size is supported or not - RAFT_CUTLASS_TRY(cutlassDist_op.can_implement(arguments)); - - // Initialize CUTLASS kernel with arguments and workspace pointer - RAFT_CUTLASS_TRY(cutlassDist_op.initialize(arguments, workspace.data(), stream)); - - // Launch initialized CUTLASS kernel - RAFT_CUTLASS_TRY(cutlassDist_op(stream)); - } -} - -}; // namespace detail -}; // namespace distance -}; // namespace raft - -#pragma GCC diagnostic pop diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h deleted file mode 100644 index 6ead09ed16..0000000000 --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue.h +++ /dev/null @@ -1,100 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - -This is adapted from DefaultEpilogueWithBroadcastTensorOp from CUTLASS 2.9.0 -(https://github.com/NVIDIA/cutlass/blob/master/include/cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h#L75) - -This epilogue allows us to load norm buffers using PredicatedTileIteratorNormVec -and EpilogueWithBroadcast used for distances L2/cosine as well as applies user-define elementwise -operation. --- A norm load is provided PredicatedTileIteratorNormVec --- B norm load is provided by EpilogueWithBroadcast --- elementwise operation is provided by OutputOp -*/ - -#pragma once - -#include "./predicated_tile_iterator_normvec.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Defines sensible defaults for epilogues for TensorOps. -template -struct PairwiseDistanceEpilogue { - /// Use defaults related to the existing epilogue - using Base = - DefaultEpilogueTensorOp; - - // - // Stores the result z = (y = GEMM(A, B, C), broadcast) - // - using OutputTileIterator = cutlass::epilogue::threadblock:: - PredicatedTileIteratorNormVec; - - // - // Additional tensor tile iterator - stores t = Elementwise(z) - // - using TensorTileIterator = - cutlass::epilogue::threadblock::PredicatedTileIterator; - - /// Define the epilogue - using Epilogue = EpilogueWithBroadcast; -}; - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h b/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h deleted file mode 100644 index 2b2c04b9d3..0000000000 --- a/cpp/include/raft/distance/detail/pairwise_distance_epilogue_elementwise.h +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// -/*! \file - \brief Functor performing distance operations used by epilogues of pairwise distance - * kernels. -* This is adapted from LinearCombinationBiasElementwise from CUTLASS 2.9.0 -* customized for applying elementwise distance formula on accumulated GEMM value -* and applying user-defined final custom operation on the distance value. -*/ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace epilogue { -namespace thread { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// This base class is meant to define the concept required of the -/// EpilogueWithBroadcast::OutputOp -template -class PairwiseDistanceEpilogueElementwise { - public: - using ElementOutput = ElementC_; - using ElementC = ElementC_; - using ElementAccumulator = ElementAccumulator_; - using ElementCompute = ElementCompute_; - using ElementZ = ElementZ_; - using ElementT = ElementT_; - static int const kElementsPerAccess = ElementsPerAccess; - static int const kCount = kElementsPerAccess; - - using DistanceOp = DistanceOp_; - using FinalOp = FinalOp_; - - using FragmentAccumulator = Array; - using FragmentCompute = Array; - using FragmentC = Array; - using FragmentZ = Array; - using FragmentT = Array; - - using FragmentOutput = FragmentZ; - - static bool const kIsHeavy = false; // ElementwiseOp::kIsHeavy; - - /// If true, the 'Z' tensor is stored - static bool const kStoreZ = false; // We don't store anything in Z, - - /// If true, the 'T' tensor is stored - static bool const kStoreT = true; // this is our final output storage. - - /// Host-constructable parameters structure - struct Params { - FinalOp_ final_op_; - DistanceOp_ dist_op_; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params(DistanceOp_ dist_op, FinalOp final_op) : final_op_(final_op), dist_op_(dist_op) {} - - CUTLASS_HOST_DEVICE - Params() {} - }; - - private: - // - // Data members - // - FinalOp_ final_op; - DistanceOp_ elementwise_op; - - public: - // - // Methods - // - - /// Constructor from Params - CUTLASS_HOST_DEVICE - PairwiseDistanceEpilogueElementwise(Params const& params) - : final_op(params.final_op_), elementwise_op(params.dist_op_) - { - } - - /// Returns true if source is needed - CUTLASS_HOST_DEVICE - bool is_source_needed() const - { - // we use for making sure C matrix path is used for A mat norm. - return true; - } - - /// Functionally required for serial reduction in the epilogue - CUTLASS_HOST_DEVICE - void set_k_partition(int k_partition, int k_partition_count) {} - - /// Applies the operation when is_source_needed() is true - CUTLASS_HOST_DEVICE - void operator()(FragmentZ& frag_Z, - FragmentT& frag_T, - FragmentAccumulator const& AB, - FragmentC const& frag_C, - FragmentCompute const& V) const - { - FragmentCompute tmp_Accum = - NumericArrayConverter()(AB); - FragmentCompute tmp_C = - NumericArrayConverter()(frag_C); - FragmentCompute result_Z; - FragmentCompute result_T; - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kElementsPerAccess; ++i) { - result_Z[i] = elementwise_op(tmp_C[i], V[i], tmp_Accum[i]); - result_T[i] = final_op(result_Z[i], 0); - } - - NumericArrayConverter convert_t; - frag_T = convert_t(result_T); - } - - /// Applies the operation when is_source_needed() is false - CUTLASS_HOST_DEVICE - void operator()(FragmentZ& frag_Z, - FragmentT& frag_T, - FragmentAccumulator const& AB, - FragmentCompute const& V) const - { - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace thread -} // namespace epilogue -} // namespace cutlass - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h b/cpp/include/raft/distance/detail/pairwise_distance_gemm.h deleted file mode 100644 index aaf2689dab..0000000000 --- a/cpp/include/raft/distance/detail/pairwise_distance_gemm.h +++ /dev/null @@ -1,238 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "./pairwise_distance_epilogue.h" - -#include -#include -#include -#include -#include - -///////////////////////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { -namespace gemm { -namespace kernel { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - /// Element type for A matrix operand - typename ElementA_, - /// Layout type for A matrix operand - int kAlignmentA, - /// Element type for B matrix operand - typename ElementB_, - /// Layout type for B matrix operand - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Element type for final output - // typename ElementOutT, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Number of stages used in the pipelined mainloop - int Stages, - /// data layout row/column major of inputs - bool isRowMajor> -struct PairwiseDistanceGemm { - // This struct is specialized for fp32/3xTF32 - - /// Threadblock-level tile size (concept: GemmShape) - using ThreadblockShape = - cutlass::gemm::GemmShape<128, 128, 16>; // <- threadblock tile M = 128, N = 128, K = 16 - /// Warp-level tile size (concept: GemmShape) - // This code section describes tile size a warp will compute - using WarpShape = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = 64, N = 64, K = 16 - /// Warp-level tile size (concept: GemmShape) - // This code section describes the size of MMA op - using InstructionShape = - cutlass::gemm::GemmShape<16, 8, 4>; // <- MMA Op tile M = 16, N = 8, K = 4 - - /// Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAddFastF32; - - // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU - // SM - using OperatorClass = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using ArchTag = cutlass::arch::Sm80; - - // This code section describes how threadblocks are scheduled on GPU - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - - /// data layout for final output matrix. - // we keep this same layout even for column major inputs - using LayoutOutput = cutlass::layout::RowMajor; - - typedef typename std::conditional::type NormXLayout; - - typedef typename std:: - conditional::type LayoutA_; - - typedef typename std:: - conditional::type LayoutB_; - - using GemmBase = typename DefaultGemmUniversal::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementAccumulator, - typename EpilogueOutputOp::ElementT, - ElementAccumulator, - EpilogueOutputOp, - NormXLayout, - GemmBase::Epilogue::kElementsPerAccess>::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = GemmWithFusedEpilogue; -}; - -template < - /// Layout type for A matrix operand - int kAlignmentA, - /// Layout type for B matrix operand - int kAlignmentB, - /// Element type for C and D matrix operands - typename ElementC_, - /// Element type for internal accumulation - typename ElementAccumulator, - /// Epilogue output operator - must satisfy concept of 'EpilogueWithBroadcastOp' - typename EpilogueOutputOp, - /// Number of stages used in the pipelined mainloop - int Stages, - /// data layout row/column major of inputs - bool isRowMajor> -struct PairwiseDistanceGemm { - // using Transform = cutlass::ComplexTransform::kNone; - // Threadblock-level tile size (concept: GemmShape) - using ThreadblockShape = - cutlass::gemm::GemmShape<64, 64, 16>; // <- threadblock tile M = 64, N = 64, K = 16 - /// Warp-level tile size (concept: GemmShape) - // This code section describes tile size a warp will compute - using WarpShape = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = 32, N = 32, K = 16 - /// Warp-level tile size (concept: GemmShape) - // This code section describes the size of MMA op - using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; - - // Operation performed by GEMM - using Operator = cutlass::arch::OpMultiplyAdd; - // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU - // SM - using OperatorClass = cutlass::arch::OpClassTensorOp; - - // This code section describes CUDA SM architecture number - using ArchTag = cutlass::arch::Sm80; - - // This code section describes how threadblocks are scheduled on GPU - /// Threadblock-level swizzling operator - using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; - - /// data layout for final output matrix. - // we keep this same layout even for column major inputs - using LayoutOutput = cutlass::layout::RowMajor; - - typedef typename std::conditional::type NormXLayout; - - typedef typename std:: - conditional::type LayoutA_; - - typedef typename std:: - conditional::type LayoutB_; - - using GemmBase = typename DefaultGemmUniversal::GemmKernel; - - // Replace epilogue - using Epilogue = typename cutlass::epilogue::threadblock::PairwiseDistanceEpilogue< - typename GemmBase::Epilogue::Shape, - typename GemmBase::Epilogue::WarpMmaOperator, - GemmBase::Epilogue::kPartitionsK, - ElementC_, - typename EpilogueOutputOp::ElementT, - ElementC_, - EpilogueOutputOp, - NormXLayout, - GemmBase::Epilogue::kElementsPerAccess>::Epilogue; - - // Compose the GEMM kernel - using GemmKernel = GemmWithFusedEpilogue; -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace kernel -} // namespace gemm -} // namespace cutlass \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh deleted file mode 100644 index bced721ec8..0000000000 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh +++ /dev/null @@ -1,198 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // raft::identity_op -#include // ops::* -#include // ops::has_cutlass_op -#include // rbf_fin_op -#include // pairwise_matrix_params -#include // RAFT_EXPLICIT - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft::distance::detail { - -template -void pairwise_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) RAFT_EXPLICIT; - -}; // namespace raft::distance::detail - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ - OpT, DataT, AccT, OutT, FinOpT, IdxT) \ - extern template void raft::distance::detail:: \ - pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ - OpT distance_op, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - const DataT* x, \ - const DataT* y, \ - const DataT* x_norm, \ - const DataT* y_norm, \ - OutT* out, \ - FinOpT fin_op, \ - cudaStream_t stream, \ - bool is_row_major) - -/* - * Hierarchy of instantiations: - * - * This file defines extern template instantiations of the distance kernels. The - * instantiation of the public API is handled in raft/distance/distance-ext.cuh. - * - * After adding an instance here, make sure to also add the instance there. - */ - -// The following two instances are used in the RBF kernel object. Note the use of int64_t for the -// index type. -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::l2_unexp_distance_op, - float, - float, - float, - raft::distance::kernels::detail::rbf_fin_op, - int64_t); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::l2_unexp_distance_op, - double, - double, - double, - raft::distance::kernels::detail::rbf_fin_op, - int64_t); - -// Rest of instances -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::canberra_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::correlation_distance_op, - float, - float, - float, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::correlation_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::dice_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::dice_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::hamming_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::hellinger_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::hellinger_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::jensen_shannon_distance_op, - float, - float, - float, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::jensen_shannon_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::kl_divergence_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::kl_divergence_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::l1_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::l1_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::l2_exp_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::l2_unexp_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::l2_unexp_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::l_inf_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::l_inf_distance_op, double, double, double, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::lp_unexp_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::lp_unexp_distance_op, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::russel_rao_distance_op, float, float, float, raft::identity_op, int); -instantiate_raft_distance_detail_pairwise_matrix_dispatch( - raft::distance::detail::ops::russel_rao_distance_op, - double, - double, - double, - raft::identity_op, - int); - -#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh deleted file mode 100644 index fd9d444662..0000000000 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh +++ /dev/null @@ -1,127 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -/* This file has two responsibilities: - * - * 1. Dispatch to the correct implementation of a kernel based on the - * architecture of the device on which the kernel will be launched. For - * instance, the cosine distance has a CUTLASS-based implementation that can - * be used on SM80+ and the normal implementation that is used on older - * architectures. - * - * 2. Provide concise function templates that can be instantiated in - * src/distance/detail/pairwise_matrix/. Previously, - * raft::distance::detail::distance was instantiated. The function - * necessarily required a large set of include files, which slowed down the - * build. The raft::distance::detail::pairwise_matrix_arch_dispatch functions - * do not require as large an include files set, which speeds up the build. - */ - -#include // ops::has_cutlass_op -#include // dispatch_sm60 -#include // pairwise_matrix_params -#include // raft::util::arch::SM_* - -// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. -// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). -// Therefore, it is the including file's responsibility to include the correct -// dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh -// and src/distance/detail/pairwise_matrix/dispatch_*.cu. - -namespace raft::distance::detail { - -// This forward-declaration ensures that we do not need to include -// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling -// all the non-CUTLASS based distance instantiations faster. For CUTLASS-based -// distances, dispatch_sm80.cuh has to be included by the file including this -// file. -template -void pairwise_matrix_sm80_dispatch(OpT, - pairwise_matrix_params, - SM_compat_t, - cudaStream_t); - -template -void pairwise_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // Create kernel parameter struct. Flip x and y if column major. - IdxT ldx = is_row_major ? k : m; - IdxT ldy = is_row_major ? k : n; - IdxT ld_out = is_row_major ? n : m; - - pairwise_matrix_params params{ - m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; - - if (!params.is_row_major) { params.flip_x_and_y(); } - - // Dispatch rule: - // - execute CUTLASS-based kernel on SM_80 and above - // - execute normal kernel below SM_80 - namespace arch = raft::util::arch; - - constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); - - if constexpr (cutlass_op_unavailable) { - // Always execute legacy kernels when no cutlass op is available - auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); - pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); - } else { - auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); - auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); - - // Get pointer to SM60 kernel to determine the best compute architecture - // out of all for which the kernel was compiled for that matches closely - // to the current device. Other methods to determine the architecture (that do not - // require a pointer) can be error prone. See: - // https://github.com/NVIDIA/cub/issues/545 - auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); - void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); - auto runtime_arch = arch::kernel_virtual_arch(kernel_ptr); - - if (cutlass_range.contains(runtime_arch)) { - // If device is SM_80 or later, use CUTLASS-based kernel. - pairwise_matrix_sm80_dispatch(distance_op, params, cutlass_range, stream); - } else { - // Reuse kernel wrapper that we obtained above. This avoids performing the - // dispatch twice. - sm60_wrapper.launch(distance_op, params, stream); - } - } -} - -}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh deleted file mode 100644 index 4a52b7ebe7..0000000000 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "dispatch-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "dispatch-ext.cuh" -#endif diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh deleted file mode 100644 index 95926e5c9a..0000000000 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh +++ /dev/null @@ -1,117 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // RAFT_EXPECTS -#include // pairwise_matrix_params - -#include // std::min -#include // size_t -#include // std::integral_constant -namespace raft::distance::detail { - -/** - * @brief: Computes minimal common alignment of the rows in a 2D array in bytes - * - * The 2D matrix `x` is assumed to be row-major. This function computes the - * minimal alignment in bytes of the first elements of each row. - * Output can be 16, 8, 4, 2, 1. - * - * @param x Base pointer of row-major input matrix - * @param stride Stride in number of element between consecutive rows. - */ -template -size_t alignment_of_2d_array(const DataT* x, size_t stride) -{ - auto base = reinterpret_cast(x); - size_t stride_bytes = sizeof(DataT) * stride; - - for (int align = 16; align >= 0; align /= 2) { - bool base_aligned = base % align == 0; - bool stride_aligned = stride_bytes % align == 0; - if (base_aligned && stride_aligned) { return align; } - } - return 1; -} - -/** - * @brief: Computes the vec_len parameter kernel policy parameter - * - * @param params Kernel parameters - */ -template -int determine_vec_len(pairwise_matrix_params params) -{ - size_t align_x = alignment_of_2d_array(params.x, params.ldx); - size_t align_y = alignment_of_2d_array(params.y, params.ldy); - size_t byte_alignment = min(align_x, align_y); - - // Since alignment is in bytes, it could be smaller than sizeof(DataT). - // Handle this (unlikely) case here. - RAFT_EXPECTS(sizeof(DataT) <= byte_alignment, - "Input matrix must be aligned to size of elements."); - - // Compute number of elements that can be loaded in one instruction - // without causing misalignent errors. - int vec_len_aligned = (byte_alignment % sizeof(DataT) == 0) ? byte_alignment / sizeof(DataT) : 1; - - // In the future, pairwise_matrix might support `int8_t` input. In that case, - // byte_alignment / sizeof(DataT) might exceed 4. We maximize at 4 here, to - // prevent adding more cases in dispatch_layout below (which are expensive to - // compile). - vec_len_aligned = std::min(vec_len_aligned, 4); - - return vec_len_aligned; -} - -template -using vec_len_constant = std::integral_constant; - -/** - * @brief: Converts run-time arguments to compile-time arguments - * - * Converts run-time arguments row_major and vec_len to compile-time arguments - * and dispatches a lambda f with these compile-time arguments. - * - * This is equivalent to copying and pasting the lambda function `f` in each of - * the switch case statements. - * - * @tparam F Type of lambda f. - * @param row_major Boolean indicating whether input arrays have row-major layout. - * @param vec_len Integer value 1, 2, or 4 specifying the Veclen template parameter of - * the KernelPolicy. - * @param f Lambda that takes two std::integral_constant parameters representing - * row_major and vec_len. - */ -template -auto dispatch_layout(bool row_major, int vec_len, F&& f) -{ - if (row_major) { - switch (vec_len) { - case 4: return f(std::true_type(), vec_len_constant<4>()); - case 2: return f(std::true_type(), vec_len_constant<2>()); - default: return f(std::true_type(), vec_len_constant<1>()); - } - } else { - switch (vec_len) { - case 4: return f(std::false_type(), vec_len_constant<4>()); - case 2: return f(std::false_type(), vec_len_constant<2>()); - default: return f(std::false_type(), vec_len_constant<1>()); - } - } -} - -}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh deleted file mode 100644 index fc8f594275..0000000000 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh +++ /dev/null @@ -1,85 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // dispatch_layout -#include // pairwise_matrix_sm60_wrapper -#include // raft::linalg::Policy4x4 - -#include // std::min - -namespace raft::distance::detail { - -template -pairwise_matrix_sm60_wrapper pairwise_matrix_sm60_get_wrapper( - OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range) -{ - int vec_len = determine_vec_len(params); - - // f takes compile-time constants row_major and vec_len aligned and returns - // the corresponding kernel wrapper. The wrapper contains the launch - // parameters of the kernel: a pointer to the kernel function, grid size, - // block size, and shared memory size. - auto f = [&](auto row_major, auto vec_len_aligned) { - // row_major and vec_len are std::integral_constants of type bool and int - // respectively. - - // To keep compile times in check, we only specialize on veclen > 1 when - // the inner loop is relatively cheap (< 5 flops). - constexpr int vec_len_op = distance_op.expensive_inner_loop ? 1 : vec_len_aligned(); - - // Prevent double, vec_len=4 combination (this is not supported) - constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); - - using RowPolicy = typename raft::linalg::Policy4x4::Policy; - using ColPolicy = typename raft::linalg::Policy4x4::ColPolicy; - using Policy = typename std::conditional::type; - - auto wrapper = - make_pairwise_matrix_sm60_wrapper(distance_op, params, sm_compat_range); - - return wrapper; - }; - - // Dispatch_layout calls f with appropriate compile time constants based on - // the runtime values of params.is_row_major and vec_len. - return dispatch_layout(params.is_row_major, vec_len, f); -} - -template -void pairwise_matrix_sm60_dispatch(OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range, - cudaStream_t stream) -{ - auto wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, sm_compat_range); - - wrapper.launch(distance_op, params, stream); -} - -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh deleted file mode 100644 index e5942653d8..0000000000 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm80.cuh +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // cutlassDistanceKernel -#include // dispatch_layout - -#include // std::min - -namespace raft::distance::detail { - -template -void pairwise_matrix_sm80_dispatch(OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range, - cudaStream_t stream) -{ - int vec_len = determine_vec_len(params); - - // f takes compile-time constants row_major and vec_len aligned and runs the - // corresponding cutlass launch code. - auto f = [&](auto row_major, auto vec_len_aligned) { - // row_major and vec_len are std::integral_constants of type bool and int - // respectively. - - // Prevent double, vec_len=4 combination (this is not supported) - constexpr int vec_len = std::min(vec_len_aligned(), static_cast(16 / sizeof(DataT))); - - using AccT = typename OpT::AccT; - cutlassDistanceKernel(params.x, - params.y, - params.x_norm, - params.y_norm, - params.m, - params.n, - params.k, - params.ldx, - params.ldy, - params.ld_out, - params.out, - params.fin_op, - distance_op, - stream); - }; - - // Dispatch_layout calls f with appropriate compile time constants based on - // the runtime values of params.is_row_major and vec_len. - dispatch_layout(params.is_row_major, vec_len, f); -} - -}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh deleted file mode 100644 index cd9f497a06..0000000000 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // raft::void_op -#include // PairwiseDistances -#include // pairwise_matrix_params -#include // raft::util::arch::SM_compute_arch - -#include // assert - -namespace raft::distance::detail { - -template -__launch_bounds__(Policy::Nthreads, 2) RAFT_KERNEL - pairwise_matrix_kernel(OpT distance_op, pairwise_matrix_params params) -{ - // Early exit to minimize the size of the kernel when it is not supposed to be compiled. - constexpr SM_compat_t sm_compat_range{}; - if constexpr (!sm_compat_range.contains(raft::util::arch::SM_compute_arch())) { - assert(false); - return; - } - - extern __shared__ char smem[]; - - // The epilog is already provided by distance_op. Do not provide additional - // epilogs. - auto epilog_op = raft::void_op(); - // No support for row_epilog_op. - auto row_epilog_op = raft::void_op(); - - // Always write output - constexpr bool write_out = true; - constexpr bool use_norms = distance_op.use_norms; - PairwiseDistances - obj(params.x, - params.y, - params.m, - params.n, - params.k, - params.ldx, - params.ldy, - params.ld_out, - params.x_norm, - params.y_norm, - params.out, - smem, - distance_op, - epilog_op, - params.fin_op, - row_epilog_op); - obj.run(); -} - -// The type of a pointer to the pairwise matrix kernel. The following template -// arguments are type-erased: -// -// - The kernel policy -// - row_major -// - SM_compat_t -template -using pairwise_matrix_kernel_t = void (*)(OpT, pairwise_matrix_params); - -// A wrapper for the pairwise matrix kernel launch. Includes kernel launch -// parameters. -template -struct pairwise_matrix_sm60_wrapper { - dim3 grid; - dim3 block; - int smem_size; - pairwise_matrix_kernel_t kernel_ptr; - - void launch(OpT distance_op, - pairwise_matrix_params params, - cudaStream_t stream) - { - kernel_ptr<<>>(distance_op, params); - RAFT_CUDA_TRY(cudaGetLastError()); - } -}; - -/** @brief: Create kernel launch wrapper for pairwise matrix kernel - * - * This can be used to type-erase the kernel execution policy, row_major, and SM - * compatibility range. - * - * @tparam Policy: Kernel execution policy - * @tparam row_major: Indicates whether input matrices are row major - * @tparam OpT: Type of distance operation - * @tparam IdxT: Index type - * @tparam DataT: Data type - * @tparam OutT: Output data type - * @tparam FinOpT: Final operation type - * @tparam SM_compat_t: Type of the SM architecture compatibility - * - * @param distance_op: Distance operation - * @param params: Parameters - * @param sm_compat_range: Which SM architectures to compile for. - */ -template -pairwise_matrix_sm60_wrapper make_pairwise_matrix_sm60_wrapper( - OpT distance_op, - pairwise_matrix_params params, - SM_compat_t sm_compat_range) -{ - dim3 block(Policy::Nthreads); - // Use ::template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_size = OpT::template shared_mem_size(); - // Obtain function pointer to kernel - auto kernel = - pairwise_matrix_kernel; - dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); - - return pairwise_matrix_sm60_wrapper{ - grid, block, smem_size, kernel}; -} - -}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh deleted file mode 100644 index 005b95afe9..0000000000 --- a/cpp/include/raft/distance/detail/pairwise_matrix/params.cuh +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -namespace raft::distance::detail { - -template -struct pairwise_matrix_params { - IdxT m; - IdxT n; - IdxT k; - IdxT ldx; - IdxT ldy; - IdxT ld_out; - const DataT* x; - const DataT* y; - const DataT* x_norm; - const DataT* y_norm; - OutT* out; - FinOpT fin_op; - bool is_row_major; - - /// @brief: Flips the x and y input and corresponding sizes - void flip_x_and_y() - { - // Flip m, n; ldx, ldy; x, y; x_norm, y_norm. - std::swap(m, n); - std::swap(ldx, ldy); - std::swap(x, y); - std::swap(x_norm, y_norm); - } -}; - -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h b/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h deleted file mode 100644 index 951f8a0132..0000000000 --- a/cpp/include/raft/distance/detail/predicated_tile_iterator_normvec.h +++ /dev/null @@ -1,585 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/*! \file - \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. - -This file contains a customized version of PredicatedTileIterator from CUTLASS 2.9.0 -(https://github.com/NVIDIA/cutlass/blob/v2.9.0/include/cutlass/epilogue/threadblock/predicated_tile_iterator.h#L75) - -Changes: -- added `Layout_` template param -- Only the row index is used to load the data in load_with_byte_offset(). - This way the same normalization data is used across all columns in a row. - -*/ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -//////////////////////////////////////////////////////////////////////////////// - -namespace cutlass { - -//////////////////////////////////////////////////////////////////////////////// - -namespace epilogue { -namespace threadblock { - -//////////////////////////////////////////////////////////////////////////////// - -/// Tile iterator used to load and store output tile from global memory in epilogue. -/// -/// Satisfies: ReadableTileIterator | PredicatedTileIterator | ForwardTileIterator -/// -template -class PredicatedTileIteratorNormVec { - public: - using ThreadMap = ThreadMap_; - using Shape = typename ThreadMap::Shape; - - using Element = Element_; - - using Layout = Layout_; - using TensorRef = TensorRef; - using ConstTensorRef = typename TensorRef::ConstTensorRef; - - using Index = typename Layout::Index; - using LongIndex = typename Layout::LongIndex; - using TensorCoord = MatrixCoord; - - static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; - static int const kThreads = ThreadMap::kThreads; - static int const kIterations = ThreadMap::Count::kTile; - - static_assert(ThreadMap::Iterations::kRow > 0, "ThreadMap::Iterations::kRow must be > 0"); - static_assert(ThreadMap::Iterations::kGroup > 0, "ThreadMap::Iterations::kGroup must be > 0"); - static_assert(ThreadMap::Iterations::kCluster > 0, "ThreadMap::Iterations::kCluster must be > 0"); - static_assert(ThreadMap::Iterations::kColumn > 0, "ThreadMap::Iterations::kColumn must be > 0"); - - /// Fragment object - using Fragment = Array; - - /// Memory access size - using AccessType = AlignedArray; - - // - // Parameters struct - // - - /// Uses a non-template class - struct Params : PredicatedTileIteratorParams { - using Base = PredicatedTileIteratorParams; - - CUTLASS_HOST_DEVICE - Params() {} - - CUTLASS_HOST_DEVICE - Params(Layout const& layout) - : PredicatedTileIteratorParams( - layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, - make_OutputTileThreadMapDesc()) - { - } - - CUTLASS_HOST_DEVICE - Params(Base const& base) : Base(base) {} - }; - - /// Mask object - struct Mask { - static int const kCount = ThreadMap::Iterations::kColumn; - - /// Predicate state - bool predicates[kCount]; - - // - // Mask - // - CUTLASS_HOST_DEVICE - Mask() { enable(); } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_HOST_DEVICE void clear() - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = false; - } - } - - ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask - CUTLASS_DEVICE void enable() - { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kCount; ++i) { - predicates[i] = true; - } - } - }; - - private: - // - // Data members - // - - /// Parameters structure containing reference and precomputed state. - PredicatedTileIteratorParams params_; - - /// Byte-level pointer - uint8_t* byte_pointer_; - - /// Array of boolean values to contain steady-state predicates - Mask mask_; - - /// Extent of the matrix tile in rows - Index extent_row_; - - /// Extent of the matrix tile in rows - Index extent_column_; - - /// A thread's starting row position (assuming steady-state predicates have been computed) - Index thread_start_row_; - - /// A thread's starting column - Index thread_start_column_; - - /// Internal state counter - int state_[3]; - - /// Scatter indices - int const* indices_; - - // - // Static asserts about internal strides - // - - static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); - static_assert(sizeof(PredicatedTileIteratorParams::stride) == 8, "Expected 64b strides"); - - private: - // - // Methods - // - - public: - // - // Methods - // - - /// Constructor - CUTLASS_DEVICE - PredicatedTileIteratorNormVec(PredicatedTileIteratorParams const& params, - Element* pointer, - TensorCoord extent, - int thread_idx, - TensorCoord threadblock_offset = TensorCoord(), - int const* indices = nullptr) - : params_(params), indices_(indices) - { - TensorCoord thread_offset = ThreadMap::initial_offset(thread_idx) + threadblock_offset; - - extent_row_ = extent.row(); - extent_column_ = extent.column(); - - thread_start_row_ = thread_offset.row(); - thread_start_column_ = thread_offset.column(); - - // Initialize predicates - CUTLASS_PRAGMA_UNROLL - for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { - mask_.predicates[c] = - ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < extent.column()); - } - - // Null pointer performs no accesses - if (!pointer) { mask_.clear(); } - - if (ScatterD && !indices) { mask_.clear(); } - - // Initialize pointer - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.row()) * LongIndex(params_.stride); - - if (ScatterD) { - byte_pointer_ = reinterpret_cast(pointer) + - LongIndex(thread_offset.column()) * sizeof(AccessType) / kElementsPerAccess; - } - - // Initialize internal state counter - state_[0] = state_[1] = state_[2] = 0; - } - - /// Adds a pointer offset in units of Element - CUTLASS_HOST_DEVICE - void add_pointer_offset(LongIndex pointer_offset) - { - byte_pointer_ += pointer_offset * sizeof_bits::value / 8; - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const - { - uint8_t* byte_pointer = byte_pointer_; - AccessType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - if (ScatterD && row_guard) { - assert(indices_); - - memory_pointer = reinterpret_cast( - byte_pointer + byte_offset + - LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); - } - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - if (column == 0) { - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[0], - guard); - } else { - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column] = - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn]; - } - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { byte_pointer += params_.increment_row; } - } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void load(Fragment& frag) const { load_with_byte_offset(frag, 0); } - - /// Stores a fragment to memory - CUTLASS_DEVICE - void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const - { - uint8_t* byte_pointer = byte_pointer_; - AccessType const* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - if (ScatterD && row_guard) { - assert(indices_); - - memory_pointer = reinterpret_cast( - byte_pointer + byte_offset + - LongIndex(indices_[row_offset + thread_start_row_]) * LongIndex(params_.stride)); - } - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - - if (UseCUDAStore) { - if (guard) { - memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess] = - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column]; - } - } else { - cutlass::arch::global_store( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); - } - } - - if (row + 1 < ThreadMap::Iterations::kRow) { - if (!ScatterD) { byte_pointer += params_.increment_row; } - } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - /// Stores a fragment to memory - CUTLASS_DEVICE - void store(Fragment const& frag) const { store_with_byte_offset(frag, 0); } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void downsample_load_with_byte_offset(Fragment& frag, - int64_t byte_offset, - int convolution_P, - int convolution_Q, - int add_P, - int add_Q, - int problem_N) const - { - uint8_t* byte_pointer = byte_pointer_; - AccessType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - int output_row = row_offset + thread_start_row_; - int output_N = output_row / (convolution_P * convolution_Q); - int output_PQ = output_row % (convolution_P * convolution_Q); - int output_P = output_PQ / convolution_Q; - int output_Q = output_PQ % convolution_Q; - - int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + - (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; - - int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); - } - - if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - /// Loads a fragment from memory - CUTLASS_DEVICE - void upsample_load_with_byte_offset(Fragment& frag, - int64_t byte_offset, - int convolution_P, - int convolution_Q, - int add_P, - int add_Q, - int problem_N) const - { - uint8_t* byte_pointer = byte_pointer_; - AccessType* frag_ptr = reinterpret_cast(&frag); - - CUTLASS_PRAGMA_UNROLL - for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; ++cluster) { - CUTLASS_PRAGMA_UNROLL - for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { - CUTLASS_PRAGMA_UNROLL - for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { - int frag_row_idx = - (row + ThreadMap::Iterations::kRow * (group + ThreadMap::Iterations::kGroup * cluster)); - - int row_offset = row * ThreadMap::Delta::kRow + group * ThreadMap::Delta::kGroup + - cluster * ThreadMap::Delta::kCluster; - - bool row_guard = ((row_offset + thread_start_row_) < extent_row_); - - int output_row = row_offset + thread_start_row_; - int output_N = output_row / (convolution_P * convolution_Q); - int output_PQ = output_row % (convolution_P * convolution_Q); - int output_P = output_PQ / convolution_Q; - int output_Q = output_PQ % convolution_Q; - int row_add_P = add_P; - int row_add_Q = add_Q; - if (output_P > convolution_P - 2) row_add_P = 0; - if (output_Q > convolution_Q - 2) row_add_Q = 0; - - int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + - ((output_P + row_add_P) / 2) * (convolution_Q / 2) + - (output_Q + row_add_Q) / 2; - - int64_t byte_offset = (input_row - output_row) * problem_N * sizeof(float); - - AccessType* memory_pointer = reinterpret_cast(byte_pointer + byte_offset); - - CUTLASS_PRAGMA_UNROLL - for (int column = 0; column < ThreadMap::Iterations::kColumn; ++column) { - bool guard = row_guard && mask_.predicates[column]; - - cutlass::arch::global_load( - frag_ptr[frag_row_idx * ThreadMap::Iterations::kColumn + column], - (void*)&memory_pointer[column * ThreadMap::Delta::kColumn / kElementsPerAccess], - guard); - } - - if (row + 1 < ThreadMap::Iterations::kRow) { byte_pointer += params_.increment_row; } - } - - if (group + 1 < ThreadMap::Iterations::kGroup) { byte_pointer += params_.increment_group; } - } - - if (cluster + 1 < ThreadMap::Iterations::kCluster) { - byte_pointer += params_.increment_cluster; - } - } - } - - CUTLASS_DEVICE - MatrixCoord thread_start() const { return MatrixCoord(thread_start_row_, thread_start_column_); } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_row() const { return thread_start_row_; } - - /// Need to get the thread start row from the tile iterator - CUTLASS_DEVICE - int32_t thread_start_column() const { return thread_start_column_; } - - /// Extent of the matrix in rows - CUTLASS_DEVICE - Index extent_row() const { return extent_row_; } - - /// Extent of the matrix in columns - CUTLASS_DEVICE - Index extent_column() const { return extent_column_; } - - /// Advances to the next position to load or store - CUTLASS_HOST_DEVICE - PredicatedTileIteratorNormVec& operator++() - { - ++state_[0]; - - if (!ScatterD) { byte_pointer_ += params_.advance_row; } - - thread_start_row_ += ThreadMap::Shape::kRow; - - if (state_[0] == ThreadMap::Count::kRow) { - state_[0] = 0; - ++state_[1]; - byte_pointer_ += params_.advance_group; - - thread_start_row_ += - (ThreadMap::Shape::kGroup - 1) * ThreadMap::Shape::kRow * ThreadMap::Count::kRow; - - if (state_[1] == ThreadMap::Count::kGroup) { - state_[1] = 0; - ++state_[2]; - byte_pointer_ += params_.advance_cluster; - - thread_start_row_ += ThreadMap::Count::kGroup * ThreadMap::Shape::kGroup * - ThreadMap::Count::kRow * ThreadMap::Shape::kRow; - - if (state_[2] == ThreadMap::Count::kCluster) { - state_[2] = 0; - byte_pointer_ += params_.advance_tile; - } - } - } - - return *this; - } - - ///< Efficiently disables all accesses guarded by mask - CUTLASS_DEVICE void clear_mask() { mask_.clear(); } - - ///< Efficiently enables all accesses guarded by mask - CUTLASS_DEVICE void enable_mask() { mask_.enable(); } - - ///< Sets the mask - CUTLASS_DEVICE void get_mask(Mask& mask) const { mask = mask_; } - - ///< Sets the mask - CUTLASS_DEVICE void set_mask(Mask const& mask) { mask_ = mask; } -}; - -/////////////////////////////////////////////////////////////////////////////// - -} // namespace threadblock -} // namespace epilogue -} // namespace cutlass - -//////////////////////////////////////////////////////////////////////////////// diff --git a/cpp/include/raft/distance/distance-ext.cuh b/cpp/include/raft/distance/distance-ext.cuh deleted file mode 100644 index dcbfbfdbc3..0000000000 --- a/cpp/include/raft/distance/distance-ext.cuh +++ /dev/null @@ -1,1117 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include // raft::device_matrix_view -#include // raft::identity_op -#include // raft::resources -#include // rbf_fin_op -#include // raft::distance::DistanceType -#include // RAFT_EXPLICIT - -#include // rmm::device_uvector - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft { -namespace distance { - -template -[[deprecated("Use cuVS instead")]] void distance(raft::resources const& handle, - const DataT* x, - const DataT* y, - OutT* dist, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - bool isRowMajor = true, - DataT metric_arg = 2.0f) RAFT_EXPLICIT; - -template -[[deprecated("Use cuVS instead")]] void distance(raft::resources const& handle, - const DataT* x, - const DataT* y, - OutT* dist, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - size_t worksize, - bool isRowMajor = true, - DataT metric_arg = 2.0f) RAFT_EXPLICIT; - -template -[[deprecated("Use cuVS instead")]] size_t getWorkspaceSize( - const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) RAFT_EXPLICIT; - -template -size_t getWorkspaceSize(raft::device_matrix_view const& x, - raft::device_matrix_view const& y) RAFT_EXPLICIT; - -template -[[deprecated("Use cuVS instead")]] void distance(raft::resources const& handle, - const DataT* x, - const DataT* y, - OutT* dist, - IdxT m, - IdxT n, - IdxT k, - bool isRowMajor = true, - DataT metric_arg = 2.0f) RAFT_EXPLICIT; - -template -[[deprecated("Use cuVS instead")]] void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - IdxT m, - IdxT n, - IdxT k, - rmm::device_uvector& workspace, - raft::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) RAFT_EXPLICIT; - -template -[[deprecated("Use cuVS instead")]] void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - IdxT m, - IdxT n, - IdxT k, - raft::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) RAFT_EXPLICIT; - -template -[[deprecated("Use cuVS instead")]] void distance( - raft::resources const& handle, - raft::device_matrix_view const x, - raft::device_matrix_view const y, - raft::device_matrix_view dist, - DataT metric_arg = 2.0f) RAFT_EXPLICIT; - -template -[[deprecated("Use cuVS instead")]] void pairwise_distance( - raft::resources const& handle, - device_matrix_view const x, - device_matrix_view const y, - device_matrix_view dist, - raft::distance::DistanceType metric, - Type metric_arg = 2.0f) RAFT_EXPLICIT; - -}; // namespace distance -}; // namespace raft - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -/* - * Hierarchy of instantiations: - * - * This file defines the extern template instantiations for the public API of - * raft::distance. To improve compile times, the extern template instantiation - * of the distance kernels is handled in - * distance/detail/pairwise_matrix/dispatch-ext.cuh. - * - * After adding an instance here, make sure to also add the instance to - * dispatch-ext.cuh and the corresponding .cu files. - */ - -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ - extern template void raft::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - size_t worksize, \ - FinalLambda fin_op, \ - bool isRowMajor, \ - DataT metric_arg) - -// The following two instances are used in test/distance/gram.cu. Note the use -// of int64_t for the index type. -instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, - float, - float, - float, - raft::distance::kernels::detail::rbf_fin_op, - int64_t); -instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, - double, - double, - double, - raft::distance::kernels::detail::rbf_fin_op, - int64_t); - -instantiate_raft_distance_distance( - raft::distance::DistanceType::Canberra, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::Canberra, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::CorrelationExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - raft::identity_op, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::DiceExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::DiceExpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::HammingUnexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::HellingerExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::HellingerExpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::InnerProduct, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::InnerProduct, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::JensenShannon, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::JensenShannon, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::KLDivergence, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::KLDivergence, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L1, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L1, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Expanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Expanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2SqrtExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2SqrtExpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Unexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::Linf, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::Linf, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::LpUnexpanded, double, double, double, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::RusselRaoExpanded, float, float, float, raft::identity_op, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::RusselRaoExpanded, double, double, double, raft::identity_op, int); - -#undef instantiate_raft_distance_distance - -// Same, but without raft::identity_op -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ - extern template void raft::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - size_t worksize, \ - bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_distance( - raft::distance::DistanceType::Canberra, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::Canberra, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::CosineExpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::CosineExpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::DiceExpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::DiceExpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::HellingerExpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::HellingerExpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::InnerProduct, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::InnerProduct, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::JensenShannon, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::JensenShannon, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::KLDivergence, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::KLDivergence, double, double, double, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Expanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Expanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Unexpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Unexpanded, double, double, double, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::LpUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::LpUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); - -#undef instantiate_raft_distance_distance - -// Same, but without workspace -#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ - extern template void raft::distance::distance( \ - raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - OutT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_distance( - raft::distance::DistanceType::Canberra, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::Canberra, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::CosineExpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::CosineExpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::DiceExpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::DiceExpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::HellingerExpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::HellingerExpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::InnerProduct, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::InnerProduct, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::JensenShannon, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::JensenShannon, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::KLDivergence, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::KLDivergence, double, double, double, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Expanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Expanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Unexpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Unexpanded, double, double, double, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::LpUnexpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::LpUnexpanded, double, double, double, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); - -#undef instantiate_raft_distance_distance - -#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ - extern template size_t raft::distance::getWorkspaceSize( \ - const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) - -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::Canberra, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::Canberra, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::CosineExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::CosineExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::DiceExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::DiceExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::HellingerExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::HellingerExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::InnerProduct, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::InnerProduct, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::JensenShannon, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::JensenShannon, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::KLDivergence, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::KLDivergence, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L1, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L1, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Expanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Expanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Unexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Unexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::Linf, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::Linf, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::LpUnexpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::LpUnexpanded, double, double, double, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); - -#undef instantiate_raft_distance_getWorkspaceSize - -#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT, layout) \ - extern template size_t raft::distance::getWorkspaceSize( \ - raft::device_matrix_view const& x, \ - raft::device_matrix_view const& y) - -// We could consider not taking template parameters for this function. The -// number of instantiations seems a bit excessive.. -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::DiceExpanded, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::DiceExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L1, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L1, double, double, double, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L1, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L1, double, double, double, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - int, - raft::layout_f_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2Unexpanded, - double, - double, - double, - int, - raft::layout_c_contiguous); -instantiate_raft_distance_getWorkspaceSize( - raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_f_contiguous); - -#undef instantiate_raft_distance_getWorkspaceSize - -#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ - extern template void raft::distance::pairwise_distance(raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - DataT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - rmm::device_uvector& workspace, \ - raft::distance::DistanceType metric, \ - bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_pairwise_distance(float, int); -instantiate_raft_distance_pairwise_distance(double, int); - -#undef instantiate_raft_distance_pairwise_distance - -// Same, but without workspace -#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ - extern template void raft::distance::pairwise_distance(raft::resources const& handle, \ - const DataT* x, \ - const DataT* y, \ - DataT* dist, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - raft::distance::DistanceType metric, \ - bool isRowMajor, \ - DataT metric_arg) - -instantiate_raft_distance_pairwise_distance(float, int); -instantiate_raft_distance_pairwise_distance(double, int); - -#undef instantiate_raft_distance_pairwise_distance - -// Version with mdspan -#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \ - extern template void raft::distance::distance( \ - raft::resources const& handle, \ - raft::device_matrix_view const x, \ - raft::device_matrix_view const y, \ - raft::device_matrix_view dist, \ - DataT metric_arg) - -// Again, we might want to consider reigning in the number of instantiations... -instantiate_raft_distance_distance( - raft::distance::DistanceType::Canberra, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::Canberra, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::Canberra, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::Canberra, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::DiceExpanded, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::DiceExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L1, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L1, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L1, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L1, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::Linf, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::Linf, double, double, double, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::Linf, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::Linf, double, double, double, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_c_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance( - raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_f_contiguous, int); -instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, - float, - float, - float, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, - double, - double, - double, - raft::layout_c_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, - float, - float, - float, - raft::layout_f_contiguous, - int); -instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, - double, - double, - double, - raft::layout_f_contiguous, - int); - -#undef instantiate_raft_distance_distance - -#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \ - extern template void raft::distance::pairwise_distance( \ - raft::resources const& handle, \ - raft::device_matrix_view const x, \ - raft::device_matrix_view const y, \ - raft::device_matrix_view dist, \ - raft::distance::DistanceType metric, \ - DataT metric_arg) - -instantiate_raft_distance_pairwise_distance(float, raft::layout_c_contiguous, int); -instantiate_raft_distance_pairwise_distance(float, raft::layout_f_contiguous, int); -instantiate_raft_distance_pairwise_distance(double, raft::layout_c_contiguous, int); -instantiate_raft_distance_pairwise_distance(double, raft::layout_f_contiguous, int); - -#undef instantiate_raft_distance_pairwise_distance diff --git a/cpp/include/raft/distance/distance-inl.cuh b/cpp/include/raft/distance/distance-inl.cuh deleted file mode 100644 index 13c9d57efd..0000000000 --- a/cpp/include/raft/distance/distance-inl.cuh +++ /dev/null @@ -1,481 +0,0 @@ -/* - * Copyright (c) 2018-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include - -#include - -#include - -namespace raft { -namespace distance { - -/** - * @defgroup pairwise_distance pointer-based pairwise distance prims - * @{ - */ - -/** - * @brief Evaluate pairwise distances with the user epilogue lamba allowed - * @tparam DistanceType which distance to evaluate - * @tparam DataT input argument type - * @tparam AccT accumulation type - * @tparam OutT output type - * @tparam FinalLambda user-defined epilogue lamba - * @tparam IdxT Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param fin_op the final gemm epilogue lambda - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - * - * @note fin_op: This is a device lambda which is supposed to operate upon the - * input which is AccT and returns the output in OutT. It's signature is - * as follows:

OutT fin_op(AccT in, int g_idx);
. If one needs - * any other parameters, feel free to pass them via closure. - */ -template -void distance(raft::resources const& handle, - const DataT* x, - const DataT* y, - OutT* dist, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - size_t worksize, - FinalLambda fin_op, - bool isRowMajor = true, - DataT metric_arg = 2.0f) -{ - detail::distance( - handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); -} - -/** - * @brief Evaluate pairwise distances for the simple use case - * @tparam DistanceType which distance to evaluate - * @tparam DataT input argument type - * @tparam AccT accumulation type - * @tparam OutT output type - * @tparam IdxT Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace needed for computations - * @param worksize number of bytes of the workspace - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - const DataT* x, - const DataT* y, - OutT* dist, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - size_t worksize, - bool isRowMajor = true, - DataT metric_arg = 2.0f) -{ - detail::distance( - handle, x, y, dist, m, n, k, workspace, worksize, isRowMajor, metric_arg); -} - -/** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam DataT input argument type - * @tparam AccT accumulation type - * @tparam OutT output type - * @tparam IdxT Index type - * @param x first set of points - * @param y second set of points - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * - * @note If the specified DistT doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) -{ - return detail::getWorkspaceSize(x, y, m, n, k); -} - -/** - * @brief Return the exact workspace size to compute the distance - * @tparam DistanceType which distance to evaluate - * @tparam DataT input argument type - * @tparam AccT accumulation type - * @tparam OutT output type - * @tparam IdxT Index type - * @param x first set of points (size m*k) - * @param y second set of points (size n*k) - * @return number of bytes needed in workspace - * - * @note If the specified DistT doesn't need the workspace at all, it - * returns 0. - */ -template -size_t getWorkspaceSize(raft::device_matrix_view const& x, - raft::device_matrix_view const& y) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - - return getWorkspaceSize( - x.data_handle(), y.data_handle(), x.extent(0), y.extent(0), x.extent(1)); -} - -/** - * @brief Evaluate pairwise distances for the simple use case - * @tparam DistanceType which distance to evaluate - * @tparam DataT input argument type - * @tparam AccT accumulation type - * @tparam OutT output type - * @tparam IdxT Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - const DataT* x, - const DataT* y, - OutT* dist, - IdxT m, - IdxT n, - IdxT k, - bool isRowMajor = true, - DataT metric_arg = 2.0f) -{ - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); - detail::distance( - handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam IdxT indexing type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param workspace temporary workspace buffer which can get resized as per the - * needed workspace size - * @param metric distance metric - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - IdxT m, - IdxT n, - IdxT k, - rmm::device_uvector& workspace, - raft::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) -{ - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - - auto dispatch = [&](auto distance_type) { - auto worksize = getWorkspaceSize(x, y, m, n, k); - workspace.resize(worksize, stream); - detail::distance( - handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); - }; - - switch (metric) { - case DistanceType::Canberra: - dispatch(std::integral_constant{}); - break; - case DistanceType::CorrelationExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::CosineExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::HammingUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::HellingerExpanded: - dispatch(std::integral_constant{}); - break; - case raft::distance::DistanceType::InnerProduct: - dispatch(std::integral_constant{}); - break; - case DistanceType::JensenShannon: - dispatch(std::integral_constant{}); - break; - case DistanceType::KLDivergence: - dispatch(std::integral_constant{}); - break; - case DistanceType::L1: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2Expanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2SqrtExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2SqrtUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::L2Unexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::Linf: - dispatch(std::integral_constant{}); - break; - case DistanceType::LpUnexpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::RusselRaoExpanded: - dispatch(std::integral_constant{}); - break; - case DistanceType::DiceExpanded: - dispatch(std::integral_constant{}); - break; - default: THROW("Unknown or unsupported distance metric '%d'!", (int)metric); - }; -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam IdxT indexing type - * @param handle raft handle for managing expensive resources - * @param x first set of points - * @param y second set of points - * @param dist output distance matrix - * @param m number of points in x - * @param n number of points in y - * @param k dimensionality - * @param metric distance metric - * @param isRowMajor whether the matrices are row-major or col-major - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - const Type* x, - const Type* y, - Type* dist, - IdxT m, - IdxT n, - IdxT k, - raft::distance::DistanceType metric, - bool isRowMajor = true, - Type metric_arg = 2.0f) -{ - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - pairwise_distance( - handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); -} - -/** @} */ - -/** - * \defgroup distance_mdspan Pairwise distance functions - * @{ - */ - -/** - * @brief Evaluate pairwise distances for the simple use case. - * - * Note: Only contiguous row- or column-major layouts supported currently. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * #include - * #include - * - * raft::raft::resources handle; - * int n_samples = 5000; - * int n_features = 50; - * - * auto input = raft::make_device_matrix(handle, n_samples, n_features); - * auto labels = raft::make_device_vector(handle, n_samples); - * auto output = raft::make_device_matrix(handle, n_samples, n_samples); - * - * raft::random::make_blobs(handle, input.view(), labels.view()); - * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); - * @endcode - * - * @tparam DistanceType which distance to evaluate - * @tparam DataT input argument type - * @tparam AccT accumulation type - * @tparam OutT output type - * @tparam IdxT Index type - * @param handle raft handle for managing expensive resources - * @param x first set of points (size n*k) - * @param y second set of points (size m*k) - * @param dist output distance matrix (size n*m) - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void distance(raft::resources const& handle, - raft::device_matrix_view const x, - raft::device_matrix_view const y, - raft::device_matrix_view dist, - DataT metric_arg = 2.0f) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - RAFT_EXPECTS(dist.extent(0) == x.extent(0), - "Number of rows in output must be equal to " - "number of rows in X"); - RAFT_EXPECTS(dist.extent(1) == y.extent(0), - "Number of columns in output must be equal to " - "number of rows in Y"); - - RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); - RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); - - constexpr auto is_rowmajor = std::is_same_v; - - distance(handle, - x.data_handle(), - y.data_handle(), - dist.data_handle(), - x.extent(0), - y.extent(0), - x.extent(1), - is_rowmajor, - metric_arg); -} - -/** - * @brief Convenience wrapper around 'distance' prim to convert runtime metric - * into compile time for the purpose of dispatch - * @tparam Type input/accumulation/output data-type - * @tparam IdxT indexing type - * @param handle raft handle for managing expensive resources - * @param x first matrix of points (size mxk) - * @param y second matrix of points (size nxk) - * @param dist output distance matrix (size mxn) - * @param metric distance metric - * @param metric_arg metric argument (used for Minkowski distance) - */ -template -void pairwise_distance(raft::resources const& handle, - device_matrix_view const x, - device_matrix_view const y, - device_matrix_view dist, - raft::distance::DistanceType metric, - Type metric_arg = 2.0f) -{ - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - RAFT_EXPECTS(dist.extent(0) == x.extent(0), - "Number of rows in output must be equal to " - "number of rows in X"); - RAFT_EXPECTS(dist.extent(1) == y.extent(0), - "Number of columns in output must be equal to " - "number of rows in Y"); - - RAFT_EXPECTS(x.is_exhaustive(), "Input x must be contiguous."); - RAFT_EXPECTS(y.is_exhaustive(), "Input y must be contiguous."); - RAFT_EXPECTS(dist.is_exhaustive(), "Output must be contiguous."); - - constexpr auto rowmajor = std::is_same_v; - - auto stream = raft::resource::get_cuda_stream(handle); - rmm::device_uvector workspace(0, stream); - - pairwise_distance(handle, - x.data_handle(), - y.data_handle(), - dist.data_handle(), - x.extent(0), - y.extent(0), - x.extent(1), - metric, - rowmajor, - metric_arg); -} - -/** @} */ - -}; // namespace distance -}; // namespace raft diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh deleted file mode 100644 index de70cd4691..0000000000 --- a/cpp/include/raft/distance/distance.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "distance-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "distance-ext.cuh" -#endif diff --git a/cpp/include/raft/distance/distance_types.hpp b/cpp/include/raft/distance/distance_types.hpp index d17ef358ee..8222fd03f9 100644 --- a/cpp/include/raft/distance/distance_types.hpp +++ b/cpp/include/raft/distance/distance_types.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * Copyright (c) 2021-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -105,4 +105,4 @@ struct KernelParams { } // end namespace kernels }; // namespace distance -}; // end namespace raft +}; // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/distance/fused_distance_nn-ext.cuh b/cpp/include/raft/distance/fused_distance_nn-ext.cuh deleted file mode 100644 index 263bbcea81..0000000000 --- a/cpp/include/raft/distance/fused_distance_nn-ext.cuh +++ /dev/null @@ -1,84 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::KeyValuePair -#include // raft::resources -#include // include initialize and reduce operations -#include // RAFT_EXPLICIT - -#include // int64_t - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft { -namespace distance { - -template -void fusedDistanceNNMinReduce(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - bool sqrt, - bool initOutBuffer, - bool isRowMajor, - raft::distance::DistanceType metric, - float metric_arg, - cudaStream_t stream) RAFT_EXPLICIT; - -} // namespace distance -} // namespace raft - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_distance_fusedDistanceNNMinReduce(DataT, OutT, IdxT) \ - extern template void raft::distance::fusedDistanceNNMinReduce( \ - OutT * min, \ - const DataT* x, \ - const DataT* y, \ - const DataT* xn, \ - const DataT* yn, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - bool sqrt, \ - bool initOutBuffer, \ - bool isRowMajor, \ - raft::distance::DistanceType metric, \ - float metric_arg, \ - cudaStream_t stream) - -instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int); -instantiate_raft_distance_fusedDistanceNNMinReduce(float, float, int64_t); - -// We can't have comma's in the macro expansion, so we use the COMMA macro: -#define COMMA , - -instantiate_raft_distance_fusedDistanceNNMinReduce(float, raft::KeyValuePair, int); -instantiate_raft_distance_fusedDistanceNNMinReduce(float, - raft::KeyValuePair, - int64_t); - -#undef COMMA - -#undef instantiate_raft_distance_fusedDistanceNNMinReduce diff --git a/cpp/include/raft/distance/fused_distance_nn-inl.cuh b/cpp/include/raft/distance/fused_distance_nn-inl.cuh deleted file mode 100644 index ffe86a1c04..0000000000 --- a/cpp/include/raft/distance/fused_distance_nn-inl.cuh +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __FUSED_DISTANCE_NN_H -#define __FUSED_DISTANCE_NN_H - -#pragma once - -#include -#include -#include -#include -#include - -#include - -#include - -#include -#include - -namespace raft { -namespace distance { - -/** - * \ingroup fused_l2_nn - * @{ - */ -/** - * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. - * - * The benefits of such a call are 2-fold: 1) eliminate the need for an - * intermediate buffer to store the output of gemm 2) reduce the memory read - * traffic on this intermediate buffer, otherwise needed during the reduction - * phase for 1-NN. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * @tparam KVPReduceOpT A struct providing functions for key-value pair comparison. - * - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] redOp reduction operator in the epilogue - * @param[in] pairRedOp reduction operation on key value pairs - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] isRowMajor whether the input/output is row or column major. - * @param[in] metric Distance metric to be used (supports L2, cosine) - * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) - * @param[in] stream cuda stream - */ -template -void fusedDistanceNN(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer, - bool isRowMajor, - raft::distance::DistanceType metric, - float metric_arg, - cudaStream_t stream) -{ - ASSERT(isRowMajor, "fusedDistanceNN only supports row major inputs"); - // When k is smaller than 32, the Policy4x4 results in redundant calculations - // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead - // that uses tiles with a smaller value of k. - bool is_skinny = k < 32; - - size_t bytes = sizeof(DataT) * k; - auto px = reinterpret_cast(x); - auto py = reinterpret_cast(y); - if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { - if (is_skinny) { - detail::fusedDistanceNNImpl< - DataT, - OutT, - IdxT, - typename linalg::Policy4x4Skinny::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); - } else { - detail::fusedDistanceNNImpl::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); - } - } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { - if (is_skinny) { - detail::fusedDistanceNNImpl< - DataT, - OutT, - IdxT, - typename linalg::Policy4x4Skinny::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); - } else { - detail::fusedDistanceNNImpl::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); - } - } else { - if (is_skinny) { - detail::fusedDistanceNNImpl::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); - } else { - detail::fusedDistanceNNImpl::Policy, - ReduceOpT>(min, - x, - y, - xn, - yn, - m, - n, - k, - (int*)workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); - } - } -} - -/** - * @brief Wrapper around fusedDistanceNN with minimum reduction operators. - * - * fusedDistanceNN cannot be compiled in the distance library due to the lambda - * operators, so this wrapper covers the most common case (minimum). - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances (e.g. raft::KeyValuePair) or store only the min - * distances. - * @tparam IdxT indexing arithmetic type - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] isRowMajor whether the input/output is row or column major. - * @param[in] metric Distance metric to be used (supports L2, cosine) - * @param[in] metric_arg power argument for distances like Minkowski (not supported for now) - * @param[in] stream cuda stream - */ -template -void fusedDistanceNNMinReduce(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - bool sqrt, - bool initOutBuffer, - bool isRowMajor, - raft::distance::DistanceType metric, - float metric_arg, - cudaStream_t stream) -{ - MinAndDistanceReduceOp redOp; - KVPMinReduce pairRedOp; - - fusedDistanceNN(min, - x, - y, - xn, - yn, - m, - n, - k, - workspace, - redOp, - pairRedOp, - sqrt, - initOutBuffer, - isRowMajor, - metric, - metric_arg, - stream); -} - -/** @} */ - -} // namespace distance -} // namespace raft - -#endif diff --git a/cpp/include/raft/distance/fused_distance_nn.cuh b/cpp/include/raft/distance/fused_distance_nn.cuh deleted file mode 100755 index 04c42e49a1..0000000000 --- a/cpp/include/raft/distance/fused_distance_nn.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "fused_distance_nn-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "fused_distance_nn-ext.cuh" -#endif diff --git a/cpp/include/raft/distance/fused_distance_nn_helpers.cuh b/cpp/include/raft/distance/fused_distance_nn_helpers.cuh deleted file mode 100644 index 3a570c681c..0000000000 --- a/cpp/include/raft/distance/fused_distance_nn_helpers.cuh +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace raft::distance { - -/** - * \defgroup fused_l2_nn Fused 1-nearest neighbors - * @{ - */ - -template -using KVPMinReduce = detail::KVPMinReduceImpl; - -template -using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl; - -template -using MinReduceOp = detail::MinReduceOpImpl; - -/** @} */ - -/** - * Initialize array using init value from reduction op - */ -template -void initialize(raft::resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) -{ - detail::initialize( - min, m, maxVal, redOp, resource::get_cuda_stream(handle)); -} - -} // namespace raft::distance diff --git a/cpp/include/raft/distance/fused_l2_nn-ext.cuh b/cpp/include/raft/distance/fused_l2_nn-ext.cuh deleted file mode 100644 index d0ac83cd51..0000000000 --- a/cpp/include/raft/distance/fused_l2_nn-ext.cuh +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::KeyValuePair -#include // raft::resources -#include // include initialize and reduce operations -#include // RAFT_EXPLICIT - -#include // int64_t - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft { -namespace distance { - -template -void fusedL2NNMinReduce(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) RAFT_EXPLICIT; - -} // namespace distance -} // namespace raft - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_distance_fusedL2NNMinReduce(DataT, OutT, IdxT) \ - extern template void raft::distance::fusedL2NNMinReduce(OutT * min, \ - const DataT* x, \ - const DataT* y, \ - const DataT* xn, \ - const DataT* yn, \ - IdxT m, \ - IdxT n, \ - IdxT k, \ - void* workspace, \ - bool sqrt, \ - bool initOutBuffer, \ - cudaStream_t stream) - -instantiate_raft_distance_fusedL2NNMinReduce(double, double, int); -instantiate_raft_distance_fusedL2NNMinReduce(double, double, int64_t); -instantiate_raft_distance_fusedL2NNMinReduce(float, float, int); -instantiate_raft_distance_fusedL2NNMinReduce(float, float, int64_t); - -// We can't have comma's in the macro expansion, so we use the COMMA macro: -#define COMMA , - -instantiate_raft_distance_fusedL2NNMinReduce(double, raft::KeyValuePair, int); -instantiate_raft_distance_fusedL2NNMinReduce(double, - raft::KeyValuePair, - int64_t); -instantiate_raft_distance_fusedL2NNMinReduce(float, raft::KeyValuePair, int); -instantiate_raft_distance_fusedL2NNMinReduce(float, - raft::KeyValuePair, - int64_t); - -#undef COMMA - -#undef instantiate_raft_distance_fusedL2NNMinReduce diff --git a/cpp/include/raft/distance/fused_l2_nn-inl.cuh b/cpp/include/raft/distance/fused_l2_nn-inl.cuh deleted file mode 100644 index bf9a49d813..0000000000 --- a/cpp/include/raft/distance/fused_l2_nn-inl.cuh +++ /dev/null @@ -1,209 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __FUSED_L2_NN_H -#define __FUSED_L2_NN_H - -#pragma once - -#include -#include -#include -#include -#include - -#include - -#include - -#include -#include - -namespace raft { -namespace distance { - -/** - * \ingroup fused_l2_nn - * @{ - */ -/** - * @brief Fused L2 distance and 1-nearest-neighbor computation in a single call. - * - * The benefits of such a call are 2-fold: 1) eliminate the need for an - * intermediate buffer to store the output of gemm 2) reduce the memory read - * traffic on this intermediate buffer, otherwise needed during the reduction - * phase for 1-NN. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] redOp reduction operator in the epilogue - * @param[in] pairRedOp reduction operation on key value pairs - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] stream cuda stream - */ -template -void fusedL2NN(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - ReduceOpT redOp, - KVPReduceOpT pairRedOp, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - // When k is smaller than 32, the Policy4x4 results in redundant calculations - // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead - // that uses tiles with a smaller value of k. - bool is_skinny = k < 32; - - size_t bytes = sizeof(DataT) * k; - auto px = reinterpret_cast(x); - auto py = reinterpret_cast(y); - if (16 % sizeof(DataT) == 0 && bytes % 16 == 0 && px % 16 == 0 && py % 16 == 0) { - if (is_skinny) { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } else if (8 % sizeof(DataT) == 0 && bytes % 8 == 0 && px % 8 == 0 && py % 8 == 0) { - if (is_skinny) { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } else { - if (is_skinny) { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } else { - detail::fusedL2NNImpl::Policy, - ReduceOpT>( - min, x, y, xn, yn, m, n, k, (int*)workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); - } - } -} - -/** - * @brief Wrapper around fusedL2NN with minimum reduction operators. - * - * fusedL2NN cannot be compiled in the distance library due to the lambda - * operators, so this wrapper covers the most common case (minimum). - * This should be preferred to the more generic API when possible, in order to - * reduce compilation times for users of the shared library. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances (e.g. raft::KeyValuePair) or store only the min - * distances. - * @tparam IdxT indexing arithmetic type - * @param[out] min will contain the reduced output (Length = `m`) - * (on device) - * @param[in] x first matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] xn L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] yn L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] m gemm m - * @param[in] n gemm n - * @param[in] k gemm k - * @param[in] workspace temp workspace. Size = sizeof(int)*m. (on device) - * @param[in] sqrt Whether the output `minDist` should contain L2-sqrt - * @param[in] initOutBuffer whether to initialize the output buffer before the - * main kernel launch - * @param[in] stream cuda stream - */ -template -void fusedL2NNMinReduce(OutT* min, - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream) -{ - MinAndDistanceReduceOp redOp; - KVPMinReduce pairRedOp; - - fusedL2NN( - min, x, y, xn, yn, m, n, k, workspace, redOp, pairRedOp, sqrt, initOutBuffer, stream); -} - -/** @} */ - -} // namespace distance -} // namespace raft - -#endif diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh deleted file mode 100644 index b1a3551323..0000000000 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "fused_l2_nn-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "fused_l2_nn-ext.cuh" -#endif diff --git a/cpp/include/raft/distance/kernels.cuh b/cpp/include/raft/distance/kernels.cuh deleted file mode 100644 index 39a6ef8c6e..0000000000 --- a/cpp/include/raft/distance/kernels.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2019-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -namespace raft::distance::kernels { - -// TODO: Need to expose formal APIs for this that are more consistent w/ other APIs in RAFT -using raft::distance::kernels::detail::GramMatrixBase; -using raft::distance::kernels::detail::KernelFactory; - -}; // end namespace raft::distance::kernels diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh deleted file mode 100644 index 5833a2e681..0000000000 --- a/cpp/include/raft/distance/masked_nn.cuh +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef __MASKED_L2_NN_H -#define __MASKED_L2_NN_H - -#pragma once - -#include -#include -#include -#include - -#include - -#include - -namespace raft { -namespace distance { -/** - * \defgroup masked_nn Masked 1-nearest neighbors - * @{ - */ - -/** - * @brief Parameter struct for masked_l2_nn function - * - * @tparam ReduceOpT Type of reduction operator in the epilogue. - * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. - * - * Usage example: - * @code{.cpp} - * #include - * - * using IdxT = int; - * using DataT = float; - * using RedOpT = raft::distance::MinAndDistanceReduceOp; - * using PairRedOpT = raft::distance::KVPMinReduce; - * using ParamT = raft::distance::masked_l2_nn_params; - * - * bool init_out = true; - * bool sqrt = false; - * - * ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; - * @endcode - * - * Prescribes how to reduce a distance to an intermediate type (`redOp`), and - * how to reduce two intermediate types (`pairRedOp`). Typically, a distance is - * mapped to an (index, value) pair and (index, value) pair with the lowest - * value (distance) is selected. - * - * In addition, prescribes whether to compute the square root of the distance - * (`sqrt`) and whether to initialize the output buffer (`initOutBuffer`). - */ -template -struct masked_l2_nn_params { - /** Reduction operator in the epilogue */ - ReduceOpT redOp; - /** Reduction operation on key value pairs */ - KVPReduceOpT pairRedOp; - /** Whether the output `minDist` should contain L2-sqrt */ - bool sqrt; - /** Whether to initialize the output buffer before the main kernel launch */ - bool initOutBuffer; -}; - -/** - * @brief Masked L2 distance and 1-nearest-neighbor computation in a single call. - * - * This function enables faster computation of nearest neighbors if the - * computation of distances between certain point pairs can be skipped. - * - * We use an adjacency matrix that describes which distances to calculate. The - * points in `y` are divided into groups, and the adjacency matrix indicates - * whether to compute distances between points in `x` and groups in `y`. In other - * words, if `adj[i,k]` is true then distance between point `x_i`, and points in - * `group_k` will be calculated. - * - * **Performance considerations** - * - * The points in `x` are processed in tiles of `M` points (`M` is currently 64, - * but may change in the future). As a result, the largest compute time - * reduction occurs if all `M` points can skip a group. If only part of the `M` - * points can skip a group, then at most a minor compute time reduction and a - * modest energy use reduction can be expected. - * - * The points in `y` are also grouped into tiles of `N` points (`N` is currently - * 64, but may change in the future). As a result, group sizes should be larger - * than `N` to avoid wasting computational resources. If the group sizes are - * evenly divisible by `N`, then the computation is most efficient, although for - * larger group sizes this effect is minor. - * - * - * **Comparison to SDDM** - * - * [SDDMM](https://ieeexplore.ieee.org/document/8638042) (sampled dense-dense - * matrix multiplication) is a matrix-matrix multiplication where only part of - * the output is computed. Compared to masked_l2_nn, there are a few differences: - * - * - The output of masked_l2_nn is a single vector (of nearest neighbors) and not - * a sparse matrix. - * - * - The sampling in masked_l2_nn is expressed through intermediate "groups" - rather than a CSR format. - * - * @tparam DataT data type - * @tparam OutT output type to either store 1-NN indices and their minimum - * distances or store only the min distances. Accordingly, one - * has to pass an appropriate `ReduceOpT` - * @tparam IdxT indexing arithmetic type - * @tparam ReduceOpT A struct to perform the final needed reduction operation - * and also to initialize the output array elements with the - * appropriate initial value needed for reduction. - * - * @param handle RAFT handle for managing expensive resources - * @param params Parameter struct specifying the reduction operations. - * @param[in] x First matrix. Row major. Dim = `m x k`. - * (on device). - * @param[in] y Second matrix. Row major. Dim = `n x k`. - * (on device). - * @param[in] x_norm L2 squared norm of `x`. Length = `m`. (on device). - * @param[in] y_norm L2 squared norm of `y`. Length = `n`. (on device) - * @param[in] adj A boolean adjacency matrix indicating for each - * row of `x` and each group in `y` whether to compute the - * distance. Dim = `m x num_groups`. - * @param[in] group_idxs An array containing the *end* indices of each group - * in `y`. The value of group_idxs[j] indicates the - * start of group j + 1, i.e., it is the inclusive - * scan of the group lengths. The first group is - * always assumed to start at index 0 and the last - * group typically ends at index `n`. Length = - * `num_groups`. - * @param[out] out will contain the reduced output (Length = `m`) - * (on device) - */ -template -void masked_l2_nn(raft::resources const& handle, - raft::distance::masked_l2_nn_params params, - raft::device_matrix_view x, - raft::device_matrix_view y, - raft::device_vector_view x_norm, - raft::device_vector_view y_norm, - raft::device_matrix_view adj, - raft::device_vector_view group_idxs, - raft::device_vector_view out) -{ - IdxT m = x.extent(0); - IdxT n = y.extent(0); - IdxT k = x.extent(1); - IdxT num_groups = group_idxs.extent(0); - - // Match k dimension of x, y - RAFT_EXPECTS(x.extent(1) == y.extent(1), "Dimension of vectors in x and y must be equal."); - // Match x, x_norm and y, y_norm - RAFT_EXPECTS(m == x_norm.extent(0), "Length of `x_norm` must match input `x`."); - RAFT_EXPECTS(n == y_norm.extent(0), "Length of `y_norm` must match input `y` "); - // Match adj to x and group_idxs - RAFT_EXPECTS(m == adj.extent(0), "#rows in `adj` must match input `x`."); - RAFT_EXPECTS(num_groups == adj.extent(1), "#cols in `adj` must match length of `group_idxs`."); - // NOTE: We do not check if all indices in group_idxs actually points *inside* y. - - // If there is no work to be done, return immediately. - if (m == 0 || n == 0 || k == 0 || num_groups == 0) { return; } - - detail::masked_l2_nn_impl(handle, - out.data_handle(), - x.data_handle(), - y.data_handle(), - x_norm.data_handle(), - y_norm.data_handle(), - adj.data_handle(), - group_idxs.data_handle(), - num_groups, - m, - n, - k, - params.redOp, - params.pairRedOp, - params.sqrt, - params.initOutBuffer); -} - -/** @} */ - -} // namespace distance -} // namespace raft - -#endif diff --git a/cpp/include/raft/distance/specializations.cuh b/cpp/include/raft/distance/specializations.cuh deleted file mode 100644 index cba059154f..0000000000 --- a/cpp/include/raft/distance/specializations.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_HIDE_DEPRECATION_WARNINGS -#pragma message( \ - __FILE__ \ - " is deprecated and will be removed." \ - " Including specializations is not necessary any more." \ - " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") -#endif diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh deleted file mode 100644 index cba059154f..0000000000 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2021-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_HIDE_DEPRECATION_WARNINGS -#pragma message( \ - __FILE__ \ - " is deprecated and will be removed." \ - " Including specializations is not necessary any more." \ - " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") -#endif diff --git a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh deleted file mode 100644 index e85b05575f..0000000000 --- a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_HIDE_DEPRECATION_WARNINGS -#pragma message( \ - __FILE__ \ - " is deprecated and will be removed." \ - " Including specializations is not necessary any more." \ - " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") -#endif diff --git a/cpp/include/raft/matrix/detail/select_k.cuh b/cpp/include/raft/matrix/detail/select_k.cuh index 711169984b..e062ca6854 100644 --- a/cpp/include/raft/matrix/detail/select_k.cuh +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,10 +15,8 @@ */ #pragma once -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY #include "select_k-inl.cuh" -#endif -#ifdef RAFT_COMPILED -#include "select_k-ext.cuh" -#endif +// #ifdef RAFT_COMPILED +// #include "select_k-ext.cuh" +// #endif diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh deleted file mode 100644 index 4055c253c8..0000000000 --- a/cpp/include/raft/neighbors/brute_force-ext.cuh +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Copyright (c) 2020-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::device_matrix_view -#include // raft::identity_op -#include // raft::resources -#include // raft::distance::DistanceType -#include -#include // RAFT_EXPLICIT - -#include - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft::neighbors::brute_force { - -template -inline void knn_merge_parts( - raft::resources const& handle, - raft::device_matrix_view in_keys, - raft::device_matrix_view in_values, - raft::device_matrix_view out_keys, - raft::device_matrix_view out_values, - size_t n_samples, - std::optional> translations = std::nullopt) RAFT_EXPLICIT; - -template -index build(raft::resources const& res, - mdspan, row_major, Accessor> dataset, - raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - T metric_arg = 0.0) RAFT_EXPLICIT; - -template -index build(raft::resources const& res, - index_params const& params, - mdspan, row_major, Accessor> dataset) RAFT_EXPLICIT; - -template -void search(raft::resources const& res, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) RAFT_EXPLICIT; - -template -void search(raft::resources const& res, - search_params const& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) RAFT_EXPLICIT; - -template -void knn(raft::resources const& handle, - std::vector> index, - raft::device_matrix_view search, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - std::optional metric_arg = std::make_optional(2.0f), - std::optional global_id_offset = std::nullopt, - epilogue_op distance_epilogue = raft::identity_op()) RAFT_EXPLICIT; - -template -void fused_l2_knn(raft::resources const& handle, - raft::device_matrix_view index, - raft::device_matrix_view query, - raft::device_matrix_view out_inds, - raft::device_matrix_view out_dists, - raft::distance::DistanceType metric) RAFT_EXPLICIT; - -} // namespace raft::neighbors::brute_force - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -// No extern template for raft::neighbors::brute_force::knn_merge_parts - -#define instantiate_raft_neighbors_brute_force_knn( \ - idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ - extern template void raft::neighbors::brute_force:: \ - knn( \ - raft::resources const& handle, \ - std::vector> index, \ - raft::device_matrix_view search, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - raft::distance::DistanceType metric, \ - std::optional metric_arg, \ - std::optional global_id_offset, \ - epilogue_op distance_epilogue); - -instantiate_raft_neighbors_brute_force_knn( - int64_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); -instantiate_raft_neighbors_brute_force_knn( - int64_t, float, int64_t, raft::row_major, raft::row_major, raft::identity_op); -instantiate_raft_neighbors_brute_force_knn( - int, float, int, raft::row_major, raft::row_major, raft::identity_op); -instantiate_raft_neighbors_brute_force_knn( - uint32_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); - -#undef instantiate_raft_neighbors_brute_force_knn - -namespace raft::neighbors::brute_force { - -extern template void search( - raft::resources const& res, - const raft::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances); - -extern template void search( - raft::resources const& res, - search_params const& params, - const raft::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances); - -extern template void search( - raft::resources const& res, - const raft::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances); - -extern template void search( - raft::resources const& res, - search_params const& params, - const raft::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances); - -extern template raft::neighbors::brute_force::index build( - raft::resources const& res, - raft::device_matrix_view dataset, - raft::distance::DistanceType metric, - float metric_arg); - -extern template raft::neighbors::brute_force::index build( - raft::resources const& res, - index_params const& params, - raft::device_matrix_view dataset); - -extern template raft::neighbors::brute_force::index build( - raft::resources const& res, - raft::host_matrix_view dataset, - raft::distance::DistanceType metric, - float metric_arg); - -extern template raft::neighbors::brute_force::index build( - raft::resources const& res, - index_params const& params, - raft::host_matrix_view dataset); -} // namespace raft::neighbors::brute_force - -#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ - value_t, idx_t, idx_layout, query_layout) \ - extern template void raft::neighbors::brute_force::fused_l2_knn( \ - raft::resources const& handle, \ - raft::device_matrix_view index, \ - raft::device_matrix_view query, \ - raft::device_matrix_view out_inds, \ - raft::device_matrix_view out_dists, \ - raft::distance::DistanceType metric); - -instantiate_raft_neighbors_brute_force_fused_l2_knn(float, - int64_t, - raft::row_major, - raft::row_major) - -#undef instantiate_raft_neighbors_brute_force_fused_l2_knn diff --git a/cpp/include/raft/neighbors/brute_force-inl.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh deleted file mode 100644 index f955cc8518..0000000000 --- a/cpp/include/raft/neighbors/brute_force-inl.cuh +++ /dev/null @@ -1,453 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include - -namespace raft::neighbors::brute_force { - -/** - * @defgroup brute_force_knn Brute-force K-Nearest Neighbors - * @{ - */ - -/** - * @brief Performs a k-select across several (contiguous) row-partitioned index/distance - * matrices formatted like the following: - * - * part1row1: k0, k1, k2, k3 - * part1row2: k0, k1, k2, k3 - * part1row3: k0, k1, k2, k3 - * part2row1: k0, k1, k2, k3 - * part2row2: k0, k1, k2, k3 - * part2row3: k0, k1, k2, k3 - * etc... - * - * The example above shows what an aggregated index/distance matrix - * would look like with two partitions when n_samples=3 and k=4. - * - * When working with extremely large data sets that have been broken - * over multiple indexes, such as when computing over multiple GPUs, - * the ids will often start at 0 for each local knn index but the - * global ids need to be used when merging them together. An optional - * translations vector can be supplied to map the starting id of - * each partition to its global id so that the final merged knn - * is based on the global ids. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * using namespace raft::neighbors; - * - * raft::resources handle; - * ... - * compute multiple knn graphs and aggregate row-wise - * (see detailed description above) - * ... - * brute_force::knn_merge_parts(handle, in_keys, in_values, out_keys, out_values, n_samples); - * @endcode - * - * @tparam idx_t - * @tparam value_t - * - * @param[in] handle - * @param[in] in_keys matrix of input keys (size n_samples * n_parts * k) - * @param[in] in_values matrix of input values (size n_samples * n_parts * k) - * @param[out] out_keys matrix of output keys (size n_samples * k) - * @param[out] out_values matrix of output values (size n_samples * k) - * @param[in] n_samples number of rows in each partition - * @param[in] translations optional vector of starting global id mappings for each local partition - */ -template -inline void knn_merge_parts( - raft::resources const& handle, - raft::device_matrix_view in_keys, - raft::device_matrix_view in_values, - raft::device_matrix_view out_keys, - raft::device_matrix_view out_values, - size_t n_samples, - std::optional> translations = std::nullopt) -{ - RAFT_EXPECTS(in_keys.extent(1) == in_values.extent(1) && in_keys.extent(0) == in_values.extent(0), - "in_keys and in_values must have the same shape."); - RAFT_EXPECTS( - out_keys.extent(0) == out_values.extent(0) && out_keys.extent(0) == n_samples, - "Number of rows in output keys and val matrices must equal number of rows in search matrix."); - RAFT_EXPECTS( - out_keys.extent(1) == out_values.extent(1) && out_keys.extent(1) == in_keys.extent(1), - "Number of columns in output indices and distances matrices must be equal to k"); - - idx_t* translations_ptr = nullptr; - if (translations.has_value()) { translations_ptr = translations.value().data_handle(); } - - auto n_parts = in_keys.extent(0) / n_samples; - detail::knn_merge_parts(in_keys.data_handle(), - in_values.data_handle(), - out_keys.data_handle(), - out_values.data_handle(), - n_samples, - n_parts, - in_keys.extent(1), - resource::get_cuda_stream(handle), - translations_ptr); -} - -/** - * @brief Flat C++ API function to perform a brute force knn on - * a series of input arrays and combine the results into a single - * output array for indexes and distances. Inputs can be either - * row- or column-major but the output matrices will always be in - * row-major format. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * #include - * using namespace raft::neighbors; - * - * raft::resources handle; - * ... - * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * brute_force::knn(handle, index, search, indices, distances, metric); - * @endcode - * - * @param[in] handle: the cuml handle to use - * @param[in] index: vector of device matrices (each size m_i*d) to be used as the knn index - * @param[in] search: matrix (size n*d) to be used for searching the index - * @param[out] indices: matrix (size n*k) to store output knn indices - * @param[out] distances: matrix (size n*k) to store the output knn distance - * @param[in] metric: distance metric to use. Euclidean (L2) is used by default - * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This - * is ignored if the metric_type is not Minkowski. - * @param[in] global_id_offset: optional starting global id mapping for the local partition - * (assumes the index contains contiguous ids in the global id space) - * @param[in] distance_epilogue: optional epilogue function to run after computing distances. This - function takes a triple of the (value, rowid, colid) for each - element in the pairwise distances and returns a transformed value - back. - */ -template -void knn(raft::resources const& handle, - std::vector> index, - raft::device_matrix_view search, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - std::optional metric_arg = std::make_optional(2.0f), - std::optional global_id_offset = std::nullopt, - epilogue_op distance_epilogue = raft::identity_op()) -{ - RAFT_EXPECTS(index[0].extent(1) == search.extent(1), - "Number of dimensions for both index and search matrices must be equal"); - - RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0), - "Number of rows in output indices and distances matrices must equal number of rows " - "in search matrix."); - RAFT_EXPECTS(indices.extent(1) == distances.extent(1) && distances.extent(1), - "Number of columns in output indices and distances matrices must the same"); - - bool rowMajorIndex = std::is_same_v; - bool rowMajorQuery = std::is_same_v; - - std::vector inputs; - std::vector sizes; - for (std::size_t i = 0; i < index.size(); ++i) { - inputs.push_back(const_cast(index[i].data_handle())); - sizes.push_back(index[i].extent(0)); - } - - std::vector trans; - if (global_id_offset.has_value()) { trans.push_back(global_id_offset.value()); } - - std::vector* trans_arg = global_id_offset.has_value() ? &trans : nullptr; - - raft::neighbors::detail::brute_force_knn_impl(handle, - inputs, - sizes, - index[0].extent(1), - // TODO: This is unfortunate. Need to fix. - const_cast(search.data_handle()), - search.extent(0), - indices.data_handle(), - distances.data_handle(), - indices.extent(1), - rowMajorIndex, - rowMajorQuery, - trans_arg, - metric, - metric_arg.value_or(2.0f), - distance_epilogue); -} - -/** - * @brief Compute the k-nearest neighbors using L2 expanded/unexpanded distance. - * - * This is a specialized function for fusing the k-selection with the distance - * computation when k < 64. The value of k will be inferred from the number - * of columns in the output matrices. - * - * Usage example: - * @code{.cpp} - * #include - * #include - * #include - * using namespace raft::neighbors; - * - * raft::resources handle; - * ... - * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * brute_force::fused_l2_knn(handle, index, search, indices, distances, metric); - * @endcode - - * @tparam value_t type of values - * @tparam idx_t type of indices - * @tparam idx_layout layout type of index matrix - * @tparam query_layout layout type of query matrix - * @param[in] handle raft handle for sharing expensive resources - * @param[in] index input index array on device (size m * d) - * @param[in] query input query array on device (size n * d) - * @param[out] out_inds output indices array on device (size n * k) - * @param[out] out_dists output dists array on device (size n * k) - * @param[in] metric type of distance computation to perform (must be a variant of L2) - */ -template -void fused_l2_knn(raft::resources const& handle, - raft::device_matrix_view index, - raft::device_matrix_view query, - raft::device_matrix_view out_inds, - raft::device_matrix_view out_dists, - raft::distance::DistanceType metric) -{ - int k = static_cast(out_inds.extent(1)); - - RAFT_EXPECTS(k <= 64, "For fused k-selection, k must be < 64"); - RAFT_EXPECTS(out_inds.extent(1) == out_dists.extent(1), "Value of k must match for outputs"); - RAFT_EXPECTS(index.extent(1) == query.extent(1), - "Number of columns in input matrices must be the same."); - - RAFT_EXPECTS(metric == distance::DistanceType::L2Expanded || - metric == distance::DistanceType::L2Unexpanded || - metric == distance::DistanceType::L2SqrtUnexpanded || - metric == distance::DistanceType::L2SqrtExpanded, - "Distance metric must be L2"); - - size_t n_index_rows = index.extent(0); - size_t n_query_rows = query.extent(0); - size_t D = index.extent(1); - - RAFT_EXPECTS(raft::is_row_or_column_major(index), "Index must be row or column major layout"); - RAFT_EXPECTS(raft::is_row_or_column_major(query), "Query must be row or column major layout"); - - const bool rowMajorIndex = raft::is_row_major(index); - const bool rowMajorQuery = raft::is_row_major(query); - - raft::spatial::knn::detail::fusedL2Knn(D, - out_inds.data_handle(), - out_dists.data_handle(), - index.data_handle(), - query.data_handle(), - n_index_rows, - n_query_rows, - k, - rowMajorIndex, - rowMajorQuery, - resource::get_cuda_stream(handle), - metric); -} - -/** - * @brief Build the index from the dataset for efficient search. - * - * This function builds a brute force index for the given dataset. This lets you re-use - * precalculated norms for the dataset, leading to a speedup over calling - * raft::neighbors::brute_force::knn repeatedly. - * - * Example usage: - * @code{.cpp} - * #include - * #include - * #include - * - * // create a random dataset - * int n_rows = 10000; - * int n_cols = 10000; - * - * raft::device_resources res; - * auto dataset = raft::make_device_matrix(res, n_rows, n_cols); - * auto labels = raft::make_device_vector(res, n_rows); - * - * raft::random::make_blobs(res, dataset.view(), labels.view()); - * - * // create a brute_force knn index from the dataset - * auto index = raft::neighbors::brute_force::build(res, - * raft::make_const_mdspan(dataset.view())); - * - * // Use the constructed index to search for the nearest 128 neighbors - * int k = 128; - * auto search = raft::make_const_mdspan(dataset.view()); - * - * auto indices= raft::make_device_matrix(res, search.extent(0), k); - * auto distances = raft::make_device_matrix(res, search.extent(0), k); - * - * raft::neighbors::brute_force::search(res, - * index, - * search, - * indices.view(), - * distances.view()); - * @endcode - * - * @tparam T data element type - * - * @param[in] res - * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] - * @param[in] metric: distance metric to use. Euclidean (L2) is used by default - * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This - * is ignored if the metric_type is not Minkowski. - * - * @return the constructed brute force index - */ -template -index build(raft::resources const& res, - mdspan, row_major, Accessor> dataset, - raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - T metric_arg = 0.0) -{ - // certain distance metrics can benefit by pre-calculating the norms for the index dataset - // which lets us avoid calculating these at query time - std::optional> norms; - // TODO(wphicks): Replace once mdbuffer is available - auto dataset_storage = std::optional>{}; - auto dataset_view = [&res, &dataset_storage, dataset]() { - if constexpr (std::is_same_v>) { - return dataset; - } else { - dataset_storage = make_device_matrix(res, dataset.extent(0), dataset.extent(1)); - raft::copy(res, dataset_storage->view(), dataset); - return raft::make_const_mdspan(dataset_storage->view()); - } - }(); - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded || - metric == raft::distance::DistanceType::CosineExpanded) { - norms = make_device_vector(res, dataset.extent(0)); - // cosine needs the l2norm, where as l2 distances needs the squared norm - if (metric == raft::distance::DistanceType::CosineExpanded) { - raft::linalg::norm(res, - dataset_view, - norms->view(), - raft::linalg::NormType::L2Norm, - raft::linalg::Apply::ALONG_ROWS, - raft::sqrt_op{}); - } else { - raft::linalg::norm(res, - dataset_view, - norms->view(), - raft::linalg::NormType::L2Norm, - raft::linalg::Apply::ALONG_ROWS); - } - } - - return index(res, dataset, std::move(norms), metric, metric_arg); -} - -/** - * @brief Build the index from the dataset for efficient search. - * - * @tparam T data element type - * - * @param[in] res - * @param[in] params configure the index building - * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] - * - * @return the constructed brute force index - */ -template -index build(raft::resources const& res, - index_params const& params, - mdspan, row_major, Accessor> dataset) -{ - return build(res, dataset, params.metric, float(params.metric_arg)); -} - -/** - * @brief Brute Force search using the constructed index. - * - * See raft::neighbors::brute_force::build for a usage example - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] res raft resources - * @param[in] idx brute force index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search(raft::resources const& res, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - raft::neighbors::detail::brute_force_search(res, idx, queries, neighbors, distances); -} - -/** - * @brief Brute Force search using the constructed index. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] res raft resources - * @param[in] params configure the search - * @param[in] idx brute force index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search(raft::resources const& res, - search_params const& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - raft::neighbors::detail::brute_force_search(res, idx, queries, neighbors, distances); -} - -/** @} */ // end group brute_force_knn -} // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh deleted file mode 100644 index 331ea55540..0000000000 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "brute_force-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "brute_force-ext.cuh" -#endif - -#include - -namespace raft::neighbors::brute_force { -/** - * @brief Make a brute force query over batches of k - * - * This lets you query for batches of k. For example, you can get - * the first 100 neighbors, then the next 100 neighbors etc. - * - * Example usage: - * @code{.cpp} - * #include - * #include - * #include - - * // create a random dataset - * int n_rows = 10000; - * int n_cols = 10000; - - * raft::device_resources res; - * auto dataset = raft::make_device_matrix(res, n_rows, n_cols); - * auto labels = raft::make_device_vector(res, n_rows); - - * raft::random::make_blobs(res, dataset.view(), labels.view()); - * - * // create a brute_force knn index from the dataset - * auto index = raft::neighbors::brute_force::build(res, - * raft::make_const_mdspan(dataset.view())); - * - * // search the index in batches of 128 nearest neighbors - * auto search = raft::make_const_mdspan(dataset.view()); - * auto query = make_batch_k_query(res, index, search, 128); - * for (auto & batch: *query) { - * // batch.indices() and batch.distances() contain the information on the current batch - * } - * - * // we can also support variable sized batches - loaded up a different number - * // of neighbors at each iteration through the ::advance method - * int64_t batch_size = 128; - * query = make_batch_k_query(res, index, search, batch_size); - * for (auto it = query->begin(); it != query->end(); it.advance(batch_size)) { - * // batch.indices() and batch.distances() contain the information on the current batch - * - * batch_size += 16; // load up an extra 16 items in the next batch - * } - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * @param[in] res - * @param[in] index The index to query - * @param[in] query A device matrix view to query for [n_queries, index->dim()] - * @param[in] batch_size The size of each batch - */ - -template -std::shared_ptr> make_batch_k_query( - const raft::resources& res, - const raft::neighbors::brute_force::index& index, - raft::device_matrix_view query, - int64_t batch_size) -{ - return std::shared_ptr>( - new detail::gpu_batch_k_query(res, index, query, batch_size)); -} -} // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force_serialize.cuh b/cpp/include/raft/neighbors/brute_force_serialize.cuh deleted file mode 100644 index bed3bed9e1..0000000000 --- a/cpp/include/raft/neighbors/brute_force_serialize.cuh +++ /dev/null @@ -1,224 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include - -namespace raft::neighbors::brute_force { - -auto static constexpr serialization_version = 0; - -/** - * \defgroup brute_force_serialize Brute Force Serialize - * @{ - */ - -/** - * Write the index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create an output stream - * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = brute_force::build(...);` - * raft::neighbors::brute_force::serialize(handle, os, index); - * @endcode - * - * @tparam T data element type - * - * @param[in] handle the raft handle - * @param[in] os output stream - * @param[in] index brute force index - * @param[in] include_dataset whether to include the dataset in the serialized - * output - * - */ -template -void serialize(raft::resources const& handle, - std::ostream& os, - const index& index, - bool include_dataset = true) -{ - RAFT_LOG_DEBUG( - "Saving brute force index, size %zu, dim %u", static_cast(index.size()), index.dim()); - - auto dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); - dtype_string.resize(4); - os << dtype_string; - - serialize_scalar(handle, os, serialization_version); - serialize_scalar(handle, os, index.size()); - serialize_scalar(handle, os, index.dim()); - serialize_scalar(handle, os, index.metric()); - serialize_scalar(handle, os, index.metric_arg()); - serialize_scalar(handle, os, include_dataset); - if (include_dataset) { serialize_mdspan(handle, os, index.dataset()); } - auto has_norms = index.has_norms(); - serialize_scalar(handle, os, has_norms); - if (has_norms) { serialize_mdspan(handle, os, index.norms()); } - resource::sync_stream(handle); -} - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * // create an index with `auto index = brute_force::build(...);` - * raft::neighbors::brute_force::serialize(handle, filename, index); - * @endcode - * - * @tparam T data element type - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index brute force index - * @param[in] include_dataset whether to include the dataset in the serialized - * output - * - */ -template -void serialize(raft::resources const& handle, - const std::string& filename, - const index& index, - bool include_dataset = true) -{ - auto os = std::ofstream{filename, std::ios::out | std::ios::binary}; - RAFT_EXPECTS(os, "Cannot open file %s", filename.c_str()); - serialize(handle, os, index, include_dataset); -} - -/** - * Load index from input stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create an input stream - * std::istream is(std::cin.rdbuf()); - * using T = float; // data element type - * auto index = raft::neighbors::brute_force::deserialize(handle, is); - * @endcode - * - * @tparam T data element type - * - * @param[in] handle the raft handle - * @param[in] is input stream - * - * @return raft::neighbors::brute_force::index - */ -template -auto deserialize(raft::resources const& handle, std::istream& is) -{ - auto dtype_string = std::array{}; - is.read(dtype_string.data(), 4); - - auto ver = deserialize_scalar(handle, is); - if (ver != serialization_version) { - RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); - } - auto rows = deserialize_scalar(handle, is); - auto dim = deserialize_scalar(handle, is); - auto metric = deserialize_scalar(handle, is); - auto metric_arg = deserialize_scalar(handle, is); - - auto dataset_storage = raft::make_host_matrix(std::int64_t{}, std::int64_t{}); - auto include_dataset = deserialize_scalar(handle, is); - if (include_dataset) { - dataset_storage = raft::make_host_matrix(rows, dim); - deserialize_mdspan(handle, is, dataset_storage.view()); - } - - auto has_norms = deserialize_scalar(handle, is); - auto norms_storage = has_norms ? std::optional{raft::make_host_vector(rows)} - : std::optional>{}; - // TODO(wphicks): Use mdbuffer here when available - auto norms_storage_dev = - has_norms ? std::optional{raft::make_device_vector(handle, rows)} - : std::optional>{}; - if (has_norms) { - deserialize_mdspan(handle, is, norms_storage->view()); - raft::copy(handle, norms_storage_dev->view(), norms_storage->view()); - } - - auto result = index(handle, - raft::make_const_mdspan(dataset_storage.view()), - std::move(norms_storage_dev), - metric, - metric_arg); - resource::sync_stream(handle); - - return result; -} - -/** - * Load index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * using T = float; // data element type - * auto index = raft::neighbors::brute_force::deserialize(handle, filename); - * @endcode - * - * @tparam T data element type - * - * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * - * @return raft::neighbors::brute_force::index - */ -template -auto deserialize(raft::resources const& handle, const std::string& filename) -{ - auto is = std::ifstream{filename, std::ios::in | std::ios::binary}; - RAFT_EXPECTS(is, "Cannot open file %s", filename.c_str()); - - return deserialize(handle, is); -} - -/**@}*/ - -} // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/brute_force_types.hpp b/cpp/include/raft/neighbors/brute_force_types.hpp deleted file mode 100644 index 4511f8d8ba..0000000000 --- a/cpp/include/raft/neighbors/brute_force_types.hpp +++ /dev/null @@ -1,302 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "ann_types.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft::neighbors::brute_force { -/** - * @addtogroup brute_force_knn - * @{ - */ - -using ann::index_params; -using ann::search_params; - -/** - * @brief Brute Force index. - * - * The index stores the dataset and norms for the dataset in device memory. - * - * @tparam T data element type - */ -template -struct index : ann::index { - public: - /** Distance metric used for retrieval */ - [[nodiscard]] constexpr inline raft::distance::DistanceType metric() const noexcept - { - return metric_; - } - - /** Total length of the index (number of vectors). */ - [[nodiscard]] constexpr inline auto size() const noexcept { return dataset_view_.extent(0); } - - /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline auto dim() const noexcept { return dataset_view_.extent(1); } - - /** Dataset [size, dim] */ - [[nodiscard]] inline auto dataset() const noexcept - -> device_matrix_view - { - return dataset_view_; - } - - /** Dataset norms */ - [[nodiscard]] inline auto norms() const -> device_vector_view - { - return norms_view_.value(); - } - - /** Whether or not this index has dataset norms */ - [[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); } - - [[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; } - - // Don't allow copying the index for performance reasons (try avoiding copying data) - index(const index&) = delete; - index(index&&) = default; - auto operator=(const index&) -> index& = delete; - auto operator=(index&&) -> index& = default; - ~index() = default; - - /** Construct a brute force index from dataset - * - * Constructs a brute force index from a dataset. This lets us precompute norms for - * the dataset, providing a speed benefit over doing this at query time. - - * If the dataset is already in GPU memory, then this class stores a non-owning reference to - * the dataset. If the dataset is in host memory, it will be copied to the device and the - * index will own the device memory. - */ - - template - [[deprecated("Use cuVS instead")]] index( - raft::resources const& res, - mdspan, row_major, data_accessor> dataset, - std::optional>&& norms, - raft::distance::DistanceType metric, - T metric_arg = 0.0) - : ann::index(), - metric_(metric), - dataset_(make_device_matrix(res, 0, 0)), - norms_(std::move(norms)), - metric_arg_(metric_arg) - { - if (norms_) { norms_view_ = make_const_mdspan(norms_.value().view()); } - update_dataset(res, dataset); - resource::sync_stream(res); - } - - /** Construct a brute force index from dataset - * - * This class stores a non-owning reference to the dataset and norms here. - * Having precomputed norms gives us a performance advantage at query time. - */ - [[deprecated("Use cuVS instead")]] index( - raft::resources const& res, - raft::device_matrix_view dataset_view, - std::optional> norms_view, - raft::distance::DistanceType metric, - T metric_arg = 0.0) - : ann::index(), - metric_(metric), - dataset_(make_device_matrix(res, 0, 0)), - dataset_view_(dataset_view), - norms_view_(norms_view), - metric_arg_(metric_arg) - { - } - - template - [[deprecated("Use cuVS instead")]] index( - raft::resources const& res, - index_params const& params, - mdspan, row_major, data_accessor> dataset, - std::optional>&& norms = std::nullopt) - : ann::index(), - metric_(params.metric), - dataset_(make_device_matrix(res, 0, 0)), - norms_(std::move(norms)), - metric_arg_(params.metric_arg) - { - if (norms_) { norms_view_ = make_const_mdspan(norms_.value().view()); } - update_dataset(res, dataset); - resource::sync_stream(res); - } - - /** - * Replace the dataset with a new dataset. - */ - void update_dataset(raft::resources const& res, - raft::device_matrix_view dataset) - { - dataset_view_ = dataset; - } - - /** - * Replace the dataset with a new dataset. - * - * We create a copy of the dataset on the device. The index manages the lifetime of this copy. - */ - void update_dataset(raft::resources const& res, - raft::host_matrix_view dataset) - { - dataset_ = make_device_matrix(res, dataset.extent(0), dataset.extent(1)); - raft::copy(res, dataset_.view(), dataset); - dataset_view_ = make_const_mdspan(dataset_.view()); - } - - private: - raft::distance::DistanceType metric_; - raft::device_matrix dataset_; - std::optional> norms_; - std::optional> norms_view_; - raft::device_matrix_view dataset_view_; - T metric_arg_; -}; - -/** - * @brief Interface for performing queries over values of k - * - * This interface lets you iterate over batches of k from a brute_force::index. - * This lets you do things like retrieve the first 100 neighbors for a query, - * apply post processing to remove any unwanted items and then if needed get the - * next 100 closest neighbors for the query. - * - * This query interface exposes C++ iterators through the ::begin and ::end, and - * is compatible with range based for loops. - * - * Note that this class is an abstract class without any cuda dependencies, meaning - * that it doesn't require a cuda compiler to use - but also means it can't be directly - * instantiated. See the raft::neighbors::brute_force::make_batch_k_query - * function for usage examples. - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - */ -template -class batch_k_query { - public: - batch_k_query(const raft::resources& res, - int64_t index_size, - int64_t query_size, - int64_t batch_size) - : res(res), index_size(index_size), query_size(query_size), batch_size(batch_size) - { - } - virtual ~batch_k_query() {} - - using value_type = raft::neighbors::batch; - - class iterator { - public: - using value_type = raft::neighbors::batch; - using reference = const value_type&; - using pointer = const value_type*; - - iterator(const batch_k_query* query, int64_t offset = 0) - : current(query->res, 0, 0), batches(query->res, 0, 0), query(query), offset(offset) - { - query->load_batch(offset, query->batch_size, &batches); - query->slice_batch(batches, offset, query->batch_size, ¤t); - } - - reference operator*() const { return current; } - - pointer operator->() const { return ¤t; } - - iterator& operator++() - { - advance(query->batch_size); - return *this; - } - - iterator operator++(int) - { - iterator previous(*this); - operator++(); - return previous; - } - - /** - * @brief Advance the iterator, using a custom size for the next batch - * - * Using operator++ means that we will load up the same batch_size for each - * batch. This method allows us to get around this restriction, and load up - * arbitrary batch sizes on each iteration. - * See raft::neighbors::brute_force::make_batch_k_query for a usage example. - * - * @param[in] next_batch_size: size of the next batch to load up - */ - void advance(int64_t next_batch_size) - { - offset = std::min(offset + current.batch_size(), query->index_size); - if (offset + next_batch_size > batches.batch_size()) { - query->load_batch(offset, next_batch_size, &batches); - } - query->slice_batch(batches, offset, next_batch_size, ¤t); - } - - friend bool operator==(const iterator& lhs, const iterator& rhs) - { - return (lhs.query == rhs.query) && (lhs.offset == rhs.offset); - }; - friend bool operator!=(const iterator& lhs, const iterator& rhs) { return !(lhs == rhs); }; - - protected: - // the current batch of data - value_type current; - - // the currently loaded group of data (containing multiple batches of data that we can iterate - // through) - value_type batches; - - const batch_k_query* query; - int64_t offset, current_batch_size; - }; - - iterator begin() const { return iterator(this); } - iterator end() const { return iterator(this, index_size); } - - protected: - // these two methods need cuda code, and are implemented in the subclass - virtual void load_batch(int64_t offset, - int64_t next_batch_size, - batch* output) const = 0; - virtual void slice_batch(const value_type& input, - int64_t offset, - int64_t batch_size, - value_type* output) const = 0; - - const raft::resources& res; - int64_t index_size, query_size, batch_size; -}; -/** @} */ - -} // namespace raft::neighbors::brute_force diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh deleted file mode 100644 index 5263ef73e7..0000000000 --- a/cpp/include/raft/neighbors/cagra.cuh +++ /dev/null @@ -1,396 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "detail/cagra/cagra_build.cuh" -#include "detail/cagra/cagra_search.cuh" -#include "detail/cagra/graph_core.cuh" - -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::neighbors::cagra { - -/** - * @defgroup cagra CUDA ANN Graph-based nearest neighbor search - * @{ - */ - -/** - * @brief Build a kNN graph using IVF-PQ. - * - * The kNN graph is the first building block for CAGRA index. - * - * The output is a dense matrix that stores the neighbor indices for each point in the dataset. - * Each point has the same number of neighbors. - * - * See [cagra::build](#cagra::build) for an alternative method. - * - * The following distance metrics are supported: - * - L2Expanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters based on shape of the dataset - * ivf_pq::index_params build_params = ivf_pq::index_params::from_dataset(dataset); - * ivf_pq::search_params search_params; - * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); - * // create knn graph - * cagra::build_knn_graph(res, dataset, knn_graph.view(), 2, build_params, search_params); - * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); - * cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view()); - * // Construct an index from dataset and optimized knn_graph - * auto index = cagra::index(res, build_params.metric(), dataset, - * optimized_graph.view()); - * @endcode - * - * @tparam DataT data element type - * @tparam IdxT type of the dataset vector indices - * - * @param[in] res raft resources - * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] - * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] - * @param[in] refine_rate (optional) refinement rate for ivf-pq search - * @param[in] build_params (optional) ivf_pq index building parameters for knn graph - * @param[in] search_params (optional) ivf_pq search parameters - */ -template -void build_knn_graph(raft::resources const& res, - mdspan, row_major, accessor> dataset, - raft::host_matrix_view knn_graph, - std::optional refine_rate = std::nullopt, - std::optional build_params = std::nullopt, - std::optional search_params = std::nullopt) -{ - using internal_IdxT = typename std::make_unsigned::type; - - auto knn_graph_internal = make_host_matrix_view( - reinterpret_cast(knn_graph.data_handle()), - knn_graph.extent(0), - knn_graph.extent(1)); - auto dataset_internal = mdspan, row_major, accessor>( - dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - - cagra::detail::build_knn_graph( - res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params); -} - -/** - * @brief Build a kNN graph using NN-descent. - * - * The kNN graph is the first building block for CAGRA index. - * - * The output is a dense matrix that stores the neighbor indices for each point in the dataset. - * Each point has the same number of neighbors. - * - * See [cagra::build](#cagra::build) for an alternative method. - * - * The following distance metrics are supported: - * - L2Expanded - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * using namespace raft::neighbors::experimental; - * // use default index parameters - * nn_descent::index_params build_params; - * build_params.graph_degree = 128; - * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); - * // create knn graph - * cagra::build_knn_graph(res, dataset, knn_graph.view(), build_params); - * auto optimized_gaph = raft::make_host_matrix(dataset.extent(0), 64); - * cagra::optimize(res, dataset, nn_descent_index.graph.view(), optimized_graph.view()); - * // Construct an index from dataset and optimized knn_graph - * auto index = cagra::index(res, build_params.metric(), dataset, - * optimized_graph.view()); - * @endcode - * - * @tparam DataT data element type - * @tparam IdxT type of the dataset vector indices - * @tparam accessor host or device accessor_type for the dataset - * @param[in] res raft::resources is an object mangaging resources - * @param[in] dataset input raft::host/device_matrix_view that can be located in - * in host or device memory - * @param[out] knn_graph a host matrix view to store the output knn graph [n_rows, graph_degree] - * @param[in] build_params an instance of experimental::nn_descent::index_params that are parameters - * to run the nn-descent algorithm - */ -template , memory_type::device>> -void build_knn_graph(raft::resources const& res, - mdspan, row_major, accessor> dataset, - raft::host_matrix_view knn_graph, - experimental::nn_descent::index_params build_params) -{ - detail::build_knn_graph(res, dataset, knn_graph, build_params); -} - -/** - * @brief Sort a KNN graph index. - * Preprocessing step for `cagra::optimize`: If a KNN graph is not built using - * `cagra::build_knn_graph`, then it is necessary to call this function before calling - * `cagra::optimize`. If the graph is built by `cagra::build_knn_graph`, it is already sorted and - * you do not need to call this function. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * cagra::index_params build_params; - * auto knn_graph = raft::make_host_matrix(dataset.extent(0), 128); - * // build KNN graph not using `cagra::build_knn_graph` - * // build(knn_graph, dataset, ...); - * // sort graph index - * sort_knn_graph(res, dataset.view(), knn_graph.view()); - * // optimize graph - * cagra::optimize(res, dataset, knn_graph.view(), optimized_graph.view()); - * // Construct an index from dataset and optimized knn_graph - * auto index = cagra::index(res, build_params.metric(), dataset, - * optimized_graph.view()); - * @endcode - * - * @tparam DataT type of the data in the source dataset - * @tparam IdxT type of the dataset vector indices - * - * @param[in] res raft resources - * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] - * @param[in,out] knn_graph a matrix view (host or device) of the input knn graph [n_rows, - * knn_graph_degree] - */ -template , memory_type::device>, - typename g_accessor = - host_device_accessor, memory_type::host>> -void sort_knn_graph(raft::resources const& res, - mdspan, row_major, d_accessor> dataset, - mdspan, row_major, g_accessor> knn_graph) -{ - using internal_IdxT = typename std::make_unsigned::type; - - using g_accessor_internal = - host_device_accessor, g_accessor::mem_type>; - auto knn_graph_internal = - mdspan, row_major, g_accessor_internal>( - reinterpret_cast(knn_graph.data_handle()), - knn_graph.extent(0), - knn_graph.extent(1)); - - auto dataset_internal = mdspan, row_major, d_accessor>( - dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - - cagra::detail::graph::sort_knn_graph(res, dataset_internal, knn_graph_internal); -} - -/** - * @brief Prune a KNN graph. - * - * Decrease the number of neighbors for each node. - * - * See [cagra::build_knn_graph](#cagra::build_knn_graph) for usage example - * - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] res raft resources - * @param[in] knn_graph a matrix view (host or device) of the input knn graph [n_rows, - * knn_graph_degree] - * @param[out] new_graph a host matrix view of the optimized knn graph [n_rows, graph_degree] - */ -template , memory_type::host>> -void optimize(raft::resources const& res, - mdspan, row_major, g_accessor> knn_graph, - raft::host_matrix_view new_graph) -{ - detail::optimize(res, knn_graph, new_graph); -} - -/** - * @brief Build the index from the dataset for efficient search. - * - * The build consist of two steps: build an intermediate knn-graph, and optimize it to - * create the final graph. The index_params struct controls the node degree of these - * graphs. - * - * It is required that dataset and the optimized graph fit the GPU memory. - * - * To customize the parameters for knn-graph building and pruning, and to reuse the - * intermediate results, you could build the index in two steps using - * [cagra::build_knn_graph](#cagra::build_knn_graph) and [cagra::optimize](#cagra::optimize). - * - * The following distance metrics are supported: - * - L2 - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * cagra::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = cagra::build(res, index_params, dataset); - * // use default search parameters - * cagra::search_params search_params; - * // search K nearest neighbours - * auto neighbors = raft::make_device_matrix(res, n_queries, k); - * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search(res, search_params, index, queries, neighbors, distances); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] res - * @param[in] params parameters for building the index - * @param[in] dataset a matrix view (host or device) to a row-major matrix [n_rows, dim] - * - * @return the constructed cagra index - */ -template , memory_type::host>> -index build(raft::resources const& res, - const index_params& params, - mdspan, row_major, Accessor> dataset) -{ - return detail::build(res, params, dataset); -} - -/** - * @brief Search ANN using the constructed index with the given sample filter. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * cagra::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = cagra::build(res, index_params, dataset); - * // use default search parameters - * cagra::search_params search_params; - * // create a bitset to filter the search - * auto removed_indices = raft::make_device_vector(res, n_removed_indices); - * raft::core::bitset removed_indices_bitset( - * res, removed_indices.view(), dataset.extent(0)); - * // search K nearest neighbours according to a bitset - * auto neighbors = raft::make_device_matrix(res, n_queries, k); - * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search_with_filtering(res, search_params, index, queries, neighbors, distances, - * filtering::bitset_filter(removed_indices_bitset.view())); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * @tparam CagraSampleFilterT Device filter function, with the signature - * `(uint32_t query ix, uint32_t sample_ix) -> bool` - * - * @param[in] res raft resources - * @param[in] params configure the search - * @param[in] idx cagra index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - * @param[in] sample_filter a device filter function that greenlights samples for a given query - */ -template -void search_with_filtering(raft::resources const& res, - const search_params& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - CagraSampleFilterT sample_filter = CagraSampleFilterT()) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must equal k"); - RAFT_EXPECTS(queries.extent(1) == idx.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - using internal_IdxT = typename std::make_unsigned::type; - auto queries_internal = raft::make_device_matrix_view( - queries.data_handle(), queries.extent(0), queries.extent(1)); - auto neighbors_internal = raft::make_device_matrix_view( - reinterpret_cast(neighbors.data_handle()), - neighbors.extent(0), - neighbors.extent(1)); - auto distances_internal = raft::make_device_matrix_view( - distances.data_handle(), distances.extent(0), distances.extent(1)); - - return cagra::detail::search_main( - res, params, idx, queries_internal, neighbors_internal, distances_internal, sample_filter); -} - -/** - * @brief Search ANN using the constructed index. - * - * See the [cagra::build](#cagra::build) documentation for a usage example. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] res raft resources - * @param[in] params configure the search - * @param[in] idx cagra index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search(raft::resources const& res, - const search_params& params, - const index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - using none_filter_type = raft::neighbors::filtering::none_cagra_sample_filter; - return cagra::search_with_filtering( - res, params, idx, queries, neighbors, distances, none_filter_type{}); -} - -/** @} */ // end group cagra - -} // namespace raft::neighbors::cagra - -// TODO: Remove deprecated experimental namespace in 23.12 release -namespace raft::neighbors::experimental::cagra { -using raft::neighbors::cagra::build; -using raft::neighbors::cagra::build_knn_graph; -using raft::neighbors::cagra::optimize; -using raft::neighbors::cagra::search; -using raft::neighbors::cagra::sort_knn_graph; -} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/include/raft/neighbors/cagra_serialize.cuh b/cpp/include/raft/neighbors/cagra_serialize.cuh deleted file mode 100644 index eae2269662..0000000000 --- a/cpp/include/raft/neighbors/cagra_serialize.cuh +++ /dev/null @@ -1,237 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "detail/cagra/cagra_serialize.cuh" - -namespace raft::neighbors::cagra { - -/** - * \defgroup cagra_serialize CAGRA Serialize - * @{ - */ - -/** - * Write the index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create an output stream - * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = raft::neighbors::cagra::build(...);` - * raft::neighbors::cagra::serialize(handle, os, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] os output stream - * @param[in] index CAGRA index - * @param[in] include_dataset Whether or not to write out the dataset to the file. - * - */ -template -void serialize(raft::resources const& handle, - std::ostream& os, - const index& index, - bool include_dataset = true) -{ - detail::serialize(handle, os, index, include_dataset); -} - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * // create an index with `auto index = raft::neighbors::cagra::build(...);` - * raft::neighbors::cagra::serialize(handle, filename, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index CAGRA index - * @param[in] include_dataset Whether or not to write out the dataset to the file. - * - */ -template -void serialize(raft::resources const& handle, - const std::string& filename, - const index& index, - bool include_dataset = true) -{ - detail::serialize(handle, filename, index, include_dataset); -} - -/** - * Write the CAGRA built index as a base layer HNSW index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create an output stream - * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = raft::neighbors::cagra::build(...);` - * raft::neighbors::cagra::serialize_to_hnswlib(handle, os, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] os output stream - * @param[in] index CAGRA index - * - */ -template -void serialize_to_hnswlib(raft::resources const& handle, - std::ostream& os, - const raft::neighbors::cagra::index& index) -{ - detail::serialize_to_hnswlib(handle, os, index); -} - -/** - * Save a CAGRA build index in hnswlib base-layer-only serialized format - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * // create an index with `auto index = raft::neighbors::cagra::build(...);` - * raft::neighbors::cagra::serialize_to_hnswlib(handle, filename, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index CAGRA index - * - */ -template -void serialize_to_hnswlib(raft::resources const& handle, - const std::string& filename, - const raft::neighbors::cagra::index& index) -{ - detail::serialize_to_hnswlib(handle, filename, index); -} - -/** - * Load index from input stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create an input stream - * std::istream is(std::cin.rdbuf()); - * using T = float; // data element type - * using IdxT = int; // type of the index - * auto index = raft::neighbors::cagra::deserialize(handle, is); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] is input stream - * - * @return raft::neighbors::experimental::cagra::index - */ -template -index deserialize(raft::resources const& handle, std::istream& is) -{ - return detail::deserialize(handle, is); -} - -/** - * Load index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * using T = float; // data element type - * using IdxT = int; // type of the index - * auto index = raft::neighbors::cagra::deserialize(handle, filename); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * - * @return raft::neighbors::experimental::cagra::index - */ -template -index deserialize(raft::resources const& handle, const std::string& filename) -{ - return detail::deserialize(handle, filename); -} - -/**@}*/ - -} // namespace raft::neighbors::cagra - -// TODO: Remove deprecated experimental namespace in 23.12 release -namespace raft::neighbors::experimental::cagra { -using raft::neighbors::cagra::deserialize; -using raft::neighbors::cagra::serialize; - -} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp deleted file mode 100644 index bc7c380db1..0000000000 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ /dev/null @@ -1,383 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "ann_types.hpp" -#include "dataset.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include -#include - -namespace raft::neighbors::cagra { -/** - * @addtogroup cagra - * @{ - */ - -/** - * @brief ANN algorithm used by CAGRA to build knn graph - * - */ -enum class graph_build_algo { - /* Use IVF-PQ to build all-neighbors knn graph */ - IVF_PQ, - /* Experimental, use NN-Descent to build all-neighbors knn graph */ - NN_DESCENT -}; - -struct index_params : ann::index_params { - /** Degree of input graph for pruning. */ - size_t intermediate_graph_degree = 128; - /** Degree of output graph. */ - size_t graph_degree = 64; - /** ANN algorithm to build knn graph. */ - graph_build_algo build_algo = graph_build_algo::IVF_PQ; - /** Number of Iterations to run if building with NN_DESCENT */ - size_t nn_descent_niter = 20; - /** - * Specify compression params if compression is desired. - * - * NOTE: this is experimental new API, consider it unsafe. - */ - std::optional compression = std::nullopt; -}; - -enum class search_algo { - /** For large batch sizes. */ - SINGLE_CTA, - /** For small batch sizes. */ - MULTI_CTA, - MULTI_KERNEL, - AUTO -}; - -enum class hash_mode { HASH, SMALL, AUTO }; - -struct search_params : ann::search_params { - /** Maximum number of queries to search at the same time (batch size). Auto select when 0.*/ - size_t max_queries = 0; - - /** Number of intermediate search results retained during the search. - * - * This is the main knob to adjust trade off between accuracy and search speed. - * Higher values improve the search accuracy. - */ - size_t itopk_size = 64; - - /** Upper limit of search iterations. Auto select when 0.*/ - size_t max_iterations = 0; - - // In the following we list additional search parameters for fine tuning. - // Reasonable default values are automatically chosen. - - /** Which search implementation to use. */ - search_algo algo = search_algo::AUTO; - - /** Number of threads used to calculate a single distance. 4, 8, 16, or 32. */ - size_t team_size = 0; - - /** Number of graph nodes to select as the starting point for the search in each iteration. aka - * search width?*/ - size_t search_width = 1; - /** Lower limit of search iterations. */ - size_t min_iterations = 0; - - /** Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. */ - size_t thread_block_size = 0; - /** Hashmap type. Auto selection when AUTO. */ - hash_mode hashmap_mode = hash_mode::AUTO; - /** Lower limit of hashmap bit length. More than 8. */ - size_t hashmap_min_bitlen = 0; - /** Upper limit of hashmap fill rate. More than 0.1, less than 0.9.*/ - float hashmap_max_fill_rate = 0.5; - - /** Number of iterations of initial random seed node selection. 1 or more. */ - uint32_t num_random_samplings = 1; - /** Bit mask used for initial random seed node selection. */ - uint64_t rand_xor_mask = 0x128394; -}; - -static_assert(std::is_aggregate_v); -static_assert(std::is_aggregate_v); - -/** - * @brief CAGRA index. - * - * The index stores the dataset and a kNN graph in device memory. - * - * @tparam T data element type - * @tparam IdxT type of the vector indices (represent dataset.extent(0)) - * - */ -template -struct index : ann::index { - static_assert(!raft::is_narrowing_v, - "IdxT must be able to represent all values of uint32_t"); - - public: - /** Distance metric used for clustering. */ - [[nodiscard]] constexpr inline auto metric() const noexcept -> raft::distance::DistanceType - { - return metric_; - } - - /** Total length of the index (number of vectors). */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT - { - auto data_rows = dataset_->n_rows(); - return data_rows > 0 ? data_rows : graph_view_.extent(0); - } - - /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { return dataset_->dim(); } - /** Graph degree */ - [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t - { - return graph_view_.extent(1); - } - - /** - * DEPRECATED: please use data() instead. - * If you need to query dataset dimensions, use the dim() and size() of the cagra index. - * The data_handle() is not always available: you need to do a dynamic_cast to the expected - * dataset type at runtime. - */ - [[nodiscard]] [[deprecated("Use data()")]] inline auto dataset() const noexcept - -> device_matrix_view - { - auto p = dynamic_cast*>(dataset_.get()); - if (p != nullptr) { return p->view(); } - auto d = dataset_->dim(); - return make_device_strided_matrix_view(nullptr, 0, d, d); - } - - /** Dataset [size, dim] */ - [[nodiscard]] inline auto data() const noexcept -> const neighbors::dataset& - { - return *dataset_; - } - - /** neighborhood graph [size, graph-degree] */ - [[nodiscard]] inline auto graph() const noexcept - -> device_matrix_view - { - return graph_view_; - } - - // Don't allow copying the index for performance reasons (try avoiding copying data) - index(const index&) = delete; - index(index&&) = default; - auto operator=(const index&) -> index& = delete; - auto operator=(index&&) -> index& = default; - ~index() = default; - - /** Construct an empty index. */ - [[deprecated("Use cuVS instead")]] index( - raft::resources const& res, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded) - : ann::index(), - metric_(metric), - graph_(make_device_matrix(res, 0, 0)), - dataset_(new neighbors::empty_dataset(0)) - { - } - - /** Construct an index from dataset and knn_graph arrays - * - * If the dataset and graph is already in GPU memory, then the index is just a thin wrapper around - * these that stores a non-owning a reference to the arrays. - * - * The constructor also accepts host arrays. In that case they are copied to the device, and the - * device arrays will be owned by the index. - * - * In case the dasates rows are not 16 bytes aligned, then we create a padded copy in device - * memory to ensure alignment for vectorized load. - * - * Usage examples: - * - * - Cagra index is normally created by the cagra::build - * @code{.cpp} - * using namespace raft::neighbors::experimental; - * auto dataset = raft::make_host_matrix(n_rows, n_cols); - * load_dataset(dataset.view()); - * // use default index parameters - * cagra::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = cagra::build(res, index_params, dataset); - * // use default search parameters - * cagra::search_params search_params; - * // search K nearest neighbours - * auto neighbors = raft::make_device_matrix(res, n_queries, k); - * auto distances = raft::make_device_matrix(res, n_queries, k); - * cagra::search(res, search_params, index, queries, neighbors, distances); - * @endcode - * In the above example, we have passed a host dataset to build. The returned index will own a - * device copy of the dataset and the knn_graph. In contrast, if we pass the dataset as a - * device_mdspan to build, then it will only store a reference to it. - * - * - Constructing index using existing knn-graph - * @code{.cpp} - * using namespace raft::neighbors::experimental; - * - * auto dataset = raft::make_device_matrix(res, n_rows, n_cols); - * auto knn_graph = raft::make_device_matrix(res, n_rows, graph_degree); - * - * // custom loading and graph creation - * // load_dataset(dataset.view()); - * // create_knn_graph(knn_graph.view()); - * - * // Wrap the existing device arrays into an index structure - * cagra::index index(res, metric, raft::make_const_mdspan(dataset.view()), - * raft::make_const_mdspan(knn_graph.view())); - * - * // Both knn_graph and dataset objects have to be in scope while the index is used because - * // the index only stores a reference to these. - * cagra::search(res, search_params, index, queries, neighbors, distances); - * @endcode - * - */ - template - [[deprecated("Use cuVS instead")]] index( - raft::resources const& res, - raft::distance::DistanceType metric, - mdspan, row_major, data_accessor> dataset, - mdspan, row_major, graph_accessor> knn_graph) - : ann::index(), - metric_(metric), - graph_(make_device_matrix(res, 0, 0)), - dataset_(make_aligned_dataset(res, dataset, 16)) - { - RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), - "Dataset and knn_graph must have equal number of rows"); - update_graph(res, knn_graph); - resource::sync_stream(res); - } - - /** - * Replace the dataset with a new dataset. - * - * If the new dataset rows are aligned on 16 bytes, then only a reference is stored to the - * dataset. It is the caller's responsibility to ensure that dataset stays alive as long as the - * index. - */ - void update_dataset(raft::resources const& res, - raft::device_matrix_view dataset) - { - dataset_ = make_aligned_dataset(res, dataset, 16); - } - - /** Set the dataset reference explicitly to a device matrix view with padding. */ - void update_dataset(raft::resources const& res, - raft::device_matrix_view dataset) - { - dataset_ = make_aligned_dataset(res, dataset, 16); - } - - /** - * Replace the dataset with a new dataset. - * - * We create a copy of the dataset on the device. The index manages the lifetime of this copy. - */ - void update_dataset(raft::resources const& res, - raft::host_matrix_view dataset) - { - dataset_ = make_aligned_dataset(res, dataset, 16); - } - - /** Replace the dataset with a new dataset. */ - template - auto update_dataset(raft::resources const& res, DatasetT&& dataset) - -> std::enable_if_t, DatasetT>> - { - dataset_ = std::make_unique(std::move(dataset)); - } - - template - auto update_dataset(raft::resources const& res, std::unique_ptr&& dataset) - -> std::enable_if_t, DatasetT>> - { - dataset_ = std::move(dataset); - } - - /** - * Replace the graph with a new graph. - * - * Since the new graph is a device array, we store a reference to that, and it is - * the caller's responsibility to ensure that knn_graph stays alive as long as the index. - */ - void update_graph(raft::resources const& res, - raft::device_matrix_view knn_graph) - { - graph_view_ = knn_graph; - } - - /** - * Replace the graph with a new graph. - * - * We create a copy of the graph on the device. The index manages the lifetime of this copy. - */ - void update_graph(raft::resources const& res, - raft::host_matrix_view knn_graph) - { - RAFT_LOG_DEBUG("Copying CAGRA knn graph from host to device"); - if ((graph_.extent(0) != knn_graph.extent(0)) || (graph_.extent(1) != knn_graph.extent(1))) { - // clear existing memory before allocating to prevent OOM errors on large graphs - if (graph_.size()) { graph_ = make_device_matrix(res, 0, 0); } - graph_ = make_device_matrix(res, knn_graph.extent(0), knn_graph.extent(1)); - } - raft::copy(graph_.data_handle(), - knn_graph.data_handle(), - knn_graph.size(), - resource::get_cuda_stream(res)); - graph_view_ = graph_.view(); - } - - private: - raft::distance::DistanceType metric_; - raft::device_matrix graph_; - raft::device_matrix_view graph_view_; - std::unique_ptr> dataset_; -}; - -/** @} */ - -} // namespace raft::neighbors::cagra - -// TODO: Remove deprecated experimental namespace in 23.12 release -namespace raft::neighbors::experimental::cagra { -using raft::neighbors::cagra::graph_build_algo; -using raft::neighbors::cagra::hash_mode; -using raft::neighbors::cagra::index; -using raft::neighbors::cagra::index_params; -using raft::neighbors::cagra::search_algo; -using raft::neighbors::cagra::search_params; -} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/include/raft/neighbors/dataset.hpp b/cpp/include/raft/neighbors/dataset.hpp deleted file mode 100644 index a6444775f4..0000000000 --- a/cpp/include/raft/neighbors/dataset.hpp +++ /dev/null @@ -1,330 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include -#include // get_device_for_address -#include // rounding up - -#include -#include -#include - -#ifdef __cpp_lib_bitops -#include -#endif - -namespace raft::neighbors { - -/** Two-dimensional dataset; maybe owning, maybe compressed, maybe strided. */ -template -struct dataset { - using index_type = IdxT; - /** Size of the dataset. */ - [[nodiscard]] virtual auto n_rows() const noexcept -> index_type = 0; - /** Dimensionality of the dataset. */ - [[nodiscard]] virtual auto dim() const noexcept -> uint32_t = 0; - /** Whether the object owns the data. */ - [[nodiscard]] virtual auto is_owning() const noexcept -> bool = 0; - virtual ~dataset() noexcept = default; -}; - -template -struct empty_dataset : public dataset { - using index_type = IdxT; - uint32_t suggested_dim; - explicit empty_dataset(uint32_t dim) noexcept : suggested_dim(dim) {} - [[nodiscard]] auto n_rows() const noexcept -> index_type final { return 0; } - [[nodiscard]] auto dim() const noexcept -> uint32_t final { return suggested_dim; } - [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } -}; - -template -struct strided_dataset : public dataset { - using index_type = IdxT; - using value_type = DataT; - using view_type = device_matrix_view; - [[nodiscard]] auto n_rows() const noexcept -> index_type final { return view().extent(0); } - [[nodiscard]] auto dim() const noexcept -> uint32_t final - { - return static_cast(view().extent(1)); - } - /** Leading dimension of the dataset. */ - [[nodiscard]] constexpr auto stride() const noexcept -> uint32_t - { - auto v = view(); - return static_cast(v.stride(0) > 0 ? v.stride(0) : v.extent(1)); - } - /** Get the view of the data. */ - [[nodiscard]] virtual auto view() const noexcept -> view_type = 0; -}; - -template -struct non_owning_dataset : public strided_dataset { - using index_type = IdxT; - using value_type = DataT; - using typename strided_dataset::view_type; - view_type data; - explicit non_owning_dataset(view_type v) noexcept : data(v) {} - [[nodiscard]] auto is_owning() const noexcept -> bool final { return false; } - [[nodiscard]] auto view() const noexcept -> view_type final { return data; }; -}; - -template -struct owning_dataset : public strided_dataset { - using index_type = IdxT; - using value_type = DataT; - using typename strided_dataset::view_type; - using storage_type = - mdarray, LayoutPolicy, ContainerPolicy>; - using mapping_type = typename view_type::mapping_type; - storage_type data; - mapping_type view_mapping; - owning_dataset(storage_type&& store, mapping_type view_mapping) noexcept - : data{std::move(store)}, view_mapping{view_mapping} - { - } - - [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } - [[nodiscard]] auto view() const noexcept -> view_type final - { - return view_type{data.data_handle(), view_mapping}; - }; -}; - -/** - * @brief Contstruct a strided matrix from any mdarray or mdspan. - * - * This function constructs a non-owning view if the input satisfied two conditions: - * - * 1) The data is accessible from the current device - * 2) The memory layout is the same as expected (row-major matrix with the required stride) - * - * Otherwise, this function constructs an owning device matrix and copies the data. - * When the data is copied, padding elements are filled with zeroes. - * - * @tparam SrcT the source mdarray or mdspan - * - * @param[in] res raft resources handle - * @param[in] src the source mdarray or mdspan - * @param[in] required_stride the leading dimension (in elements) - * @return maybe owning current-device-accessible strided matrix - */ -template -auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t required_stride) - -> std::unique_ptr> -{ - using extents_type = typename SrcT::extents_type; - using value_type = typename SrcT::value_type; - using index_type = typename SrcT::index_type; - using layout_type = typename SrcT::layout_type; - static_assert(extents_type::rank() == 2, "The input must be a matrix."); - static_assert(std::is_same_v || - std::is_same_v> || - std::is_same_v, - "The input must be row-major"); - RAFT_EXPECTS(src.extent(1) <= required_stride, - "The input row length must be not larger than the desired stride."); - cudaPointerAttributes ptr_attrs; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&ptr_attrs, src.data_handle())); - auto* device_ptr = reinterpret_cast(ptr_attrs.devicePointer); - const uint32_t src_stride = src.stride(0) > 0 ? src.stride(0) : src.extent(1); - const bool device_accessible = device_ptr != nullptr; - const bool row_major = src.stride(1) <= 1; - const bool stride_matches = required_stride == src_stride; - - if (device_accessible && row_major && stride_matches) { - // Everything matches: make a non-owning dataset - return std::make_unique>( - make_device_strided_matrix_view( - device_ptr, src.extent(0), src.extent(1), required_stride)); - } - // Something is wrong: have to make a copy and produce an owning dataset - auto out_layout = - make_strided_layout(src.extents(), std::array{required_stride, 1}); - auto out_array = make_device_matrix(res, src.extent(0), required_stride); - - using out_mdarray_type = decltype(out_array); - using out_layout_type = typename out_mdarray_type::layout_type; - using out_container_policy_type = typename out_mdarray_type::container_policy_type; - using out_owning_type = - owning_dataset; - - RAFT_CUDA_TRY(cudaMemsetAsync(out_array.data_handle(), - 0, - out_array.size() * sizeof(value_type), - resource::get_cuda_stream(res))); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), - sizeof(value_type) * required_stride, - src.data_handle(), - sizeof(value_type) * src_stride, - sizeof(value_type) * src.extent(1), - src.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - - return std::make_unique(std::move(out_array), out_layout); -} - -/** - * @brief Contstruct a strided matrix from any mdarray or mdspan. - * - * A variant `make_strided_dataset` that allows specifying the byte alignment instead of the - * explicit stride length. - * - * @tparam SrcT the source mdarray or mdspan - * - * @param[in] res raft resources handle - * @param[in] src the source mdarray or mdspan - * @param[in] align_bytes the required byte alignment for the dataset rows. - * @return maybe owning current-device-accessible strided matrix - */ -template -auto make_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes = 16) - -> std::unique_ptr> -{ - using value_type = typename SrcT::value_type; - constexpr size_t kSize = sizeof(value_type); - uint32_t required_stride = - raft::round_up_safe(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize; - return make_strided_dataset(res, src, required_stride); -} - -/** Parameters for VPQ compression. */ -struct vpq_params { - /** - * The bit length of the vector element after compression by PQ. - * - * Possible values: [4, 5, 6, 7, 8]. - * - * Hint: the smaller the 'pq_bits', the smaller the index size and the better the search - * performance, but the lower the recall. - */ - uint32_t pq_bits = 8; - /** - * The dimensionality of the vector after compression by PQ. - * When zero, an optimal value is selected using a heuristic. - * - * TODO: at the moment `dim` must be a multiple `pq_dim`. - */ - uint32_t pq_dim = 0; - /** - * Vector Quantization (VQ) codebook size - number of "coarse cluster centers". - * When zero, an optimal value is selected using a heuristic. - */ - uint32_t vq_n_centers = 0; - /** The number of iterations searching for kmeans centers (both VQ & PQ phases). */ - uint32_t kmeans_n_iters = 25; - /** - * The fraction of data to use during iterative kmeans building (VQ phase). - * When zero, an optimal value is selected using a heuristic. - */ - double vq_kmeans_trainset_fraction = 0; - /** - * The fraction of data to use during iterative kmeans building (PQ phase). - * When zero, an optimal value is selected using a heuristic. - */ - double pq_kmeans_trainset_fraction = 0; -}; - -/** - * @brief VPQ compressed dataset. - * - * The dataset is compressed using two level quantization - * - * 1. Vector Quantization - * 2. Product Quantization of residuals - * - * @tparam MathT the type of elements in the codebooks - * @tparam IdxT type of the vector indices (represent dataset.extent(0)) - * - */ -template -struct vpq_dataset : public dataset { - /** Vector Quantization codebook - "coarse cluster centers". */ - device_matrix vq_code_book; - /** Product Quantization codebook - "fine cluster centers". */ - device_matrix pq_code_book; - /** Compressed dataset. */ - device_matrix data; - - vpq_dataset(device_matrix&& vq_code_book, - device_matrix&& pq_code_book, - device_matrix&& data) - : vq_code_book{std::move(vq_code_book)}, - pq_code_book{std::move(pq_code_book)}, - data{std::move(data)} - { - } - - [[nodiscard]] auto n_rows() const noexcept -> IdxT final { return data.extent(0); } - [[nodiscard]] auto dim() const noexcept -> uint32_t final { return vq_code_book.extent(1); } - [[nodiscard]] auto is_owning() const noexcept -> bool final { return true; } - - /** Row length of the encoded data in bytes. */ - [[nodiscard]] constexpr inline auto encoded_row_length() const noexcept -> uint32_t - { - return data.extent(1); - } - /** The number of "coarse cluster centers" */ - [[nodiscard]] constexpr inline auto vq_n_centers() const noexcept -> uint32_t - { - return vq_code_book.extent(0); - } - /** The bit length of an encoded vector element after compression by PQ. */ - [[nodiscard]] constexpr inline auto pq_bits() const noexcept -> uint32_t - { - /* - NOTE: pq_bits and the book size - - Normally, we'd store `pq_bits` as a part of the index. - However, we know there's an invariant `pq_n_centers = 1 << pq_bits`, i.e. the codebook size is - the same as the number of possible code values. Hence, we don't store the pq_bits and derive it - from the array dimensions instead. - */ - auto pq_width = pq_n_centers(); -#ifdef __cpp_lib_bitops - return std::countr_zero(pq_width); -#else - uint32_t pq_bits = 0; - while (pq_width > 1) { - pq_bits++; - pq_width >>= 1; - } - return pq_bits; -#endif - } - /** The dimensionality of an encoded vector after compression by PQ. */ - [[nodiscard]] constexpr inline auto pq_dim() const noexcept -> uint32_t - { - return raft::div_rounding_up_unsafe(dim(), pq_len()); - } - /** Dimensionality of a subspaces, i.e. the number of vector components mapped to a subspace */ - [[nodiscard]] constexpr inline auto pq_len() const noexcept -> uint32_t - { - return pq_code_book.extent(1); - } - /** The number of vectors in a PQ codebook (`1 << pq_bits`). */ - [[nodiscard]] constexpr inline auto pq_n_centers() const noexcept -> uint32_t - { - return pq_code_book.extent(0); - } -}; - -} // namespace raft::neighbors diff --git a/cpp/include/raft/neighbors/detail/cagra/bitonic.hpp b/cpp/include/raft/neighbors/detail/cagra/bitonic.hpp deleted file mode 100644 index 1d4c77af6f..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/bitonic.hpp +++ /dev/null @@ -1,227 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include - -namespace raft::neighbors::cagra::detail { -namespace bitonic { - -namespace detail { - -template -_RAFT_DEVICE inline void swap_if_needed(K& k0, V& v0, K& k1, V& v1, const bool asc) -{ - if ((k0 != k1) && ((k0 < k1) != asc)) { - const auto tmp_k = k0; - k0 = k1; - k1 = tmp_k; - const auto tmp_v = v0; - v0 = v1; - v1 = tmp_v; - } -} - -template -_RAFT_DEVICE inline void swap_if_needed(K& k0, V& v0, const unsigned lane_offset, const bool asc) -{ - auto k1 = __shfl_xor_sync(~0u, k0, lane_offset); - auto v1 = __shfl_xor_sync(~0u, v0, lane_offset); - if ((k0 != k1) && ((k0 < k1) != asc)) { - k0 = k1; - v0 = v1; - } -} - -template -struct warp_merge_core { - _RAFT_DEVICE inline void operator()(K k[N], V v[N], const std::uint32_t range, const bool asc) - { - const auto lane_id = threadIdx.x % warp_size; - - if (range == 1) { - for (std::uint32_t b = 2; b <= N; b <<= 1) { - for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { -#pragma unroll - for (std::uint32_t i = 0; i < N; i++) { - std::uint32_t j = i ^ c; - if (i >= j) continue; - const auto line_id = i + (N * lane_id); - const auto p = static_cast(line_id & b) == static_cast(line_id & c); - swap_if_needed(k[i], v[i], k[j], v[j], p); - } - } - } - return; - } - - const std::uint32_t b = range; - for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { - const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); -#pragma unroll - for (std::uint32_t i = 0; i < N; i++) { - swap_if_needed(k[i], v[i], c, p); - } - } - const auto p = ((lane_id & b) == 0); - for (std::uint32_t c = N / 2; c >= 1; c >>= 1) { -#pragma unroll - for (std::uint32_t i = 0; i < N; i++) { - std::uint32_t j = i ^ c; - if (i >= j) continue; - swap_if_needed(k[i], v[i], k[j], v[j], p); - } - } - } -}; - -template -struct warp_merge_core { - _RAFT_DEVICE inline void operator()(K k[6], V v[6], const std::uint32_t range, const bool asc) - { - constexpr unsigned N = 6; - const auto lane_id = threadIdx.x % warp_size; - - if (range == 1) { - for (std::uint32_t i = 0; i < N; i += 3) { - const auto p = (i == 0); - swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); - swap_if_needed(k[1 + i], v[1 + i], k[2 + i], v[2 + i], p); - swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); - } - const auto p = ((lane_id & 1) == 0); - for (std::uint32_t i = 0; i < 3; i++) { - std::uint32_t j = i + 3; - swap_if_needed(k[i], v[i], k[j], v[j], p); - } - for (std::uint32_t i = 0; i < N; i += 3) { - swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); - swap_if_needed(k[1 + i], v[1 + i], k[2 + i], v[2 + i], p); - swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); - } - return; - } - - const std::uint32_t b = range; - for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { - const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); -#pragma unroll - for (std::uint32_t i = 0; i < N; i++) { - swap_if_needed(k[i], v[i], c, p); - } - } - const auto p = ((lane_id & b) == 0); - for (std::uint32_t i = 0; i < 3; i++) { - std::uint32_t j = i + 3; - swap_if_needed(k[i], v[i], k[j], v[j], p); - } - for (std::uint32_t i = 0; i < N; i += N / 2) { - swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); - swap_if_needed(k[1 + i], v[1 + i], k[2 + i], v[2 + i], p); - swap_if_needed(k[0 + i], v[0 + i], k[1 + i], v[1 + i], p); - } - } -}; - -template -struct warp_merge_core { - _RAFT_DEVICE inline void operator()(K k[3], V v[3], const std::uint32_t range, const bool asc) - { - constexpr unsigned N = 3; - const auto lane_id = threadIdx.x % warp_size; - - if (range == 1) { - const auto p = ((lane_id & 1) == 0); - swap_if_needed(k[0], v[0], k[1], v[1], p); - swap_if_needed(k[1], v[1], k[2], v[2], p); - swap_if_needed(k[0], v[0], k[1], v[1], p); - return; - } - - const std::uint32_t b = range; - for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { - const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); -#pragma unroll - for (std::uint32_t i = 0; i < N; i++) { - swap_if_needed(k[i], v[i], c, p); - } - } - const auto p = ((lane_id & b) == 0); - swap_if_needed(k[0], v[0], k[1], v[1], p); - swap_if_needed(k[1], v[1], k[2], v[2], p); - swap_if_needed(k[0], v[0], k[1], v[1], p); - } -}; - -template -struct warp_merge_core { - _RAFT_DEVICE inline void operator()(K k[2], V v[2], const std::uint32_t range, const bool asc) - { - constexpr unsigned N = 2; - const auto lane_id = threadIdx.x % warp_size; - - if (range == 1) { - const auto p = ((lane_id & 1) == 0); - swap_if_needed(k[0], v[0], k[1], v[1], p); - return; - } - - const std::uint32_t b = range; - for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { - const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); -#pragma unroll - for (std::uint32_t i = 0; i < N; i++) { - swap_if_needed(k[i], v[i], c, p); - } - } - const auto p = ((lane_id & b) == 0); - swap_if_needed(k[0], v[0], k[1], v[1], p); - } -}; - -template -struct warp_merge_core { - _RAFT_DEVICE inline void operator()(K k[1], V v[1], const std::uint32_t range, const bool asc) - { - const auto lane_id = threadIdx.x % warp_size; - const std::uint32_t b = range; - for (std::uint32_t c = b / 2; c >= 1; c >>= 1) { - const auto p = static_cast(lane_id & b) == static_cast(lane_id & c); - swap_if_needed(k[0], v[0], c, p); - } - } -}; - -} // namespace detail - -template -__device__ void warp_merge(K k[N], V v[N], unsigned range, const bool asc = true) -{ - detail::warp_merge_core{}(k, v, range, asc); -} - -template -__device__ void warp_sort(K k[N], V v[N], const bool asc = true) -{ - for (std::uint32_t range = 1; range <= warp_size; range <<= 1) { - warp_merge(k, v, range, asc); - } -} - -} // namespace bitonic -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh deleted file mode 100644 index 40dcf68e68..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ /dev/null @@ -1,366 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../../cagra_types.hpp" -#include "../../vpq_dataset.cuh" -#include "graph_core.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -namespace raft::neighbors::cagra::detail { - -template -void build_knn_graph(raft::resources const& res, - mdspan, row_major, accessor> dataset, - raft::host_matrix_view knn_graph, - std::optional refine_rate = std::nullopt, - std::optional build_params = std::nullopt, - std::optional search_params = std::nullopt) -{ - RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded || - build_params->metric == distance::DistanceType::InnerProduct, - "Currently only L2Expanded or InnerProduct metric are supported"); - - uint32_t node_degree = knn_graph.extent(1); - common::nvtx::range fun_scope("cagra::build_graph(%zu, %zu, %u)", - size_t(dataset.extent(0)), - size_t(dataset.extent(1)), - node_degree); - - if (!build_params) { build_params = ivf_pq::index_params::from_dataset(dataset); } - - // Make model name - const std::string model_name = [&]() { - char model_name[1024]; - sprintf(model_name, - "%s-%lux%lu.cluster_%u.pq_%u.%ubit.itr_%u.metric_%u.pqcenter_%u", - "IVF-PQ", - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1)), - build_params->n_lists, - build_params->pq_dim, - build_params->pq_bits, - build_params->kmeans_n_iters, - build_params->metric, - static_cast(build_params->codebook_kind)); - return std::string(model_name); - }(); - - RAFT_LOG_DEBUG("# Building IVF-PQ index %s", model_name.c_str()); - auto index = ivf_pq::build( - res, *build_params, dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - - // - // search top (k + 1) neighbors - // - if (!search_params) { - search_params = ivf_pq::search_params{}; - search_params->n_probes = std::min(dataset.extent(1) * 2, build_params->n_lists); - search_params->lut_dtype = CUDA_R_8U; - search_params->internal_distance_dtype = CUDA_R_32F; - } - const auto top_k = node_degree + 1; - uint32_t gpu_top_k = node_degree * refine_rate.value_or(2.0f); - gpu_top_k = std::min(std::max(gpu_top_k, top_k), dataset.extent(0)); - const auto num_queries = dataset.extent(0); - const auto max_batch_size = 1024; - RAFT_LOG_DEBUG( - "IVF-PQ search node_degree: %d, top_k: %d, gpu_top_k: %d, max_batch_size:: %d, n_probes: %u", - node_degree, - top_k, - gpu_top_k, - max_batch_size, - search_params->n_probes); - - auto distances = raft::make_device_matrix(res, max_batch_size, gpu_top_k); - auto neighbors = raft::make_device_matrix(res, max_batch_size, gpu_top_k); - auto refined_distances = raft::make_device_matrix(res, max_batch_size, top_k); - auto refined_neighbors = raft::make_device_matrix(res, max_batch_size, top_k); - auto neighbors_host = raft::make_host_matrix(max_batch_size, gpu_top_k); - auto queries_host = raft::make_host_matrix(max_batch_size, dataset.extent(1)); - auto refined_neighbors_host = raft::make_host_matrix(max_batch_size, top_k); - auto refined_distances_host = raft::make_host_matrix(max_batch_size, top_k); - - // TODO(tfeher): batched search with multiple GPUs - std::size_t num_self_included = 0; - bool first = true; - const auto start_clock = std::chrono::system_clock::now(); - - rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(res); - - raft::spatial::knn::detail::utils::batch_load_iterator vec_batches( - dataset.data_handle(), - dataset.extent(0), - dataset.extent(1), - max_batch_size, - resource::get_cuda_stream(res), - device_memory); - - size_t next_report_offset = 0; - size_t d_report_offset = dataset.extent(0) / 100; // Report progress in 1% steps. - - for (const auto& batch : vec_batches) { - // Map int64_t to uint32_t because ivf_pq requires the latter. - // TODO(tfeher): remove this mapping once ivf_pq accepts mdspan with int64_t index type - auto queries_view = raft::make_device_matrix_view( - batch.data(), batch.size(), batch.row_width()); - auto neighbors_view = make_device_matrix_view( - neighbors.data_handle(), batch.size(), neighbors.extent(1)); - auto distances_view = make_device_matrix_view( - distances.data_handle(), batch.size(), distances.extent(1)); - - ivf_pq::search(res, *search_params, index, queries_view, neighbors_view, distances_view); - if constexpr (is_host_mdspan_v) { - raft::copy(neighbors_host.data_handle(), - neighbors.data_handle(), - neighbors_view.size(), - resource::get_cuda_stream(res)); - raft::copy(queries_host.data_handle(), - batch.data(), - queries_view.size(), - resource::get_cuda_stream(res)); - auto queries_host_view = make_host_matrix_view( - queries_host.data_handle(), batch.size(), batch.row_width()); - auto neighbors_host_view = make_host_matrix_view( - neighbors_host.data_handle(), batch.size(), neighbors.extent(1)); - auto refined_neighbors_host_view = make_host_matrix_view( - refined_neighbors_host.data_handle(), batch.size(), top_k); - auto refined_distances_host_view = make_host_matrix_view( - refined_distances_host.data_handle(), batch.size(), top_k); - resource::sync_stream(res); - - raft::neighbors::detail::refine_host( - dataset, - queries_host_view, - neighbors_host_view, - refined_neighbors_host_view, - refined_distances_host_view, - build_params->metric); - } else { - auto neighbor_candidates_view = make_device_matrix_view( - neighbors.data_handle(), batch.size(), gpu_top_k); - auto refined_neighbors_view = make_device_matrix_view( - refined_neighbors.data_handle(), batch.size(), top_k); - auto refined_distances_view = make_device_matrix_view( - refined_distances.data_handle(), batch.size(), top_k); - - auto dataset_view = make_device_matrix_view( - dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - raft::neighbors::detail::refine_device( - res, - dataset_view, - queries_view, - neighbor_candidates_view, - refined_neighbors_view, - refined_distances_view, - build_params->metric); - raft::copy(refined_neighbors_host.data_handle(), - refined_neighbors_view.data_handle(), - refined_neighbors_view.size(), - resource::get_cuda_stream(res)); - resource::sync_stream(res); - } - // omit itself & write out - // TODO(tfeher): do this in parallel with GPU processing of next batch - for (std::size_t i = 0; i < batch.size(); i++) { - size_t vec_idx = i + batch.offset(); - for (std::size_t j = 0, num_added = 0; j < top_k && num_added < node_degree; j++) { - const auto v = refined_neighbors_host(i, j); - if (static_cast(v) == vec_idx) { - num_self_included++; - continue; - } - knn_graph(vec_idx, num_added) = v; - num_added++; - } - } - - size_t num_queries_done = batch.offset() + batch.size(); - const auto end_clock = std::chrono::system_clock::now(); - if (batch.offset() > next_report_offset) { - next_report_offset += d_report_offset; - const auto time = - std::chrono::duration_cast(end_clock - start_clock).count() * - 1e-6; - const auto throughput = num_queries_done / time; - - RAFT_LOG_DEBUG( - "# Search %12lu / %12lu (%3.2f %%), %e queries/sec, %.2f minutes ETA, self included = " - "%3.2f %% \r", - num_queries_done, - dataset.extent(0), - num_queries_done / static_cast(dataset.extent(0)) * 100, - throughput, - (num_queries - num_queries_done) / throughput / 60, - static_cast(num_self_included) / num_queries_done * 100.); - } - first = false; - } - - if (!first) RAFT_LOG_DEBUG("# Finished building kNN graph"); -} - -template -void build_knn_graph(raft::resources const& res, - mdspan, row_major, accessor> dataset, - raft::host_matrix_view knn_graph, - experimental::nn_descent::index_params build_params) -{ - auto nn_descent_idx = experimental::nn_descent::index(res, knn_graph); - experimental::nn_descent::build(res, build_params, dataset, nn_descent_idx); - - using internal_IdxT = typename std::make_unsigned::type; - using g_accessor = typename decltype(nn_descent_idx.graph())::accessor_type; - using g_accessor_internal = - host_device_accessor, g_accessor::mem_type>; - - auto knn_graph_internal = - mdspan, row_major, g_accessor_internal>( - reinterpret_cast(nn_descent_idx.graph().data_handle()), - nn_descent_idx.graph().extent(0), - nn_descent_idx.graph().extent(1)); - - graph::sort_knn_graph(res, dataset, knn_graph_internal); -} - -template , memory_type::host>> -void optimize(raft::resources const& res, - mdspan, row_major, g_accessor> knn_graph, - raft::host_matrix_view new_graph) -{ - using internal_IdxT = typename std::make_unsigned::type; - - auto new_graph_internal = raft::make_host_matrix_view( - reinterpret_cast(new_graph.data_handle()), - new_graph.extent(0), - new_graph.extent(1)); - - using g_accessor_internal = - host_device_accessor, memory_type::host>; - auto knn_graph_internal = - mdspan, row_major, g_accessor_internal>( - reinterpret_cast(knn_graph.data_handle()), - knn_graph.extent(0), - knn_graph.extent(1)); - - cagra::detail::graph::optimize(res, knn_graph_internal, new_graph_internal); -} - -template , memory_type::host>> -index build( - raft::resources const& res, - const index_params& params, - mdspan, row_major, Accessor> dataset, - std::optional nn_descent_params = std::nullopt, - std::optional refine_rate = std::nullopt, - std::optional pq_build_params = std::nullopt, - std::optional search_params = std::nullopt, - bool construct_index_with_dataset = true) -{ - size_t intermediate_degree = params.intermediate_graph_degree; - size_t graph_degree = params.graph_degree; - if (intermediate_degree >= static_cast(dataset.extent(0))) { - RAFT_LOG_WARN( - "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", - dataset.extent(0)); - intermediate_degree = dataset.extent(0) - 1; - } - if (intermediate_degree < graph_degree) { - RAFT_LOG_WARN( - "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " - "graph_degree.", - graph_degree, - intermediate_degree); - graph_degree = intermediate_degree; - } - - std::optional> knn_graph( - raft::make_host_matrix(dataset.extent(0), intermediate_degree)); - - if (params.build_algo == graph_build_algo::IVF_PQ) { - build_knn_graph(res, dataset, knn_graph->view(), refine_rate, pq_build_params, search_params); - } else { - RAFT_EXPECTS( - params.metric == raft::distance::DistanceType::L2Expanded, - "L2Expanded is the only distance metrics supported for CAGRA build with nn_descent"); - // Use nn-descent to build CAGRA knn graph - if (!nn_descent_params) { - nn_descent_params = experimental::nn_descent::index_params(); - nn_descent_params->graph_degree = intermediate_degree; - nn_descent_params->intermediate_graph_degree = 1.5 * intermediate_degree; - nn_descent_params->max_iterations = params.nn_descent_niter; - } - build_knn_graph(res, dataset, knn_graph->view(), *nn_descent_params); - } - - auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree); - - RAFT_LOG_INFO("optimizing graph"); - optimize(res, knn_graph->view(), cagra_graph.view()); - - // free intermediate graph before trying to create the index - knn_graph.reset(); - - RAFT_LOG_INFO("Graph optimized, creating index"); - // Construct an index from dataset and optimized knn graph. - if (construct_index_with_dataset) { - if (params.compression.has_value()) { - RAFT_EXPECTS(params.metric == raft::distance::DistanceType::L2Expanded, - "VPQ compression is only supported with L2Expanded distance mertric"); - index idx(res, params.metric); - idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view())); - idx.update_dataset( - res, - // TODO: hardcoding codebook math to `half`, we can do runtime dispatching later - neighbors::vpq_build(res, *params.compression, dataset)); - return idx; - } - return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); - } else { - // We just add the graph. User is expected to update dataset separately. This branch is used - // if user needs special control of memory allocations for the dataset. - index idx(res, params.metric); - idx.update_graph(res, raft::make_const_mdspan(cagra_graph.view())); - return idx; - } -} -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh deleted file mode 100644 index 67fad2e46a..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ /dev/null @@ -1,328 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "compute_distance_vpq.cuh" -#include "factory.cuh" -#include "search_plan.cuh" -#include "search_single_cta.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::neighbors::cagra::detail { - -template -struct CagraSampleFilterWithQueryIdOffset { - const uint32_t offset; - CagraSampleFilterT filter; - - CagraSampleFilterWithQueryIdOffset(const uint32_t offset, const CagraSampleFilterT filter) - : offset(offset), filter(filter) - { - } - - _RAFT_DEVICE auto operator()(const uint32_t query_id, const uint32_t sample_id) - { - return filter(query_id + offset, sample_id); - } -}; - -template -struct CagraSampleFilterT_Selector { - using type = CagraSampleFilterWithQueryIdOffset; -}; -template <> -struct CagraSampleFilterT_Selector { - using type = raft::neighbors::filtering::none_cagra_sample_filter; -}; - -// A helper function to set a query id offset -template -inline typename CagraSampleFilterT_Selector::type set_offset( - CagraSampleFilterT filter, const uint32_t offset) -{ - typename CagraSampleFilterT_Selector::type new_filter(offset, filter); - return new_filter; -} -template <> -inline - typename CagraSampleFilterT_Selector::type - set_offset( - raft::neighbors::filtering::none_cagra_sample_filter filter, const uint32_t) -{ - return filter; -} - -template -void search_main_core( - raft::resources const& res, - search_params params, - DatasetDescriptorT dataset_desc, - raft::device_matrix_view graph, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - CagraSampleFilterT sample_filter = CagraSampleFilterT(), - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded) -{ - RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n", - static_cast(dataset_desc.size), - static_cast(dataset_desc.dim)); - RAFT_LOG_DEBUG("# query size = %lu, dim = %lu\n", - static_cast(queries.extent(0)), - static_cast(queries.extent(1))); - RAFT_EXPECTS(queries.extent(1) == dataset_desc.dim, "Queries and index dim must match"); - const uint32_t topk = neighbors.extent(1); - - cudaDeviceProp deviceProp = resource::get_device_properties(res); - if (params.max_queries == 0) { - params.max_queries = std::min(queries.extent(0), deviceProp.maxGridSize[1]); - } - - common::nvtx::range fun_scope( - "cagra::search(max_queries = %u, k = %u, dim = %zu)", - params.max_queries, - topk, - dataset_desc.dim); - - using CagraSampleFilterT_s = typename CagraSampleFilterT_Selector::type; - std::unique_ptr> plan = - factory::create( - res, params, dataset_desc.dim, graph.extent(1), topk, metric); - - plan->check(topk); - - RAFT_LOG_DEBUG("Cagra search"); - const uint32_t max_queries = plan->max_queries; - const uint32_t query_dim = queries.extent(1); - - for (unsigned qid = 0; qid < queries.extent(0); qid += max_queries) { - const uint32_t n_queries = std::min(max_queries, queries.extent(0) - qid); - auto _topk_indices_ptr = - reinterpret_cast(neighbors.data_handle()) + - (topk * qid); - auto _topk_distances_ptr = distances.data_handle() + (topk * qid); - // todo(tfeher): one could keep distances optional and pass nullptr - const auto* _query_ptr = queries.data_handle() + (query_dim * qid); - const auto* _seed_ptr = - plan->num_seeds > 0 - ? reinterpret_cast(plan->dev_seed.data()) + - (plan->num_seeds * qid) - : nullptr; - uint32_t* _num_executed_iterations = nullptr; - - (*plan)(res, - dataset_desc, - graph, - _topk_indices_ptr, - _topk_distances_ptr, - _query_ptr, - n_queries, - _seed_ptr, - _num_executed_iterations, - topk, - set_offset(sample_filter, qid)); - } -} - -template -void launch_vpq_search_main_core( - raft::resources const& res, - const vpq_dataset* vpq_dset, - search_params params, - raft::device_matrix_view graph, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - CagraSampleFilterT sample_filter, - const raft::distance::DistanceType metric) -{ - RAFT_EXPECTS(vpq_dset->pq_bits() == 8, "Only pq_bits = 8 is supported for now"); - RAFT_EXPECTS(vpq_dset->pq_len() == 2 || vpq_dset->pq_len() == 4, - "Only pq_len 2 or 4 is supported for now"); - RAFT_EXPECTS(vpq_dset->dim() % vpq_dset->pq_dim() == 0, - "dim must be a multiple of pq_dim at the moment"); - - const float vq_scale = 1.0f; - const float pq_scale = 1.0f; - - if (vpq_dset->pq_bits() == 8) { - if (vpq_dset->pq_len() == 2) { - using dataset_desc_t = cagra_q_dataset_descriptor_t; - dataset_desc_t dataset_desc(vpq_dset->data.data_handle(), - vpq_dset->encoded_row_length(), - vpq_dset->pq_dim(), - vpq_dset->vq_code_book.data_handle(), - vq_scale, - vpq_dset->pq_code_book.data_handle(), - pq_scale, - size_t(vpq_dset->n_rows()), - vpq_dset->dim()); - search_main_core( - res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric); - } else if (vpq_dset->pq_len() == 4) { - using dataset_desc_t = cagra_q_dataset_descriptor_t; - dataset_desc_t dataset_desc(vpq_dset->data.data_handle(), - vpq_dset->encoded_row_length(), - vpq_dset->pq_dim(), - vpq_dset->vq_code_book.data_handle(), - vq_scale, - vpq_dset->pq_code_book.data_handle(), - pq_scale, - size_t(vpq_dset->n_rows()), - vpq_dset->dim()); - search_main_core( - res, params, dataset_desc, graph, queries, neighbors, distances, sample_filter, metric); - } else { - RAFT_FAIL("Subspace dimension must be 2 or 4"); - } - } else { - RAFT_FAIL("Only 8-bit PQ is supported now"); - } -} - -/** - * @brief Search ANN using the constructed index. - * - * See the [build](#build) documentation for a usage example. - * - * @tparam T data element type - * @tparam IdxT type of database vector indices - * @tparam internal_IdxT during search we map IdxT to internal_IdxT, this way we do not need - * separate kernels for int/uint. - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] idx ivf-pq constructed index - * @param[in] queries a device matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries, - * k] - */ -template -void search_main(raft::resources const& res, - search_params params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - CagraSampleFilterT sample_filter = CagraSampleFilterT()) -{ - const auto& graph = index.graph(); - auto graph_internal = raft::make_device_matrix_view( - reinterpret_cast(graph.data_handle()), graph.extent(0), graph.extent(1)); - - // n_rows has the same type as the dataset index (the array extents type) - using ds_idx_type = decltype(index.data().n_rows()); - // Dispatch search parameters based on the dataset kind. - if (auto* strided_dset = dynamic_cast*>(&index.data()); - strided_dset != nullptr) { - // Set TEAM_SIZE and DATASET_BLOCK_SIZE to zero tentatively since these parameters cannot be - // determined here. They are set just before kernel launch. - using dataset_desc_t = standard_dataset_descriptor_t; - // Search using a plain (strided) row-major dataset - const dataset_desc_t dataset_desc(strided_dset->view().data_handle(), - strided_dset->n_rows(), - strided_dset->dim(), - strided_dset->stride()); - search_main_core(res, - params, - dataset_desc, - graph_internal, - queries, - neighbors, - distances, - sample_filter, - index.metric()); - } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); - vpq_dset != nullptr) { - // Search using a compressed dataset - RAFT_FAIL("FP32 VPQ dataset support is coming soon"); - } else if (auto* vpq_dset = dynamic_cast*>(&index.data()); - vpq_dset != nullptr) { - launch_vpq_search_main_core( - res, - vpq_dset, - params, - graph_internal, - queries, - neighbors, - distances, - sample_filter, - index.metric()); - } else if (auto* empty_dset = dynamic_cast*>(&index.data()); - empty_dset != nullptr) { - // Forgot to add a dataset. - RAFT_FAIL( - "Attempted to search without a dataset. Please call index.update_dataset(...) first."); - } else { - // This is a logic error. - RAFT_FAIL("Unrecognized dataset format"); - } - - static_assert(std::is_same_v, - "only float distances are supported at the moment"); - float* dist_out = distances.data_handle(); - const DistanceT* dist_in = distances.data_handle(); - // We're converting the data from T to DistanceT during distance computation - // and divide the values by kDivisor. Here we restore the original scale. - constexpr float kScale = spatial::knn::detail::utils::config::kDivisor / - spatial::knn::detail::utils::config::kDivisor; - ivf::detail::postprocess_distances(dist_out, - dist_in, - index.metric(), - distances.extent(0), - distances.extent(1), - kScale, - true, - resource::get_cuda_stream(res)); -} -/** @} */ // end group cagra - -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh deleted file mode 100644 index 600c8785e0..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ /dev/null @@ -1,270 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace raft::neighbors::cagra::detail { - -constexpr int serialization_version = 4; - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] res the raft resource handle - * @param[in] filename the file name for saving the index - * @param[in] index_ CAGRA index - * - */ -template -void serialize(raft::resources const& res, - std::ostream& os, - const index& index_, - bool include_dataset) -{ - common::nvtx::range fun_scope("cagra::serialize"); - - RAFT_LOG_DEBUG( - "Saving CAGRA index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); - - std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); - dtype_string.resize(4); - os << dtype_string; - - serialize_scalar(res, os, serialization_version); - serialize_scalar(res, os, index_.size()); - serialize_scalar(res, os, index_.dim()); - serialize_scalar(res, os, index_.graph_degree()); - serialize_scalar(res, os, index_.metric()); - serialize_mdspan(res, os, index_.graph()); - - include_dataset &= (index_.data().n_rows() > 0); - - serialize_scalar(res, os, include_dataset); - if (include_dataset) { - RAFT_LOG_INFO("Saving CAGRA index with dataset"); - neighbors::detail::serialize(res, os, index_.data()); - } else { - RAFT_LOG_DEBUG("Saving CAGRA index WITHOUT dataset"); - } -} - -template -void serialize(raft::resources const& res, - const std::string& filename, - const index& index_, - bool include_dataset) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - detail::serialize(res, of, index_, include_dataset); - - of.close(); - if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } -} - -template -void serialize_to_hnswlib(raft::resources const& res, - std::ostream& os, - const raft::neighbors::cagra::index& index_) -{ - // static_assert(std::is_same_v or std::is_same_v, - // "An hnswlib index can only be trained with int32 or uint32 IdxT"); - common::nvtx::range fun_scope("cagra::serialize"); - RAFT_LOG_DEBUG("Saving CAGRA index to hnswlib format, size %zu, dim %u", - static_cast(index_.size()), - index_.dim()); - - // offset_level_0 - std::size_t offset_level_0 = 0; - os.write(reinterpret_cast(&offset_level_0), sizeof(std::size_t)); - // max_element - std::size_t max_element = index_.size(); - os.write(reinterpret_cast(&max_element), sizeof(std::size_t)); - // curr_element_count - std::size_t curr_element_count = index_.size(); - os.write(reinterpret_cast(&curr_element_count), sizeof(std::size_t)); - // Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t, - // labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) + - // dim * sizeof(data_t) + sizeof(labeltype) - auto size_data_per_element = static_cast(index_.graph_degree() * sizeof(IdxT) + 4 + - index_.dim() * sizeof(T) + 8); - os.write(reinterpret_cast(&size_data_per_element), sizeof(std::size_t)); - // label_offset - std::size_t label_offset = size_data_per_element - 8; - os.write(reinterpret_cast(&label_offset), sizeof(std::size_t)); - // offset_data - auto offset_data = static_cast(index_.graph_degree() * sizeof(IdxT) + 4); - os.write(reinterpret_cast(&offset_data), sizeof(std::size_t)); - // max_level - int max_level = 1; - os.write(reinterpret_cast(&max_level), sizeof(int)); - // entrypoint_node - auto entrypoint_node = static_cast(index_.size() / 2); - os.write(reinterpret_cast(&entrypoint_node), sizeof(int)); - // max_M - auto max_M = static_cast(index_.graph_degree() / 2); - os.write(reinterpret_cast(&max_M), sizeof(std::size_t)); - // max_M0 - std::size_t max_M0 = index_.graph_degree(); - os.write(reinterpret_cast(&max_M0), sizeof(std::size_t)); - // M - auto M = static_cast(index_.graph_degree() / 2); - os.write(reinterpret_cast(&M), sizeof(std::size_t)); - // mult, can be anything - double mult = 0.42424242; - os.write(reinterpret_cast(&mult), sizeof(double)); - // efConstruction, can be anything - std::size_t efConstruction = 500; - os.write(reinterpret_cast(&efConstruction), sizeof(std::size_t)); - - auto dataset = index_.dataset(); - // Remove padding before saving the dataset - auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), - sizeof(T) * host_dataset.extent(1), - dataset.data_handle(), - sizeof(T) * dataset.stride(0), - sizeof(T) * host_dataset.extent(1), - dataset.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - resource::sync_stream(res); - - auto graph = index_.graph(); - auto host_graph = - raft::make_host_matrix(graph.extent(0), graph.extent(1)); - raft::copy(host_graph.data_handle(), - graph.data_handle(), - graph.size(), - raft::resource::get_cuda_stream(res)); - resource::sync_stream(res); - - // Write one dataset and graph row at a time - for (std::size_t i = 0; i < index_.size(); i++) { - auto graph_degree = static_cast(index_.graph_degree()); - os.write(reinterpret_cast(&graph_degree), sizeof(int)); - - for (std::size_t j = 0; j < index_.graph_degree(); ++j) { - auto graph_elem = host_graph(i, j); - os.write(reinterpret_cast(&graph_elem), sizeof(IdxT)); - } - - auto data_row = host_dataset.data_handle() + (index_.dim() * i); - // if constexpr (std::is_same_v) { - for (std::size_t j = 0; j < index_.dim(); ++j) { - auto data_elem = host_dataset(i, j); - os.write(reinterpret_cast(&data_elem), sizeof(T)); - } - // } else if constexpr (std::is_same_v or std::is_same_v) { - // for (std::size_t j = 0; j < index_.dim(); ++j) { - // auto data_elem = static_cast(host_dataset(i, j)); - // os.write(reinterpret_cast(&data_elem), sizeof(int)); - // } - // } - - os.write(reinterpret_cast(&i), sizeof(std::size_t)); - } - - for (std::size_t i = 0; i < index_.size(); i++) { - // zeroes - auto zero = 0; - os.write(reinterpret_cast(&zero), sizeof(int)); - } -} - -template -void serialize_to_hnswlib(raft::resources const& res, - const std::string& filename, - const raft::neighbors::cagra::index& index_) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - detail::serialize_to_hnswlib(res, of, index_); - - of.close(); - if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } -} - -/** Load an index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] res the raft resource handle - * @param[in] filename the name of the file that stores the index - * @param[in] index_ CAGRA index - * - */ -template -auto deserialize(raft::resources const& res, std::istream& is) -> index -{ - common::nvtx::range fun_scope("cagra::deserialize"); - - char dtype_string[4]; - is.read(dtype_string, 4); - - auto ver = deserialize_scalar(res, is); - if (ver != serialization_version) { - RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); - } - auto n_rows = deserialize_scalar(res, is); - auto dim = deserialize_scalar(res, is); - auto graph_degree = deserialize_scalar(res, is); - auto metric = deserialize_scalar(res, is); - - auto graph = raft::make_host_matrix(n_rows, graph_degree); - deserialize_mdspan(res, is, graph.view()); - - index idx(res, metric); - idx.update_graph(res, raft::make_const_mdspan(graph.view())); - bool has_dataset = deserialize_scalar(res, is); - if (has_dataset) { - idx.update_dataset(res, neighbors::detail::deserialize_dataset(res, is)); - } - return idx; -} - -template -auto deserialize(raft::resources const& res, const std::string& filename) -> index -{ - std::ifstream is(filename, std::ios::in | std::ios::binary); - - if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - auto index = detail::deserialize(res, is); - - is.close(); - - return index; -} -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp deleted file mode 100644 index 80ee7a36f1..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ /dev/null @@ -1,316 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "device_common.hpp" -#include "hashmap.hpp" -#include "utils.hpp" - -#include -#include -#include -#include - -#include - -namespace raft::neighbors::cagra::detail { -namespace device { - -// using LOAD_256BIT_T = ulonglong4; -using LOAD_128BIT_T = uint4; -using LOAD_64BIT_T = uint64_t; - -template -_RAFT_DEVICE constexpr unsigned get_vlen() -{ - return utils::size_of() / utils::size_of(); -} - -template -_RAFT_DEVICE void compute_distance_to_random_nodes( - INDEX_T* const result_indices_ptr, // [num_pickup] - DISTANCE_T* const result_distances_ptr, // [num_pickup] - const typename DATASET_DESCRIPTOR_T::QUERY_T* const query_buffer, - const DATASET_DESCRIPTOR_T& dataset_desc, - const std::size_t num_pickup, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const INDEX_T* const seed_ptr, // [num_seeds] - const uint32_t num_seeds, - INDEX_T* const visited_hash_ptr, - const uint32_t hash_bitlen, - const raft::distance::DistanceType metric, - const uint32_t block_id = 0, - const uint32_t num_blocks = 1) -{ - uint32_t max_i = num_pickup; - if (max_i % (32 / TEAM_SIZE)) { max_i += (32 / TEAM_SIZE) - (max_i % (32 / TEAM_SIZE)); } - - for (uint32_t i = threadIdx.x / TEAM_SIZE; i < max_i; i += blockDim.x / TEAM_SIZE) { - const bool valid_i = (i < num_pickup); - - INDEX_T best_index_team_local; - DISTANCE_T best_norm2_team_local = utils::get_max_value(); - for (uint32_t j = 0; j < num_distilation; j++) { - // Select a node randomly and compute the distance to it - INDEX_T seed_index; - if (valid_i) { - // uint32_t gid = i + (num_pickup * (j + (num_distilation * block_id))); - uint32_t gid = block_id + (num_blocks * (i + (num_pickup * j))); - if (seed_ptr && (gid < num_seeds)) { - seed_index = seed_ptr[gid]; - } else { - seed_index = device::xorshift64(gid ^ rand_xor_mask) % dataset_desc.size; - } - } - - DISTANCE_T norm2; - switch (metric) { - case raft::distance::L2Expanded: - norm2 = dataset_desc.template compute_similarity( - query_buffer, seed_index, valid_i); - break; - case raft::distance::InnerProduct: - norm2 = dataset_desc.template compute_similarity( - query_buffer, seed_index, valid_i); - break; - default: break; - } - - if (valid_i && (norm2 < best_norm2_team_local)) { - best_norm2_team_local = norm2; - best_index_team_local = seed_index; - } - } - - const unsigned lane_id = threadIdx.x % TEAM_SIZE; - if (valid_i && lane_id == 0) { - if (hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) { - result_distances_ptr[i] = best_norm2_team_local; - result_indices_ptr[i] = best_index_team_local; - } else { - result_distances_ptr[i] = utils::get_max_value(); - result_indices_ptr[i] = utils::get_max_value(); - } - } - } -} - -template -_RAFT_DEVICE void compute_distance_to_child_nodes( - INDEX_T* const result_child_indices_ptr, - DISTANCE_T* const result_child_distances_ptr, - // query - const typename DATASET_DESCRIPTOR_T::QUERY_T* const query_buffer, - // [dataset_dim, dataset_size] - const DATASET_DESCRIPTOR_T& dataset_desc, - // [knn_k, dataset_size] - const INDEX_T* const knn_graph, - const std::uint32_t knn_k, - // hashmap - INDEX_T* const visited_hashmap_ptr, - const std::uint32_t hash_bitlen, - const INDEX_T* const parent_indices, - const INDEX_T* const internal_topk_list, - const std::uint32_t search_width, - const raft::distance::DistanceType metric) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - // Read child indices of parents from knn graph and check if the distance - // computaiton is necessary. - for (uint32_t i = threadIdx.x; i < knn_k * search_width; i += blockDim.x) { - const INDEX_T smem_parent_id = parent_indices[i / knn_k]; - INDEX_T child_id = invalid_index; - if (smem_parent_id != invalid_index) { - const auto parent_id = internal_topk_list[smem_parent_id] & ~index_msb_1_mask; - child_id = knn_graph[(i % knn_k) + (static_cast(knn_k) * parent_id)]; - } - if (child_id != invalid_index) { - if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) { - child_id = invalid_index; - } - } - result_child_indices_ptr[i] = child_id; - } - __syncthreads(); - - // Compute the distance to child nodes - std::uint32_t max_i = knn_k * search_width; - if (max_i % (32 / TEAM_SIZE)) { max_i += (32 / TEAM_SIZE) - (max_i % (32 / TEAM_SIZE)); } - for (std::uint32_t tid = threadIdx.x; tid < max_i * TEAM_SIZE; tid += blockDim.x) { - const auto i = tid / TEAM_SIZE; - const bool valid_i = (i < (knn_k * search_width)); - INDEX_T child_id = invalid_index; - if (valid_i) { child_id = result_child_indices_ptr[i]; } - - DISTANCE_T norm2; - switch (metric) { - case raft::distance::L2Expanded: - norm2 = - dataset_desc - .template compute_similarity( - query_buffer, child_id, child_id != invalid_index); - break; - case raft::distance::InnerProduct: - norm2 = dataset_desc.template compute_similarity( - query_buffer, child_id, child_id != invalid_index); - break; - default: break; - } - - // Store the distance - const unsigned lane_id = threadIdx.x % TEAM_SIZE; - if (valid_i && lane_id == 0) { - if (child_id != invalid_index) { - result_child_distances_ptr[i] = norm2; - } else { - result_child_distances_ptr[i] = utils::get_max_value(); - } - } - } -} - -} // namespace device - -template -struct dataset_descriptor_base_t { - using INDEX_T = INDEX_T_; - using QUERY_T = QUERY_T_; - using DISTANCE_T = DISTANCE_T_; - - const INDEX_T size; - const std::uint32_t dim; - - dataset_descriptor_base_t(const INDEX_T size, const std::uint32_t dim) : size(size), dim(dim) {} -}; - -template -struct standard_dataset_descriptor_t - : public dataset_descriptor_base_t { - using LOAD_T = device::LOAD_128BIT_T; - using DATA_T = DATA_T_; - using QUERY_T = typename dataset_descriptor_base_t::QUERY_T; - - const DATA_T* const ptr; - const std::size_t ld; - using dataset_descriptor_base_t::size; - using dataset_descriptor_base_t::dim; - - standard_dataset_descriptor_t(const DATA_T* const ptr, - const std::size_t size, - const std::uint32_t dim, - const std::size_t ld) - : dataset_descriptor_base_t(size, dim), ptr(ptr), ld(ld) - { - } - - static const std::uint32_t smem_buffer_size_in_byte = 0; - __device__ void set_smem_ptr(void* const){}; - - template - __device__ void copy_query(const DATA_T* const dmem_query_ptr, - QUERY_T* const smem_query_ptr, - const std::uint32_t query_smem_buffer_length) - { - for (unsigned i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) { - unsigned j = device::swizzling(i); - if (i < dim) { - smem_query_ptr[j] = spatial::knn::detail::utils::mapping{}(dmem_query_ptr[i]); - } else { - smem_query_ptr[j] = 0.0; - } - } - } - - template - std::enable_if_t __device__ - dist_op(T a, T b) const - { - T diff = a - b; - return diff * diff; - } - - template - std::enable_if_t __device__ - dist_op(T a, T b) const - { - return -a * b; - } - - template - __device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr, - const INDEX_T dataset_i, - const bool valid) const - { - const auto dataset_ptr = ptr + dataset_i * ld; - const unsigned lane_id = threadIdx.x % TEAM_SIZE; - constexpr unsigned vlen = device::get_vlen(); - // #include (DATASET_BLOCK_DIM, TEAM_SIZE * vlen); - raft::TxN_t dl_buff[reg_nelem]; - - DISTANCE_T norm2 = 0; - if (valid) { - for (uint32_t elem_offset = 0; elem_offset < dim; elem_offset += DATASET_BLOCK_DIM) { -#pragma unroll - for (uint32_t e = 0; e < reg_nelem; e++) { - const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset; - if (k >= dim) break; - dl_buff[e].load(dataset_ptr, k); - } -#pragma unroll - for (uint32_t e = 0; e < reg_nelem; e++) { - const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset; - if (k >= dim) break; -#pragma unroll - for (uint32_t v = 0; v < vlen; v++) { - const uint32_t kv = k + v; - // Note this loop can go above the dataset_dim for padded arrays. This is not a problem - // because: - // - Above the last element (dataset_dim-1), the query array is filled with zeros. - // - The data buffer has to be also padded with zeros. - DISTANCE_T d = query_ptr[device::swizzling(kv)]; - norm2 += dist_op( - d, spatial::knn::detail::utils::mapping{}(dl_buff[e].val.data[v])); - } - } - } - } - for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { - norm2 += __shfl_xor_sync(0xffffffff, norm2, offset); - } - return norm2; - } -}; - -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh b/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh deleted file mode 100644 index c922a0d7f4..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "compute_distance.hpp" - -#include -#include - -namespace raft::neighbors::cagra::detail { -template -struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t { - using LOAD_T = device::LOAD_128BIT_T; - using DATA_T = DATA_T_; - using CODE_BOOK_T = CODE_BOOK_T_; - using QUERY_T = typename dataset_descriptor_base_t::QUERY_T; - - static_assert(std::is_same_v, "Only CODE_BOOK_T = `half` is supported now"); - - const std::uint8_t* encoded_dataset_ptr; - const std::uint32_t encoded_dataset_dim; - const std::uint32_t n_subspace; - const CODE_BOOK_T* vq_code_book_ptr; - const float vq_scale; - const CODE_BOOK_T* pq_code_book_ptr; - const float pq_scale; - using dataset_descriptor_base_t::size; - using dataset_descriptor_base_t::dim; - - // Set on device - CODE_BOOK_T* smem_pq_code_book_ptr; - static const std::uint32_t smem_buffer_size_in_byte = - (1 << PQ_BITS) * PQ_LEN * utils::size_of(); - - __device__ void set_smem_ptr(void* const smem_ptr) - { - smem_pq_code_book_ptr = reinterpret_cast(smem_ptr); - - // Copy PQ table - for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) { - half2 buf2; - buf2.x = pq_code_book_ptr[i]; - buf2.y = pq_code_book_ptr[i + 1]; - - // Change the order of PQ code book array to reduce the - // frequency of bank conflicts. - constexpr auto num_elements_per_bank = 4 / utils::size_of(); - constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank; - const auto j = i / num_elements_per_bank; - const auto smem_index = - (j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS); - reinterpret_cast(smem_pq_code_book_ptr)[smem_index] = buf2; - } - } - - cagra_q_dataset_descriptor_t(const std::uint8_t* encoded_dataset_ptr, - const std::uint32_t encoded_dataset_dim, - const std::uint32_t n_subspace, - const CODE_BOOK_T* const vq_code_book_ptr, - const float vq_scale, - const CODE_BOOK_T* const pq_code_book_ptr, - const float pq_scale, - const std::size_t size, - const std::uint32_t dim) - : dataset_descriptor_base_t(size, dim), - encoded_dataset_ptr(encoded_dataset_ptr), - encoded_dataset_dim(encoded_dataset_dim), - n_subspace(n_subspace), - vq_code_book_ptr(vq_code_book_ptr), - vq_scale(vq_scale), - pq_code_book_ptr(pq_code_book_ptr), - pq_scale(pq_scale) - { - } - - template - __device__ void copy_query(const DATA_T* const dmem_query_ptr, - QUERY_T* const smem_query_ptr, - const std::uint32_t query_smem_buffer_length) - { - constexpr spatial::knn::detail::utils::mapping mapping{}; - for (unsigned i = threadIdx.x * 2; i < dim; i += blockDim.x * 2) { - half2 buf2{0, 0}; - if (i < dim) { buf2.x = mapping(dmem_query_ptr[i]); } - if (i + 1 < dim) { buf2.y = mapping(dmem_query_ptr[i + 1]); } - if ((PQ_BITS == 8) && (PQ_LEN % 2 == 0)) { - // Use swizzling in the condition to reduce bank conflicts in shared - // memory, which are likely to occur when pq_code_book_dim is large. - ((half2*)smem_query_ptr)[device::swizzling(i / 2)] = - buf2; - } else { - (reinterpret_cast(smem_query_ptr + i))[0] = buf2; - } - } - } - - template - __device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr, - const INDEX_T node_id, - const bool valid) const - { - float norm = 0; - if (valid) { - const unsigned lane_id = threadIdx.x % TEAM_SIZE; - const uint32_t vq_code = *(reinterpret_cast( - encoded_dataset_ptr + (static_cast(encoded_dataset_dim) * node_id))); - if (PQ_BITS == 8) { - for (uint32_t elem_offset = 0; elem_offset < dim; elem_offset += DATASET_BLOCK_DIM) { - constexpr unsigned vlen = 4; // **** DO NOT CHANGE **** - constexpr unsigned nelem = - raft::div_rounding_up_unsafe(DATASET_BLOCK_DIM / PQ_LEN, TEAM_SIZE * vlen); - // Loading PQ codes - uint32_t pq_codes[nelem]; -#pragma unroll - for (std::uint32_t e = 0; e < nelem; e++) { - const std::uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset / PQ_LEN; - if (k >= n_subspace) break; - // Loading 4 x 8-bit PQ-codes using 32-bit load ops (from device memory) - pq_codes[e] = *(reinterpret_cast( - encoded_dataset_ptr + (static_cast(encoded_dataset_dim) * node_id) + - 4 + k)); - } - // - if constexpr (PQ_LEN % 2 == 0) { - // **** Use half2 for distance computation **** - half2 norm2{0, 0}; -#pragma unroll - for (std::uint32_t e = 0; e < nelem; e++) { - const std::uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset / PQ_LEN; - if (k >= n_subspace) break; - // Loading VQ code-book - raft::TxN_t vq_vals[PQ_LEN]; -#pragma unroll - for (std::uint32_t m = 0; m < PQ_LEN; m += 1) { - const uint32_t d = (vlen * m) + (PQ_LEN * k); - if (d >= dim) break; - vq_vals[m].load( - reinterpret_cast(vq_code_book_ptr + d + (dim * vq_code)), 0); - } - // Compute distance - std::uint32_t pq_code = pq_codes[e]; -#pragma unroll - for (std::uint32_t v = 0; v < vlen; v++) { - if (PQ_LEN * (v + k) >= dim) break; -#pragma unroll - for (std::uint32_t m = 0; m < PQ_LEN; m += 2) { - const std::uint32_t d1 = m + (PQ_LEN * v); - const std::uint32_t d = d1 + (PQ_LEN * k); - // Loading query vector in smem - half2 diff2 = (reinterpret_cast( - query_ptr))[device::swizzling(d / 2)]; - // Loading PQ code book in smem - diff2 -= *(reinterpret_cast( - smem_pq_code_book_ptr + (1 << PQ_BITS) * 2 * (m / 2) + (2 * (pq_code & 0xff)))); - diff2 -= vq_vals[d1 / vlen].val.data[(d1 % vlen) / 2]; - norm2 += diff2 * diff2; - } - pq_code >>= 8; - } - } - norm += static_cast(norm2.x + norm2.y); - } else { - // **** Use float for distance computation **** -#pragma unroll - for (std::uint32_t e = 0; e < nelem; e++) { - const std::uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset / PQ_LEN; - if (k >= n_subspace) break; - // Loading VQ code-book - raft::TxN_t vq_vals[PQ_LEN]; -#pragma unroll - for (std::uint32_t m = 0; m < PQ_LEN; m++) { - const std::uint32_t d = (vlen * m) + (PQ_LEN * k); - if (d >= dim) break; - // Loading 4 x 8/16-bit VQ-values using 32/64-bit load ops (from L2$ or device - // memory) - vq_vals[m].load( - reinterpret_cast(vq_code_book_ptr + d + (dim * vq_code)), 0); - } - // Compute distance - std::uint32_t pq_code = pq_codes[e]; -#pragma unroll - for (std::uint32_t v = 0; v < vlen; v++) { - if (PQ_LEN * (v + k) >= dim) break; - raft::TxN_t pq_vals; - pq_vals.load( - reinterpret_cast(smem_pq_code_book_ptr + PQ_LEN * (pq_code & 0xff)), - 0); // (from L1$ or smem) -#pragma unroll - for (std::uint32_t m = 0; m < PQ_LEN; m++) { - const std::uint32_t d1 = m + (PQ_LEN * v); - const std::uint32_t d = d1 + (PQ_LEN * k); - // if (d >= dataset_dim) break; - DISTANCE_T diff = query_ptr[d]; // (from smem) - diff -= pq_scale * static_cast(pq_vals.data[m]); - diff -= vq_scale * static_cast(vq_vals[d1 / vlen].val.data[d1 % vlen]); - norm += diff * diff; - } - pq_code >>= 8; - } - } - } - } - } - } - for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { - norm += __shfl_xor_sync(0xffffffff, norm, offset); - } - return norm; - } -}; - -} // namespace raft::neighbors::cagra::detail \ No newline at end of file diff --git a/cpp/include/raft/neighbors/detail/cagra/device_common.hpp b/cpp/include/raft/neighbors/detail/cagra/device_common.hpp deleted file mode 100644 index d4d69e6a67..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/device_common.hpp +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "utils.hpp" - -#include - -#include - -#include -#include - -namespace raft::neighbors::cagra::detail { -namespace device { - -// warpSize for compile time calculation -constexpr unsigned warp_size = 32; - -/** Xorshift rondem number generator. - * - * See https://en.wikipedia.org/wiki/Xorshift#xorshift for reference. - */ -_RAFT_HOST_DEVICE inline uint64_t xorshift64(uint64_t u) -{ - u ^= u >> 12; - u ^= u << 25; - u ^= u >> 27; - return u * 0x2545F4914F6CDD1DULL; -} - -template -_RAFT_DEVICE inline T swizzling(T x) -{ - // Address swizzling reduces bank conflicts in shared memory, but increases - // the amount of operation instead. - // return x; - if constexpr (X_MAX <= 1024) { - return (x) ^ ((x) >> 5); - } else { - return (x) ^ (((x) >> 5) & 0x1f); - } -} - -} // namespace device -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/factory.cuh b/cpp/include/raft/neighbors/detail/cagra/factory.cuh deleted file mode 100644 index 6d7fc6c966..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/factory.cuh +++ /dev/null @@ -1,95 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "search_multi_cta.cuh" -#include "search_multi_kernel.cuh" -#include "search_plan.cuh" -#include "search_single_cta.cuh" - -#include - -namespace raft::neighbors::cagra::detail { - -template -class factory { - using T = typename DATASET_DESCRIPTOR_T::DATA_T; - using IdxT = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DistanceT = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - - public: - /** - * Create a search structure for dataset with dim features. - */ - static std::unique_ptr> create( - raft::resources const& res, - search_params const& params, - int64_t dim, - int64_t graph_degree, - uint32_t topk, - const raft::distance::DistanceType metric) - { - search_plan_impl_base plan(params, dim, graph_degree, topk, metric); - switch (plan.dataset_block_dim) { - case 128: - switch (plan.team_size) { - case 8: return dispatch_kernel<128, 8>(res, plan); break; - default: THROW("Incorrect team size %lu", plan.team_size); - } - break; - case 256: - switch (plan.team_size) { - case 16: return dispatch_kernel<256, 16>(res, plan); break; - default: THROW("Incorrect team size %lu", plan.team_size); - } - break; - case 512: - switch (plan.team_size) { - case 32: return dispatch_kernel<512, 32>(res, plan); break; - default: THROW("Incorrect team size %lu", plan.team_size); - } - break; - default: THROW("Incorrect dataset_block_dim (%lu)\n", plan.dataset_block_dim); - } - return std::unique_ptr>(); - } - - private: - template - static std::unique_ptr> - dispatch_kernel(raft::resources const& res, search_plan_impl_base& plan) - { - if (plan.algo == search_algo::SINGLE_CTA) { - return std::unique_ptr>( - new single_cta_search:: - search( - res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric)); - } else if (plan.algo == search_algo::MULTI_CTA) { - return std::unique_ptr>( - new multi_cta_search:: - search( - res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric)); - } else { - return std::unique_ptr>( - new multi_kernel_search:: - search( - res, plan, plan.dim, plan.graph_degree, plan.topk, plan.metric)); - } - } -}; -}; // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh deleted file mode 100644 index 93faf9dd19..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ /dev/null @@ -1,590 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "utils.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -#include -#include -#include -#include -#include - -namespace raft::neighbors::cagra::detail { -namespace graph { - -// unnamed namespace to avoid multiple definition error -namespace { -inline double cur_time(void) -{ - struct timeval tv; - gettimeofday(&tv, NULL); - return ((double)tv.tv_sec + (double)tv.tv_usec * 1e-6); -} - -template -__device__ inline void swap(T& val1, T& val2) -{ - T val0 = val1; - val1 = val2; - val2 = val0; -} - -template -__device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2, bool ascending) -{ - if (key1 == key2) { return false; } - if ((key1 > key2) == ascending) { - swap(key1, key2); - swap(val1, val2); - return true; - } - return false; -} - -template -RAFT_KERNEL kern_sort(const DATA_T* const dataset, // [dataset_chunk_size, dataset_dim] - const IdxT dataset_size, - const uint32_t dataset_dim, - IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - const uint32_t graph_size, - const uint32_t graph_degree) -{ - const IdxT srcNode = (blockDim.x * blockIdx.x + threadIdx.x) / raft::WarpSize; - if (srcNode >= graph_size) { return; } - - const uint32_t lane_id = threadIdx.x % raft::WarpSize; - - float my_keys[numElementsPerThread]; - IdxT my_vals[numElementsPerThread]; - - // Compute distance from a src node to its neighbors - for (int k = 0; k < graph_degree; k++) { - const IdxT dstNode = knn_graph[k + static_cast(graph_degree) * srcNode]; - float dist = 0.0; - for (int d = lane_id; d < dataset_dim; d += raft::WarpSize) { - float diff = spatial::knn::detail::utils::mapping{}( - dataset[d + static_cast(dataset_dim) * srcNode]) - - spatial::knn::detail::utils::mapping{}( - dataset[d + static_cast(dataset_dim) * dstNode]); - dist += diff * diff; - } - dist += __shfl_xor_sync(0xffffffff, dist, 1); - dist += __shfl_xor_sync(0xffffffff, dist, 2); - dist += __shfl_xor_sync(0xffffffff, dist, 4); - dist += __shfl_xor_sync(0xffffffff, dist, 8); - dist += __shfl_xor_sync(0xffffffff, dist, 16); - if (lane_id == (k % raft::WarpSize)) { - my_keys[k / raft::WarpSize] = dist; - my_vals[k / raft::WarpSize] = dstNode; - } - } - for (int k = graph_degree; k < raft::WarpSize * numElementsPerThread; k++) { - if (lane_id == k % raft::WarpSize) { - my_keys[k / raft::WarpSize] = utils::get_max_value(); - my_vals[k / raft::WarpSize] = utils::get_max_value(); - } - } - - // Sort by RAFT bitonic sort - raft::util::bitonic(true).sort(my_keys, my_vals); - - // Update knn_graph - for (int i = 0; i < numElementsPerThread; i++) { - const int k = i * raft::WarpSize + lane_id; - if (k < graph_degree) { - knn_graph[k + (static_cast(graph_degree) * srcNode)] = my_vals[i]; - } - } -} - -template -RAFT_KERNEL kern_prune(const IdxT* const knn_graph, // [graph_chunk_size, graph_degree] - const uint32_t graph_size, - const uint32_t graph_degree, - const uint32_t degree, - const uint32_t batch_size, - const uint32_t batch_id, - uint8_t* const detour_count, // [graph_chunk_size, graph_degree] - uint32_t* const num_no_detour_edges, // [graph_size] - uint64_t* const stats) -{ - __shared__ uint32_t smem_num_detour[MAX_DEGREE]; - uint64_t* const num_retain = stats; - uint64_t* const num_full = stats + 1; - - const uint64_t nid = blockIdx.x + (batch_size * batch_id); - if (nid >= graph_size) { return; } - for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - smem_num_detour[k] = 0; - } - __syncthreads(); - - const uint64_t iA = nid; - if (iA >= graph_size) { return; } - - // count number of detours (A->D->B) - for (uint32_t kAD = 0; kAD < graph_degree - 1; kAD++) { - const uint64_t iD = knn_graph[kAD + (graph_degree * iA)]; - for (uint32_t kDB = threadIdx.x; kDB < graph_degree; kDB += blockDim.x) { - const uint64_t iB_candidate = knn_graph[kDB + ((uint64_t)graph_degree * iD)]; - for (uint32_t kAB = kAD + 1; kAB < graph_degree; kAB++) { - // if ( kDB < kAB ) - { - const uint64_t iB = knn_graph[kAB + (graph_degree * iA)]; - if (iB == iB_candidate) { - atomicAdd(smem_num_detour + kAB, 1); - break; - } - } - } - } - __syncthreads(); - } - - uint32_t num_edges_no_detour = 0; - for (uint32_t k = threadIdx.x; k < graph_degree; k += blockDim.x) { - detour_count[k + (graph_degree * iA)] = min(smem_num_detour[k], (uint32_t)255); - if (smem_num_detour[k] == 0) { num_edges_no_detour++; } - } - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 1); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 2); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 4); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 8); - num_edges_no_detour += __shfl_xor_sync(0xffffffff, num_edges_no_detour, 16); - num_edges_no_detour = min(num_edges_no_detour, degree); - - if (threadIdx.x == 0) { - num_no_detour_edges[iA] = num_edges_no_detour; - atomicAdd((unsigned long long int*)num_retain, (unsigned long long int)num_edges_no_detour); - if (num_edges_no_detour >= degree) { atomicAdd((unsigned long long int*)num_full, 1); } - } -} - -template -RAFT_KERNEL kern_make_rev_graph(const IdxT* const dest_nodes, // [graph_size] - IdxT* const rev_graph, // [size, degree] - uint32_t* const rev_graph_count, // [graph_size] - const uint32_t graph_size, - const uint32_t degree) -{ - const uint32_t tid = threadIdx.x + (blockDim.x * blockIdx.x); - const uint32_t tnum = blockDim.x * gridDim.x; - - for (uint32_t src_id = tid; src_id < graph_size; src_id += tnum) { - const IdxT dest_id = dest_nodes[src_id]; - if (dest_id >= graph_size) continue; - - const uint32_t pos = atomicAdd(rev_graph_count + dest_id, 1); - if (pos < degree) { rev_graph[pos + ((uint64_t)degree * dest_id)] = src_id; } - } -} - -template -uint64_t pos_in_array(T val, const T* array, uint64_t num) -{ - for (uint64_t i = 0; i < num; i++) { - if (val == array[i]) { return i; } - } - return num; -} - -template -void shift_array(T* array, uint64_t num) -{ - for (uint64_t i = num; i > 0; i--) { - array[i] = array[i - 1]; - } -} -} // namespace - -template , memory_type::device>, - typename g_accessor = - host_device_accessor, memory_type::host>> -void sort_knn_graph(raft::resources const& res, - mdspan, row_major, d_accessor> dataset, - mdspan, row_major, g_accessor> knn_graph) -{ - RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), - "dataset size is expected to have the same number of graph index size"); - const uint32_t dataset_size = dataset.extent(0); - const uint32_t dataset_dim = dataset.extent(1); - const DataT* dataset_ptr = dataset.data_handle(); - - const IdxT graph_size = dataset_size; - const uint32_t input_graph_degree = knn_graph.extent(1); - IdxT* const input_graph_ptr = knn_graph.data_handle(); - - auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); - - // - // Sorting kNN graph - // - const double time_sort_start = cur_time(); - RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs "); - - auto d_dataset = raft::make_device_matrix(res, dataset_size, dataset_dim); - raft::copy(d_dataset.data_handle(), - dataset_ptr, - dataset_size * dataset_dim, - resource::get_cuda_stream(res)); - - raft::copy(d_input_graph.data_handle(), - input_graph_ptr, - graph_size * input_graph_degree, - resource::get_cuda_stream(res)); - - void (*kernel_sort)( - const DataT* const, const IdxT, const uint32_t, IdxT* const, const uint32_t, const uint32_t); - if (input_graph_degree <= 32) { - constexpr int numElementsPerThread = 1; - kernel_sort = kern_sort; - } else if (input_graph_degree <= 64) { - constexpr int numElementsPerThread = 2; - kernel_sort = kern_sort; - } else if (input_graph_degree <= 128) { - constexpr int numElementsPerThread = 4; - kernel_sort = kern_sort; - } else if (input_graph_degree <= 256) { - constexpr int numElementsPerThread = 8; - kernel_sort = kern_sort; - } else if (input_graph_degree <= 512) { - constexpr int numElementsPerThread = 16; - kernel_sort = kern_sort; - } else if (input_graph_degree <= 1024) { - constexpr int numElementsPerThread = 32; - kernel_sort = kern_sort; - } else { - RAFT_FAIL( - "The degree of input knn graph is too large (%u). " - "It must be equal to or smaller than %d.", - input_graph_degree, - 1024); - } - const auto block_size = 256; - const auto num_warps_per_block = block_size / raft::WarpSize; - const auto grid_size = (graph_size + num_warps_per_block - 1) / num_warps_per_block; - - RAFT_LOG_DEBUG("."); - kernel_sort<<>>( - d_dataset.data_handle(), - dataset_size, - dataset_dim, - d_input_graph.data_handle(), - graph_size, - input_graph_degree); - resource::sync_stream(res); - RAFT_LOG_DEBUG("."); - raft::copy(input_graph_ptr, - d_input_graph.data_handle(), - graph_size * input_graph_degree, - resource::get_cuda_stream(res)); - RAFT_LOG_DEBUG("\n"); - - const double time_sort_end = cur_time(); - RAFT_LOG_DEBUG("# Sorting kNN graph time: %.1lf sec\n", time_sort_end - time_sort_start); -} - -template , memory_type::host>> -void optimize(raft::resources const& res, - mdspan, row_major, g_accessor> knn_graph, - raft::host_matrix_view new_graph) -{ - RAFT_LOG_DEBUG( - "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); - - RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0), - "Each input array is expected to have the same number of rows"); - RAFT_EXPECTS(new_graph.extent(1) <= knn_graph.extent(1), - "output graph cannot have more columns than input graph"); - const uint32_t input_graph_degree = knn_graph.extent(1); - const uint32_t output_graph_degree = new_graph.extent(1); - auto input_graph_ptr = knn_graph.data_handle(); - auto output_graph_ptr = new_graph.data_handle(); - const IdxT graph_size = new_graph.extent(0); - - { - // - // Prune kNN graph - // - auto d_detour_count = - raft::make_device_matrix(res, graph_size, input_graph_degree); - - RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), - 0xff, - graph_size * input_graph_degree * sizeof(uint8_t), - resource::get_cuda_stream(res))); - - auto d_num_no_detour_edges = raft::make_device_vector(res, graph_size); - RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), - 0x00, - graph_size * sizeof(uint32_t), - resource::get_cuda_stream(res))); - - auto dev_stats = raft::make_device_vector(res, 2); - auto host_stats = raft::make_host_vector(2); - - // - // Prune unimportant edges. - // - // The edge to be retained is determined without explicitly considering - // distance or angle. Suppose the edge is the k-th edge of some node-A to - // node-B (A->B). Among the edges originating at node-A, there are k-1 edges - // shorter than the edge A->B. Each of these k-1 edges are connected to a - // different k-1 nodes. Among these k-1 nodes, count the number of nodes with - // edges to node-B, which is the number of 2-hop detours for the edge A->B. - // Once the number of 2-hop detours has been counted for all edges, the - // specified number of edges are picked up for each node, starting with the - // edge with the lowest number of 2-hop detours. - // - const double time_prune_start = cur_time(); - RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - - // Copy input_graph_ptr over to device if necessary - device_matrix_view_from_host d_input_graph( - res, - raft::make_host_matrix_view(input_graph_ptr, graph_size, input_graph_degree)); - - constexpr int MAX_DEGREE = 1024; - if (input_graph_degree > MAX_DEGREE) { - RAFT_FAIL( - "The degree of input knn graph is too large (%u). " - "It must be equal to or smaller than %d.", - input_graph_degree, - 1024); - } - const uint32_t batch_size = - std::min(static_cast(graph_size), static_cast(256 * 1024)); - const uint32_t num_batch = (graph_size + batch_size - 1) / batch_size; - const dim3 threads_prune(32, 1, 1); - const dim3 blocks_prune(batch_size, 1, 1); - - RAFT_CUDA_TRY(cudaMemsetAsync( - dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, resource::get_cuda_stream(res))); - - for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - kern_prune - <<>>( - d_input_graph.data_handle(), - graph_size, - input_graph_degree, - output_graph_degree, - batch_size, - i_batch, - d_detour_count.data_handle(), - d_num_no_detour_edges.data_handle(), - dev_stats.data_handle()); - resource::sync_stream(res); - RAFT_LOG_DEBUG( - "# Pruning kNN Graph on GPUs (%.1lf %%)\r", - (double)std::min((i_batch + 1) * batch_size, graph_size) / graph_size * 100); - } - resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); - - host_matrix_view_from_device detour_count(res, d_detour_count.view()); - - raft::copy( - host_stats.data_handle(), dev_stats.data_handle(), 2, resource::get_cuda_stream(res)); - const auto num_keep = host_stats.data_handle()[0]; - const auto num_full = host_stats.data_handle()[1]; - - // Create pruned kNN graph -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - // Find the `output_graph_degree` smallest detourable count nodes by checking the detourable - // count of the neighbors while increasing the target detourable count from zero. - uint64_t pk = 0; - uint32_t num_detour = 0; - while (pk < output_graph_degree) { - uint32_t next_num_detour = std::numeric_limits::max(); - for (uint64_t k = 0; k < input_graph_degree; k++) { - const auto num_detour_k = detour_count.data_handle()[k + (input_graph_degree * i)]; - // Find the detourable count to check in the next iteration - if (num_detour_k > num_detour) { - next_num_detour = std::min(static_cast(num_detour_k), next_num_detour); - } - - // Store the neighbor index if its detourable count is equal to `num_detour`. - if (num_detour_k != num_detour) { continue; } - output_graph_ptr[pk + (output_graph_degree * i)] = - input_graph_ptr[k + (input_graph_degree * i)]; - pk += 1; - if (pk >= output_graph_degree) break; - } - if (pk >= output_graph_degree) break; - - assert(next_num_detour != std::numeric_limits::max()); - num_detour = next_num_detour; - } - RAFT_EXPECTS(pk == output_graph_degree, - "Couldn't find the output_graph_degree (%u) smallest detourable count nodes for " - "node %lu in the rank-based node reranking process", - output_graph_degree, - static_cast(i)); - } - - const double time_prune_end = cur_time(); - RAFT_LOG_DEBUG( - "# Pruning time: %.1lf sec, " - "avg_no_detour_edges_per_node: %.2lf/%u, " - "nodes_with_no_detour_at_all_edges: %.1lf%%\n", - time_prune_end - time_prune_start, - (double)num_keep / graph_size, - output_graph_degree, - (double)num_full / graph_size * 100); - } - - auto rev_graph = raft::make_host_matrix(graph_size, output_graph_degree); - auto rev_graph_count = raft::make_host_vector(graph_size); - - { - // - // Make reverse graph - // - const double time_make_start = cur_time(); - - device_matrix_view_from_host d_rev_graph(res, rev_graph.view()); - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), - 0xff, - graph_size * output_graph_degree * sizeof(IdxT), - resource::get_cuda_stream(res))); - - auto d_rev_graph_count = raft::make_device_vector(res, graph_size); - RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(), - 0x00, - graph_size * sizeof(uint32_t), - resource::get_cuda_stream(res))); - - auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = raft::make_device_vector(res, graph_size); - - for (uint64_t k = 0; k < output_graph_degree; k++) { -#pragma omp parallel for - for (uint64_t i = 0; i < graph_size; i++) { - dest_nodes.data_handle()[i] = output_graph_ptr[k + (output_graph_degree * i)]; - } - resource::sync_stream(res); - - raft::copy(d_dest_nodes.data_handle(), - dest_nodes.data_handle(), - graph_size, - resource::get_cuda_stream(res)); - - dim3 threads(256, 1, 1); - dim3 blocks(1024, 1, 1); - kern_make_rev_graph<<>>( - d_dest_nodes.data_handle(), - d_rev_graph.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - output_graph_degree); - RAFT_LOG_DEBUG("# Making reverse graph on GPUs: %lu / %u \r", k, output_graph_degree); - } - - resource::sync_stream(res); - RAFT_LOG_DEBUG("\n"); - - if (d_rev_graph.allocated_memory()) { - raft::copy(rev_graph.data_handle(), - d_rev_graph.data_handle(), - graph_size * output_graph_degree, - resource::get_cuda_stream(res)); - } - raft::copy(rev_graph_count.data_handle(), - d_rev_graph_count.data_handle(), - graph_size, - resource::get_cuda_stream(res)); - - const double time_make_end = cur_time(); - RAFT_LOG_DEBUG("# Making reverse graph time: %.1lf sec", time_make_end - time_make_start); - } - - { - // - // Replace some edges with reverse edges - // - const double time_replace_start = cur_time(); - - const uint64_t num_protected_edges = output_graph_degree / 2; - RAFT_LOG_DEBUG("# num_protected_edges: %lu", num_protected_edges); - - constexpr int _omp_chunk = 1024; -#pragma omp parallel for schedule(dynamic, _omp_chunk) - for (uint64_t j = 0; j < graph_size; j++) { - uint64_t k = std::min(rev_graph_count.data_handle()[j], output_graph_degree); - while (k) { - k--; - uint64_t i = rev_graph.data_handle()[k + (output_graph_degree * j)]; - - uint64_t pos = - pos_in_array(i, output_graph_ptr + (output_graph_degree * j), output_graph_degree); - if (pos < num_protected_edges) { continue; } - uint64_t num_shift = pos - num_protected_edges; - if (pos == output_graph_degree) { - num_shift = output_graph_degree - num_protected_edges - 1; - } - shift_array(output_graph_ptr + num_protected_edges + (output_graph_degree * j), - num_shift); - output_graph_ptr[num_protected_edges + (output_graph_degree * j)] = i; - } - if ((omp_get_thread_num() == 0) && ((j % _omp_chunk) == 0)) { - RAFT_LOG_DEBUG("# Replacing reverse edges: %lu / %lu ", j, graph_size); - } - } - RAFT_LOG_DEBUG("\n"); - - const double time_replace_end = cur_time(); - RAFT_LOG_DEBUG("# Replacing edges time: %.1lf sec", time_replace_end - time_replace_start); - - /* stats */ - uint64_t num_replaced_edges = 0; -#pragma omp parallel for reduction(+ : num_replaced_edges) - for (uint64_t i = 0; i < graph_size; i++) { - for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; - const uint64_t pos = - pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); - if (pos == output_graph_degree) { num_replaced_edges += 1; } - } - } - RAFT_LOG_DEBUG("# Average number of replaced edges per node: %.2f", - (double)num_replaced_edges / graph_size); - } -} - -} // namespace graph -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp b/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp deleted file mode 100644 index 034bca6178..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "utils.hpp" - -#include -#include - -#include - -// #pragma GCC diagnostic push -// #pragma GCC diagnostic ignored -// #pragma GCC diagnostic pop -namespace raft::neighbors::cagra::detail { -namespace hashmap { - -_RAFT_HOST_DEVICE inline uint32_t get_size(const uint32_t bitlen) { return 1U << bitlen; } - -template -_RAFT_DEVICE inline void init(IdxT* const table, const unsigned bitlen, unsigned FIRST_TID = 0) -{ - if (threadIdx.x < FIRST_TID) return; - for (unsigned i = threadIdx.x - FIRST_TID; i < get_size(bitlen); i += blockDim.x - FIRST_TID) { - table[i] = utils::get_max_value(); - } -} - -template -_RAFT_DEVICE inline uint32_t insert(IdxT* const table, const uint32_t bitlen, const IdxT key) -{ - // Open addressing is used for collision resolution - const uint32_t size = get_size(bitlen); - const uint32_t bit_mask = size - 1; -#if 1 - // Linear probing - IdxT index = (key ^ (key >> bitlen)) & bit_mask; - constexpr uint32_t stride = 1; -#else - // Double hashing - uint32_t index = key & bit_mask; - const uint32_t stride = (key >> bitlen) * 2 + 1; -#endif - for (unsigned i = 0; i < size; i++) { - const IdxT old = atomicCAS(&table[index], ~static_cast(0), key); - if (old == ~static_cast(0)) { - return 1; - } else if (old == key) { - return 0; - } - index = (index + stride) & bit_mask; - } - return 0; -} - -template -_RAFT_DEVICE inline uint32_t insert(IdxT* const table, const uint32_t bitlen, const IdxT key) -{ - IdxT ret = 0; - if (threadIdx.x % TEAM_SIZE == 0) { ret = insert(table, bitlen, key); } - for (unsigned offset = 1; offset < TEAM_SIZE; offset *= 2) { - ret |= __shfl_xor_sync(0xffffffff, ret, offset); - } - return ret; -} - -} // namespace hashmap -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh deleted file mode 100644 index 4b979bcae8..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ /dev/null @@ -1,268 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "bitonic.hpp" -#include "compute_distance.hpp" -#include "device_common.hpp" -#include "hashmap.hpp" -#include "search_multi_cta_kernel.cuh" -#include "search_plan.cuh" -#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible -#include "utils.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp - -#include -#include -#include -#include -#include -#include - -namespace raft::neighbors::cagra::detail { -namespace multi_cta_search { - -template - -struct search : public search_plan_impl { - using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; - using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; - - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; - - using search_plan_impl::hash_bitlen; - - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; - - using search_plan_impl::smem_size; - - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; - - uint32_t num_cta_per_query; - rmm::device_uvector intermediate_indices; - rmm::device_uvector intermediate_distances; - size_t topk_workspace_size; - rmm::device_uvector topk_workspace; - - search(raft::resources const& res, - search_params params, - int64_t dim, - int64_t graph_degree, - uint32_t topk, - raft::distance::DistanceType metric) - : search_plan_impl( - res, params, dim, graph_degree, topk, metric), - intermediate_indices(0, resource::get_cuda_stream(res)), - intermediate_distances(0, resource::get_cuda_stream(res)), - topk_workspace(0, resource::get_cuda_stream(res)) - - { - set_params(res, params); - } - - void set_params(raft::resources const& res, const search_params& params) - { - constexpr unsigned muti_cta_itopk_size = 32; - this->itopk_size = muti_cta_itopk_size; - search_width = 1; - num_cta_per_query = - max(params.search_width, raft::ceildiv(params.itopk_size, (size_t)muti_cta_itopk_size)); - result_buffer_size = itopk_size + search_width * graph_degree; - typedef raft::Pow2<32> AlignBytes; - unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); - // constexpr unsigned max_result_buffer_size = 256; - RAFT_EXPECTS(result_buffer_size_32 <= 256, "Result buffer size cannot exceed 256"); - - const auto query_smem_buffer_length = - raft::ceildiv(dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; - - smem_size = sizeof(float) * query_smem_buffer_length + - (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + - sizeof(uint32_t) * search_width + sizeof(uint32_t) + - DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte; - RAFT_LOG_DEBUG("# smem_size: %u", smem_size); - - // - // Determine the thread block size - // - constexpr unsigned min_block_size = 64; - constexpr unsigned max_block_size = 1024; - uint32_t block_size = thread_block_size; - if (block_size == 0) { - block_size = min_block_size; - - // Increase block size according to shared memory requirements. - // If block size is 32, upper limit of shared memory size per - // thread block is set to 4096. This is GPU generation dependent. - constexpr unsigned ulimit_smem_size_cta32 = 4096; - while (smem_size > ulimit_smem_size_cta32 / 32 * block_size) { - block_size *= 2; - } - - // Increase block size to improve GPU occupancy when total number of - // CTAs (= num_cta_per_query * max_queries) is small. - cudaDeviceProp deviceProp = resource::get_device_properties(res); - RAFT_LOG_DEBUG("# multiProcessorCount: %d", deviceProp.multiProcessorCount); - while ((block_size < max_block_size) && - (graph_degree * search_width * team_size >= block_size * 2) && - (num_cta_per_query * max_queries <= - (1024 / (block_size * 2)) * deviceProp.multiProcessorCount)) { - block_size *= 2; - } - } - RAFT_LOG_DEBUG("# thread_block_size: %u", block_size); - RAFT_EXPECTS(block_size >= min_block_size, - "block_size cannot be smaller than min_block size, %u", - min_block_size); - RAFT_EXPECTS(block_size <= max_block_size, - "block_size cannot be larger than max_block size %u", - max_block_size); - thread_block_size = block_size; - - // - // Allocate memory for intermediate buffer and workspace. - // - uint32_t num_intermediate_results = num_cta_per_query * itopk_size; - intermediate_indices.resize(num_intermediate_results * max_queries, - resource::get_cuda_stream(res)); - intermediate_distances.resize(num_intermediate_results * max_queries, - resource::get_cuda_stream(res)); - - hashmap.resize(hashmap_size, resource::get_cuda_stream(res)); - - topk_workspace_size = _cuann_find_topk_bufferSize( - topk, max_queries, num_intermediate_results, utils::get_cuda_data_type()); - RAFT_LOG_DEBUG("# topk_workspace_size: %lu", topk_workspace_size); - topk_workspace.resize(topk_workspace_size, resource::get_cuda_stream(res)); - } - - void check(const uint32_t topk) override - { - RAFT_EXPECTS(num_cta_per_query * 32 >= topk, - "`num_cta_per_query` (%u) * 32 must be equal to or greater than " - "`topk` (%u) when 'search_mode' is \"multi-cta\". " - "(`num_cta_per_query`=max(`search_width`, ceildiv(`itopk_size`, 32)))", - num_cta_per_query, - topk); - } - - ~search() {} - - void operator()( - raft::resources const& res, - // raft::device_matrix_view dataset, - DATASET_DESCRIPTOR_T dataset_desc, - raft::device_matrix_view - graph, - typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, // [num_queries, topk] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const uint32_t num_queries, - const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk, - SAMPLE_FILTER_T sample_filter) - { - cudaStream_t stream = resource::get_cuda_stream(res); - - select_and_run( - dataset_desc, - graph, - intermediate_indices.data(), - intermediate_distances.data(), - queries_ptr, - num_queries, - dev_seed_ptr, - num_executed_iterations, - topk, - thread_block_size, - result_buffer_size, - smem_size, - hash_bitlen, - hashmap.data(), - num_cta_per_query, - num_random_samplings, - rand_xor_mask, - num_seeds, - itopk_size, - search_width, - min_iterations, - max_iterations, - sample_filter, - this->metric, - stream); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - // Select the top-k results from the intermediate results - const uint32_t num_intermediate_results = num_cta_per_query * itopk_size; - _cuann_find_topk(topk, - num_queries, - num_intermediate_results, - intermediate_distances.data(), - num_intermediate_results, - intermediate_indices.data(), - num_intermediate_results, - topk_distances_ptr, - topk, - topk_indices_ptr, - topk, - topk_workspace.data(), - true, - NULL, - stream); - } -}; - -} // namespace multi_cta_search -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh deleted file mode 100644 index 35f4f0e1c9..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh +++ /dev/null @@ -1,418 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include // none_cagra_sample_filter -#include // RAFT_EXPLICIT - -#include - -namespace raft::neighbors::cagra::detail { -namespace multi_cta_search { - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -template -void select_and_run( - DATASET_DESCRIPTOR_T dataset_desc, - raft::device_matrix_view graph, - typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, - const uint32_t num_queries, - const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, - uint32_t* const num_executed_iterations, - uint32_t topk, - uint32_t block_size, - uint32_t result_buffer_size, - uint32_t smem_size, - int64_t hash_bitlen, - typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, - uint32_t num_cta_per_query, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, - uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, - SAMPLE_FILTER_T sample_filter, - raft::distance::DistanceType metric, - cudaStream_t stream) RAFT_EXPLICIT; -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_kernel_selection( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - extern template void select_and_run< \ - TEAM_SIZE, \ - MAX_DATASET_DIM, \ - raft::neighbors::cagra::detail::standard_dataset_descriptor_t, \ - SAMPLE_FILTER_T>( \ - raft::neighbors::cagra::detail::standard_dataset_descriptor_t \ - dataset_desc, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - raft::distance::DistanceType metric, \ - cudaStream_t stream); - -instantiate_kernel_selection( - 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 1024, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 8, 128, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 16, 256, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 512, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_kernel_selection( - 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_kernel_selection - -#define instantiate_q_kernel_selection(TEAM_SIZE, \ - MAX_DATASET_DIM, \ - CODE_BOOK_T, \ - PQ_BITS, \ - PQ_CODE_BOOK_DIM, \ - DATA_T, \ - INDEX_T, \ - DISTANCE_T, \ - SAMPLE_FILTER_T) \ - extern template void \ - select_and_run, \ - SAMPLE_FILTER_T>( \ - raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t dataset_desc, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t block_size, \ - uint32_t result_buffer_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - uint32_t num_cta_per_query, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - raft::distance::DistanceType metric, \ - cudaStream_t stream); - -instantiate_q_kernel_selection( - 8, 128, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection( - 16, 256, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection( - 32, 512, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 2, - half, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection( - 8, 128, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection( - 16, 256, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection( - 32, 512, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 4, - half, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_kernel_selection( - 8, 128, half, 8, 2, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(16, - 256, - half, - 8, - 2, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 512, - half, - 8, - 2, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 2, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection( - 8, 128, half, 8, 4, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(16, - 256, - half, - 8, - 4, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 512, - half, - 8, - 4, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 4, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_kernel_selection(8, - 128, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(16, - 256, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 512, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(8, - 128, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(16, - 256, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 512, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_kernel_selection(8, - 128, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(16, - 256, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 512, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(8, - 128, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(16, - 256, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 512, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_kernel_selection(32, - 1024, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_q_kernel_selection -} // namespace multi_cta_search -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh deleted file mode 100644 index 16bb555aa4..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ /dev/null @@ -1,520 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "bitonic.hpp" -#include "compute_distance.hpp" -#include "device_common.hpp" -#include "hashmap.hpp" -#include "search_plan.cuh" -#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk if possible -#include "utils.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp - -#include -#include -#include -#include -#include -#include -#include - -namespace raft::neighbors::cagra::detail { -namespace multi_cta_search { - -// #define _CLK_BREAKDOWN - -template -__device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [search_width] - const uint32_t search_width, - INDEX_T* const itopk_indices, // [num_itopk] - const size_t num_itopk, - uint32_t* const terminate_flag) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const unsigned warp_id = threadIdx.x / 32; - if (warp_id > 0) { return; } - const unsigned lane_id = threadIdx.x % 32; - for (uint32_t i = lane_id; i < search_width; i += 32) { - next_parent_indices[i] = utils::get_max_value(); - } - uint32_t max_itopk = num_itopk; - if (max_itopk % 32) { max_itopk += 32 - (max_itopk % 32); } - uint32_t num_new_parents = 0; - for (uint32_t j = lane_id; j < max_itopk; j += 32) { - INDEX_T index; - int new_parent = 0; - if (j < num_itopk) { - index = itopk_indices[j]; - if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set - new_parent = 1; - } - } - const uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); - if (new_parent) { - const auto i = __popc(ballot_mask & ((1 << lane_id) - 1)) + num_new_parents; - if (i < search_width) { - next_parent_indices[i] = j; - itopk_indices[j] |= index_msb_1_mask; // set most significant bit as used node - } - } - num_new_parents += __popc(ballot_mask); - if (num_new_parents >= search_width) { break; } - } - if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } -} - -template -__device__ inline void topk_by_bitonic_sort(float* distances, // [num_elements] - INDEX_T* indices, // [num_elements] - const uint32_t num_elements, - const uint32_t num_itopk // num_itopk <= num_elements -) -{ - const unsigned warp_id = threadIdx.x / 32; - if (warp_id > 0) { return; } - const unsigned lane_id = threadIdx.x % 32; - constexpr unsigned N = (MAX_ELEMENTS + 31) / 32; - float key[N]; - INDEX_T val[N]; - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i); - if (j < num_elements) { - key[i] = distances[j]; - val[i] = indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Sort */ - bitonic::warp_sort(key, val); - /* Store itopk sorted results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_itopk) { - distances[j] = key[i]; - indices[j] = val[i]; - } - } -} - -// -// multiple CTAs per single query -// -template -__launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( - typename DATASET_DESCRIPTOR_T::INDEX_T* const - result_indices_ptr, // [num_queries, num_cta_per_query, itopk_size] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const - result_distances_ptr, // [num_queries, num_cta_per_query, itopk_size] - DATASET_DESCRIPTOR_T dataset_desc, - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree] - const uint32_t graph_degree, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - typename DATASET_DESCRIPTOR_T::INDEX_T* const - visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const uint32_t hash_bitlen, - const uint32_t itopk_size, - const uint32_t search_width, - const uint32_t min_iteration, - const uint32_t max_iteration, - uint32_t* const num_executed_iterations, /* stats */ - SAMPLE_FILTER_T sample_filter, - const raft::distance::DistanceType metric) -{ - using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; - using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - using QUERY_T = typename DATASET_DESCRIPTOR_T::QUERY_T; - - const auto num_queries = gridDim.y; - const auto query_id = blockIdx.y; - const auto num_cta_per_query = gridDim.x; - const auto cta_id = blockIdx.x; // local CTA ID - -#ifdef _CLK_BREAKDOWN - uint64_t clk_init = 0; - uint64_t clk_compute_1st_distance = 0; - uint64_t clk_topk = 0; - uint64_t clk_pickup_parents = 0; - uint64_t clk_compute_distance = 0; - uint64_t clk_start; -#define _CLK_START() clk_start = clock64() -#define _CLK_REC(V) V += clock64() - clk_start; -#else -#define _CLK_START() -#define _CLK_REC(V) -#endif - _CLK_START(); - - extern __shared__ uint32_t smem[]; - - // Layout of result_buffer - // +----------------+------------------------------+---------+ - // | internal_top_k | neighbors of parent nodes | padding | - // | | | upto 32 | - // +----------------+------------------------------+---------+ - // |<--- result_buffer_size --->| - uint32_t result_buffer_size = itopk_size + (search_width * graph_degree); - uint32_t result_buffer_size_32 = result_buffer_size; - if (result_buffer_size % 32) { result_buffer_size_32 += 32 - (result_buffer_size % 32); } - assert(result_buffer_size_32 <= MAX_ELEMENTS); - - const auto query_smem_buffer_length = - raft::ceildiv(dataset_desc.dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; - auto query_buffer = reinterpret_cast(smem); - auto result_indices_buffer = reinterpret_cast(query_buffer + query_smem_buffer_length); - auto result_distances_buffer = - reinterpret_cast(result_indices_buffer + result_buffer_size_32); - auto parent_indices_buffer = - reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto distance_work_buffer_ptr = - reinterpret_cast(parent_indices_buffer + search_width); - auto terminate_flag = reinterpret_cast(distance_work_buffer_ptr + - DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte); - - // Set smem working buffer for the distance calculation - dataset_desc.set_smem_ptr(distance_work_buffer_ptr); - -#if 0 - /* debug */ - for (unsigned i = threadIdx.x; i < result_buffer_size_32; i += blockDim.x) { - result_indices_buffer[i] = utils::get_max_value(); - result_distances_buffer[i] = utils::get_max_value(); - } -#endif - const DATA_T* const query_ptr = queries_ptr + (dataset_desc.dim * query_id); - dataset_desc.template copy_query( - query_ptr, query_buffer, query_smem_buffer_length); - - if (threadIdx.x == 0) { terminate_flag[0] = 0; } - INDEX_T* const local_visited_hashmap_ptr = - visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); - __syncthreads(); - _CLK_REC(clk_init); - - // compute distance to randomly selecting nodes - _CLK_START(); - const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; - uint32_t block_id = cta_id + (num_cta_per_query * query_id); - uint32_t num_blocks = num_cta_per_query * num_queries; - - device::compute_distance_to_random_nodes(result_indices_buffer, - result_distances_buffer, - query_buffer, - dataset_desc, - result_buffer_size, - num_distilation, - rand_xor_mask, - local_seed_ptr, - num_seeds, - local_visited_hashmap_ptr, - hash_bitlen, - metric, - block_id, - num_blocks); - __syncthreads(); - _CLK_REC(clk_compute_1st_distance); - - uint32_t iter = 0; - while (1) { - // topk with bitonic sort - _CLK_START(); - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - itopk_size + (search_width * graph_degree), - itopk_size); - _CLK_REC(clk_topk); - - if (iter + 1 == max_iteration) { - __syncthreads(); - break; - } - - // pick up next parents - _CLK_START(); - pickup_next_parents( - parent_indices_buffer, search_width, result_indices_buffer, itopk_size, terminate_flag); - _CLK_REC(clk_pickup_parents); - - __syncthreads(); - if (*terminate_flag && iter >= min_iteration) { break; } - - // compute the norms between child nodes and query node - _CLK_START(); - // constexpr unsigned max_n_frags = 16; - constexpr unsigned max_n_frags = 0; - device::compute_distance_to_child_nodes( - result_indices_buffer + itopk_size, - result_distances_buffer + itopk_size, - query_buffer, - dataset_desc, - knn_graph, - graph_degree, - local_visited_hashmap_ptr, - hash_bitlen, - parent_indices_buffer, - result_indices_buffer, - search_width, - metric); - _CLK_REC(clk_compute_distance); - __syncthreads(); - - // Filtering - if constexpr (!std::is_same::value) { - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { - if (parent_indices_buffer[p] != invalid_index) { - const auto parent_id = - result_indices_buffer[parent_indices_buffer[p]] & ~index_msb_1_mask; - if (!sample_filter(query_id, parent_id)) { - // If the parent must not be in the resulting top-k list, remove from the parent list - result_distances_buffer[parent_indices_buffer[p]] = utils::get_max_value(); - result_indices_buffer[parent_indices_buffer[p]] = invalid_index; - } - } - } - __syncthreads(); - } - - iter++; - } - - // Post process for filtering - if constexpr (!std::is_same::value) { - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - for (unsigned i = threadIdx.x; i < itopk_size + search_width * graph_degree; i += blockDim.x) { - const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; - if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id, node_id)) { - // If the parent must not be in the resulting top-k list, remove from the parent list - result_distances_buffer[i] = utils::get_max_value(); - result_indices_buffer[i] = invalid_index; - } - } - - __syncthreads(); - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - itopk_size + (search_width * graph_degree), - itopk_size); - __syncthreads(); - } - - for (uint32_t i = threadIdx.x; i < itopk_size; i += blockDim.x) { - uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id))); - if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; } - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - result_indices_ptr[j] = - result_indices_buffer[i] & ~index_msb_1_mask; // clear most significant bit - } - - if (threadIdx.x == 0 && cta_id == 0 && num_executed_iterations != nullptr) { - num_executed_iterations[query_id] = iter + 1; - } - -#ifdef _CLK_BREAKDOWN - if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && (blockIdx.x == 0) && - ((query_id * 3) % gridDim.y < 3)) { - printf( - "%s:%d " - "query, %d, thread, %d" - ", init, %lu" - ", 1st_distance, %lu" - ", topk, %lu" - ", pickup_parents, %lu" - ", distance, %lu" - "\n", - __FILE__, - __LINE__, - query_id, - threadIdx.x, - clk_init, - clk_compute_1st_distance, - clk_topk, - clk_pickup_parents, - clk_compute_distance); - } -#endif -} - -template -RAFT_KERNEL set_value_batch_kernel(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count * batch_size) { return; } - const auto batch_id = tid / count; - const auto elem_id = tid % count; - dev_ptr[elem_id + ld * batch_id] = val; -} - -template -void set_value_batch(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size, - cudaStream_t cuda_stream) -{ - constexpr std::uint32_t block_size = 256; - const auto grid_size = (count * batch_size + block_size - 1) / block_size; - set_value_batch_kernel - <<>>(dev_ptr, ld, val, count, batch_size); -} - -template -struct search_kernel_config { - // Search kernel function type. Note that the actual values for the template value - // parameters do not matter, because they are not part of the function signature. The - // second to fourth value parameters will be selected by the choose_* functions below. - using kernel_t = decltype(&search_kernel); - - static auto choose_buffer_size(unsigned result_buffer_size, unsigned block_size) -> kernel_t - { - if (result_buffer_size <= 64) { - return search_kernel; - } else if (result_buffer_size <= 128) { - return search_kernel; - } else if (result_buffer_size <= 256) { - return search_kernel; - } - THROW("Result buffer size %u larger than max buffer size %u", result_buffer_size, 256); - } -}; - -template -void select_and_run( - DATASET_DESCRIPTOR_T dataset_desc, - raft::device_matrix_view graph, - typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, // [num_queries, topk] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const uint32_t num_queries, - const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk, - // multi_cta_search (params struct) - uint32_t block_size, // - uint32_t result_buffer_size, - uint32_t smem_size, - int64_t hash_bitlen, - typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, - uint32_t num_cta_per_query, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, - uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, - SAMPLE_FILTER_T sample_filter, - raft::distance::DistanceType metric, - cudaStream_t stream) -{ - auto kernel = - search_kernel_config:: - choose_buffer_size(result_buffer_size, block_size); - - RAFT_CUDA_TRY(cudaFuncSetAttribute(kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte)); - // Initialize hash table - const uint32_t hash_size = hashmap::get_size(hash_bitlen); - set_value_batch(hashmap_ptr, - hash_size, - utils::get_max_value(), - hash_size, - num_queries, - stream); - - dim3 block_dims(block_size, 1, 1); - dim3 grid_dims(num_cta_per_query, num_queries, 1); - RAFT_LOG_DEBUG("Launching kernel with %u threads, (%u, %u) blocks %u smem", - block_size, - num_cta_per_query, - num_queries, - smem_size); - - kernel<<>>(topk_indices_ptr, - topk_distances_ptr, - dataset_desc, - queries_ptr, - graph.data_handle(), - graph.extent(1), - num_random_samplings, - rand_xor_mask, - dev_seed_ptr, - num_seeds, - hashmap_ptr, - hash_bitlen, - itopk_size, - search_width, - min_iterations, - max_iterations, - num_executed_iterations, - sample_filter, - metric); -} - -} // namespace multi_cta_search -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel.cuh deleted file mode 100644 index e003907292..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "search_multi_cta_kernel-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "search_multi_cta_kernel-ext.cuh" -#endif diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh deleted file mode 100644 index 31c4bc5dca..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ /dev/null @@ -1,1088 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "compute_distance.hpp" -#include "compute_distance_vpq.cuh" -#include "device_common.hpp" -#include "hashmap.hpp" -#include "search_plan.cuh" -#include "topk_for_cagra/topk_core.cuh" //todo replace with raft kernel -#include "utils.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp - -#include -#include - -#include -#include -#include -#include -#include -#include - -namespace raft::neighbors::cagra::detail { -namespace multi_kernel_search { - -template -RAFT_KERNEL set_value_kernel(T* const dev_ptr, const T val) -{ - *dev_ptr = val; -} - -template -RAFT_KERNEL set_value_kernel(T* const dev_ptr, const T val, const std::size_t count) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count) { return; } - dev_ptr[tid] = val; -} - -template -void set_value(T* const dev_ptr, const T val, cudaStream_t cuda_stream) -{ - set_value_kernel<<<1, 1, 0, cuda_stream>>>(dev_ptr, val); -} - -template -void set_value(T* const dev_ptr, const T val, const std::size_t count, cudaStream_t cuda_stream) -{ - constexpr std::uint32_t block_size = 256; - const auto grid_size = (count + block_size - 1) / block_size; - set_value_kernel<<>>(dev_ptr, val, count); -} - -template -RAFT_KERNEL get_value_kernel(T* const host_ptr, const T* const dev_ptr) -{ - *host_ptr = *dev_ptr; -} - -template -void get_value(T* const host_ptr, const T* const dev_ptr, cudaStream_t cuda_stream) -{ - get_value_kernel<<<1, 1, 0, cuda_stream>>>(host_ptr, dev_ptr); -} - -// MAX_DATASET_DIM : must equal to or greater than dataset_dim -template -RAFT_KERNEL random_pickup_kernel( - const DATASET_DESCRIPTOR_T dataset_desc, - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const std::size_t num_pickup, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, ldr] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] - const std::uint32_t ldr, // (*) ldr >= num_pickup - typename DATASET_DESCRIPTOR_T::INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] - const std::uint32_t hash_bitlen, - const raft::distance::DistanceType metric) -{ - using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; - using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - - const auto ldb = hashmap::get_size(hash_bitlen); - const auto global_team_index = (blockIdx.x * blockDim.x + threadIdx.x) / TEAM_SIZE; - const uint32_t query_id = blockIdx.y; - if (global_team_index >= num_pickup) { return; } - // Load a query - extern __shared__ float query_buffer[]; - const auto query_smem_buffer_length = - raft::ceildiv(dataset_desc.dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; - for (uint32_t i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) { - unsigned j = device::swizzling(i); - if (i < dataset_desc.dim) { - query_buffer[j] = spatial::knn::detail::utils::mapping{}( - (queries_ptr + query_id * dataset_desc.dim)[i]); - } else { - query_buffer[j] = 0.0; - } - } - __syncthreads(); - - INDEX_T best_index_team_local; - DISTANCE_T best_norm2_team_local = utils::get_max_value(); - for (unsigned i = 0; i < num_distilation; i++) { - INDEX_T seed_index; - if (seed_ptr && (global_team_index < num_seeds)) { - seed_index = seed_ptr[global_team_index + (num_seeds * query_id)]; - } else { - // Chose a seed node randomly - seed_index = - device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % dataset_desc.size; - } - - DISTANCE_T norm2; - switch (metric) { - case distance::DistanceType::L2Expanded: - norm2 = dataset_desc.template compute_similarity( - query_buffer, seed_index, true); - break; - case distance::DistanceType::InnerProduct: - norm2 = dataset_desc.template compute_similarity( - query_buffer, seed_index, true); - break; - default: break; - } - - if (norm2 < best_norm2_team_local) { - best_norm2_team_local = norm2; - best_index_team_local = seed_index; - } - } - - const auto store_gmem_index = global_team_index + (ldr * query_id); - if (threadIdx.x % TEAM_SIZE == 0) { - if (hashmap::insert( - visited_hashmap_ptr + (ldb * query_id), hash_bitlen, best_index_team_local)) { - result_distances_ptr[store_gmem_index] = best_norm2_team_local; - result_indices_ptr[store_gmem_index] = best_index_team_local; - } else { - result_distances_ptr[store_gmem_index] = utils::get_max_value(); - result_indices_ptr[store_gmem_index] = utils::get_max_value(); - } - } -} - -// MAX_DATASET_DIM : must be equal to or greater than dataset_dim -template -void random_pickup( - const DATASET_DESCRIPTOR_T dataset_desc, - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const std::size_t num_queries, - const std::size_t num_pickup, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, ldr] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldr] - const std::size_t ldr, // (*) ldr >= num_pickup - typename DATASET_DESCRIPTOR_T::INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << bitlen] - const std::uint32_t hash_bitlen, - const raft::distance::DistanceType metric, - cudaStream_t const cuda_stream = 0) -{ - const auto block_size = 256u; - const auto num_teams_per_threadblock = block_size / TEAM_SIZE; - const dim3 grid_size((num_pickup + num_teams_per_threadblock - 1) / num_teams_per_threadblock, - num_queries); - - const auto query_smem_buffer_length = - raft::ceildiv(dataset_desc.dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; - const auto smem_size = query_smem_buffer_length * sizeof(float); - - random_pickup_kernel - <<>>(dataset_desc, - queries_ptr, - num_pickup, - num_distilation, - rand_xor_mask, - seed_ptr, - num_seeds, - result_indices_ptr, - result_distances_ptr, - ldr, - visited_hashmap_ptr, - hash_bitlen, - metric); -} - -template -RAFT_KERNEL pickup_next_parents_kernel( - INDEX_T* const parent_candidates_ptr, // [num_queries, lds] - const std::size_t lds, // (*) lds >= parent_candidates_size - const std::uint32_t parent_candidates_size, // - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::size_t hash_bitlen, - const std::uint32_t small_hash_bitlen, - INDEX_T* const parent_list_ptr, // [num_queries, ldd] - const std::size_t ldd, // (*) ldd >= parent_list_size - const std::size_t parent_list_size, // - std::uint32_t* const terminate_flag) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - const std::size_t ldb = hashmap::get_size(hash_bitlen); - const uint32_t query_id = blockIdx.x; - if (threadIdx.x < 32) { - // pickup next parents with single warp - for (std::uint32_t i = threadIdx.x; i < parent_list_size; i += 32) { - parent_list_ptr[i + (ldd * query_id)] = utils::get_max_value(); - } - std::uint32_t parent_candidates_size_max = parent_candidates_size; - if (parent_candidates_size % 32) { - parent_candidates_size_max += 32 - (parent_candidates_size % 32); - } - std::uint32_t num_new_parents = 0; - for (std::uint32_t j = threadIdx.x; j < parent_candidates_size_max; j += 32) { - INDEX_T index; - int new_parent = 0; - if (j < parent_candidates_size) { - index = parent_candidates_ptr[j + (lds * query_id)]; - if ((index & index_msb_1_mask) == 0) { // check most significant bit - new_parent = 1; - } - } - const std::uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); - if (new_parent) { - const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; - if (i < parent_list_size) { - parent_list_ptr[i + (ldd * query_id)] = j; - parent_candidates_ptr[j + (lds * query_id)] |= - index_msb_1_mask; // set most significant bit as used node - } - } - num_new_parents += __popc(ballot_mask); - if (num_new_parents >= parent_list_size) { break; } - } - if ((num_new_parents > 0) && (threadIdx.x == 0)) { *terminate_flag = 0; } - } else if (small_hash_bitlen) { - // reset small-hash - hashmap::init(visited_hashmap_ptr + (ldb * query_id), hash_bitlen, 32); - } - - if (small_hash_bitlen) { - __syncthreads(); - // insert internal-topk indices into small-hash - for (unsigned i = threadIdx.x; i < parent_candidates_size; i += blockDim.x) { - auto key = parent_candidates_ptr[i + (lds * query_id)] & - ~index_msb_1_mask; // clear most significant bit - hashmap::insert(visited_hashmap_ptr + (ldb * query_id), hash_bitlen, key); - } - } -} - -template -void pickup_next_parents(INDEX_T* const parent_candidates_ptr, // [num_queries, lds] - const std::size_t lds, // (*) lds >= parent_candidates_size - const std::size_t parent_candidates_size, // - const std::size_t num_queries, - INDEX_T* const visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::size_t hash_bitlen, - const std::size_t small_hash_bitlen, - INDEX_T* const parent_list_ptr, // [num_queries, ldd] - const std::size_t ldd, // (*) ldd >= parent_list_size - const std::size_t parent_list_size, // - std::uint32_t* const terminate_flag, - cudaStream_t cuda_stream = 0) -{ - std::uint32_t block_size = 32; - if (small_hash_bitlen) { - block_size = 128; - while (parent_candidates_size > block_size) { - block_size *= 2; - } - block_size = min(block_size, (uint32_t)512); - } - pickup_next_parents_kernel - <<>>(parent_candidates_ptr, - lds, - parent_candidates_size, - visited_hashmap_ptr, - hash_bitlen, - small_hash_bitlen, - parent_list_ptr, - ldd, - parent_list_size, - terminate_flag); -} - -template -RAFT_KERNEL compute_distance_to_child_nodes_kernel( - const typename DATASET_DESCRIPTOR_T::INDEX_T* const - parent_node_list, // [num_queries, search_width] - typename DATASET_DESCRIPTOR_T::INDEX_T* const - parent_candidates_ptr, // [num_queries, search_width] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const - parent_distance_ptr, // [num_queries, search_width] - const std::size_t lds, - const std::uint32_t search_width, - const DATASET_DESCRIPTOR_T dataset_desc, - const typename DATASET_DESCRIPTOR_T::INDEX_T* const - neighbor_graph_ptr, // [dataset_size, graph_degree] - const std::uint32_t graph_degree, - const typename DATASET_DESCRIPTOR_T::DATA_T* query_ptr, // [num_queries, data_dim] - typename DATASET_DESCRIPTOR_T::INDEX_T* const - visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::uint32_t hash_bitlen, - typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, ldd] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] - const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree - SAMPLE_FILTER_T sample_filter, - const raft::distance::DistanceType metric) -{ - using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - - const uint32_t ldb = hashmap::get_size(hash_bitlen); - const auto tid = threadIdx.x + blockDim.x * blockIdx.x; - const auto global_team_id = tid / TEAM_SIZE; - const auto query_id = blockIdx.y; - - extern __shared__ float query_buffer[]; - const auto query_smem_buffer_length = - raft::ceildiv(dataset_desc.dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; - for (uint32_t i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) { - unsigned j = device::swizzling(i); - if (i < dataset_desc.dim) { - query_buffer[j] = - spatial::knn::detail::utils::mapping{}((query_ptr + query_id * dataset_desc.dim)[i]); - } else { - query_buffer[j] = 0.0; - } - } - __syncthreads(); - if (global_team_id >= search_width * graph_degree) { return; } - - const std::size_t parent_list_index = - parent_node_list[global_team_id / graph_degree + (search_width * blockIdx.y)]; - - if (parent_list_index == utils::get_max_value()) { return; } - - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const auto raw_parent_index = parent_candidates_ptr[parent_list_index + (lds * query_id)]; - - if (raw_parent_index == utils::get_max_value()) { - result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); - return; - } - const auto parent_index = raw_parent_index & ~index_msb_1_mask; - - const auto neighbor_list_head_ptr = neighbor_graph_ptr + (graph_degree * parent_index); - - const std::size_t child_id = neighbor_list_head_ptr[global_team_id % graph_degree]; - - const auto compute_distance_flag = hashmap::insert( - visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id); - - DISTANCE_T norm2; - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - norm2 = dataset_desc.template compute_similarity( - query_buffer, child_id, compute_distance_flag); - break; - case raft::distance::DistanceType::InnerProduct: - norm2 = dataset_desc.template compute_similarity( - query_buffer, child_id, compute_distance_flag); - break; - default: break; - } - - if (compute_distance_flag) { - if (threadIdx.x % TEAM_SIZE == 0) { - result_indices_ptr[ldd * blockIdx.y + global_team_id] = child_id; - result_distances_ptr[ldd * blockIdx.y + global_team_id] = norm2; - } - } else { - if (threadIdx.x % TEAM_SIZE == 0) { - result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); - } - } - - if constexpr (!std::is_same::value) { - if (!sample_filter(query_id, parent_index)) { - parent_candidates_ptr[parent_list_index + (lds * query_id)] = utils::get_max_value(); - parent_distance_ptr[parent_list_index + (lds * query_id)] = - utils::get_max_value(); - } - } -} - -template -void compute_distance_to_child_nodes( - const typename DATASET_DESCRIPTOR_T::INDEX_T* const - parent_node_list, // [num_queries, search_width] - typename DATASET_DESCRIPTOR_T::INDEX_T* const - parent_candidates_ptr, // [num_queries, search_width] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const - parent_distance_ptr, // [num_queries, search_width] - const std::size_t lds, - const uint32_t search_width, - const DATASET_DESCRIPTOR_T dataset_desc, - const typename DATASET_DESCRIPTOR_T::INDEX_T* const - neighbor_graph_ptr, // [dataset_size, graph_degree] - const std::uint32_t graph_degree, - const typename DATASET_DESCRIPTOR_T::DATA_T* query_ptr, // [num_queries, data_dim] - const std::uint32_t num_queries, - typename DATASET_DESCRIPTOR_T::INDEX_T* const - visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::uint32_t hash_bitlen, - typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, ldd] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, ldd] - const std::uint32_t ldd, // (*) ldd >= search_width * graph_degree - SAMPLE_FILTER_T sample_filter, - const raft::distance::DistanceType metric, - cudaStream_t cuda_stream = 0) -{ - const auto block_size = 128; - const dim3 grid_size( - (search_width * graph_degree + (block_size / TEAM_SIZE) - 1) / (block_size / TEAM_SIZE), - num_queries); - - const auto query_smem_buffer_length = - raft::ceildiv(dataset_desc.dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; - - const auto smem_size = - query_smem_buffer_length * sizeof(float) + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte; - - compute_distance_to_child_nodes_kernel - <<>>(parent_node_list, - parent_candidates_ptr, - parent_distance_ptr, - lds, - search_width, - dataset_desc, - neighbor_graph_ptr, - graph_degree, - query_ptr, - visited_hashmap_ptr, - hash_bitlen, - result_indices_ptr, - result_distances_ptr, - ldd, - sample_filter, - metric); -} - -template -RAFT_KERNEL remove_parent_bit_kernel(const std::uint32_t num_queries, - const std::uint32_t num_topk, - INDEX_T* const topk_indices_ptr, // [ld, num_queries] - const std::uint32_t ld) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - uint32_t i_query = blockIdx.x; - if (i_query >= num_queries) return; - - for (unsigned i = threadIdx.x; i < num_topk; i += blockDim.x) { - topk_indices_ptr[i + (ld * i_query)] &= ~index_msb_1_mask; // clear most significant bit - } -} - -template -void remove_parent_bit(const std::uint32_t num_queries, - const std::uint32_t num_topk, - INDEX_T* const topk_indices_ptr, // [ld, num_queries] - const std::uint32_t ld, - cudaStream_t cuda_stream = 0) -{ - const std::size_t grid_size = num_queries; - const std::size_t block_size = 256; - remove_parent_bit_kernel<<>>( - num_queries, num_topk, topk_indices_ptr, ld); -} - -// This function called after the `remove_parent_bit` function -template -RAFT_KERNEL apply_filter_kernel(INDEX_T* const result_indices_ptr, - DISTANCE_T* const result_distances_ptr, - const std::size_t lds, - const std::uint32_t result_buffer_size, - const std::uint32_t num_queries, - const INDEX_T query_id_offset, - SAMPLE_FILTER_T sample_filter) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= result_buffer_size * num_queries) { return; } - const auto i = tid % result_buffer_size; - const auto j = tid / result_buffer_size; - const auto index = i + j * lds; - - if (result_indices_ptr[index] != ~index_msb_1_mask && - !sample_filter(query_id_offset + j, result_indices_ptr[index])) { - result_indices_ptr[index] = utils::get_max_value(); - result_distances_ptr[index] = utils::get_max_value(); - } -} - -template -void apply_filter(INDEX_T* const result_indices_ptr, - DISTANCE_T* const result_distances_ptr, - const std::size_t lds, - const std::uint32_t result_buffer_size, - const std::uint32_t num_queries, - const INDEX_T query_id_offset, - SAMPLE_FILTER_T sample_filter, - cudaStream_t cuda_stream) -{ - const std::uint32_t block_size = 256; - const std::uint32_t grid_size = ceildiv(num_queries * result_buffer_size, block_size); - - apply_filter_kernel<<>>(result_indices_ptr, - result_distances_ptr, - lds, - result_buffer_size, - num_queries, - query_id_offset, - sample_filter); -} - -template -RAFT_KERNEL batched_memcpy_kernel(T* const dst, // [batch_size, ld_dst] - const uint64_t ld_dst, - const T* const src, // [batch_size, ld_src] - const uint64_t ld_src, - const uint64_t count, - const uint64_t batch_size) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count * batch_size) { return; } - const auto i = tid % count; - const auto j = tid / count; - dst[i + (ld_dst * j)] = src[i + (ld_src * j)]; -} - -template -void batched_memcpy(T* const dst, // [batch_size, ld_dst] - const uint64_t ld_dst, - const T* const src, // [batch_size, ld_src] - const uint64_t ld_src, - const uint64_t count, - const uint64_t batch_size, - cudaStream_t cuda_stream) -{ - assert(ld_dst >= count); - assert(ld_src >= count); - constexpr uint32_t block_size = 256; - const auto grid_size = (batch_size * count + block_size - 1) / block_size; - batched_memcpy_kernel - <<>>(dst, ld_dst, src, ld_src, count, batch_size); -} - -template -RAFT_KERNEL set_value_batch_kernel(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size) -{ - const auto tid = threadIdx.x + blockIdx.x * blockDim.x; - if (tid >= count * batch_size) { return; } - const auto batch_id = tid / count; - const auto elem_id = tid % count; - dev_ptr[elem_id + ld * batch_id] = val; -} - -template -void set_value_batch(T* const dev_ptr, - const std::size_t ld, - const T val, - const std::size_t count, - const std::size_t batch_size, - cudaStream_t cuda_stream) -{ - constexpr std::uint32_t block_size = 256; - const auto grid_size = (count * batch_size + block_size - 1) / block_size; - set_value_batch_kernel - <<>>(dev_ptr, ld, val, count, batch_size); -} - -// result_buffer (work buffer) for "multi-kernel" -// +--------------------+------------------------------+-------------------+ -// | internal_top_k (A) | neighbors of internal_top_k | internal_topk (B) | -// | | | | -// +--------------------+------------------------------+-------------------+ -// |<--- result_buffer_allocation_size --->| -// |<--- result_buffer_size --->| // Double buffer (A) -// |<--- result_buffer_size --->| // Double buffer (B) -template -struct search : search_plan_impl { - using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; - using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - - static_assert(std::is_same_v, "Only float is supported as resulting distance"); - - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; - - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; - - using search_plan_impl::hash_bitlen; - - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; - - using search_plan_impl::smem_size; - - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; - - size_t result_buffer_allocation_size; - rmm::device_uvector result_indices; // results_indices_buffer - rmm::device_uvector result_distances; // result_distances_buffer - rmm::device_uvector parent_node_list; - rmm::device_uvector topk_hint; - rmm::device_scalar terminate_flag; // dev_terminate_flag, host_terminate_flag.; - rmm::device_uvector topk_workspace; - - // temporary storage for _find_topk - rmm::device_uvector input_keys_storage; - rmm::device_uvector output_keys_storage; - rmm::device_uvector input_values_storage; - rmm::device_uvector output_values_storage; - - search(raft::resources const& res, - search_params params, - int64_t dim, - int64_t graph_degree, - uint32_t topk, - raft::distance::DistanceType metric) - : search_plan_impl( - res, params, dim, graph_degree, topk, metric), - result_indices(0, resource::get_cuda_stream(res)), - result_distances(0, resource::get_cuda_stream(res)), - parent_node_list(0, resource::get_cuda_stream(res)), - topk_hint(0, resource::get_cuda_stream(res)), - topk_workspace(0, resource::get_cuda_stream(res)), - terminate_flag(resource::get_cuda_stream(res)), - input_keys_storage(0, resource::get_cuda_stream(res)), - output_keys_storage(0, resource::get_cuda_stream(res)), - input_values_storage(0, resource::get_cuda_stream(res)), - output_values_storage(0, resource::get_cuda_stream(res)) - { - set_params(res); - } - - void set_params(raft::resources const& res) - { - // - // Allocate memory for intermediate buffer and workspace. - // - result_buffer_size = itopk_size + (search_width * graph_degree); - result_buffer_allocation_size = result_buffer_size + itopk_size; - result_indices.resize(result_buffer_allocation_size * max_queries, - resource::get_cuda_stream(res)); - result_distances.resize(result_buffer_allocation_size * max_queries, - resource::get_cuda_stream(res)); - - parent_node_list.resize(max_queries * search_width, resource::get_cuda_stream(res)); - topk_hint.resize(max_queries, resource::get_cuda_stream(res)); - - size_t topk_workspace_size = _cuann_find_topk_bufferSize( - itopk_size, max_queries, result_buffer_size, utils::get_cuda_data_type()); - RAFT_LOG_DEBUG("# topk_workspace_size: %lu", topk_workspace_size); - topk_workspace.resize(topk_workspace_size, resource::get_cuda_stream(res)); - - hashmap.resize(hashmap_size, resource::get_cuda_stream(res)); - } - - ~search() {} - - inline void _find_topk(raft::resources const& handle, - uint32_t topK, - uint32_t sizeBatch, - uint32_t numElements, - const float* inputKeys, // [sizeBatch, ldIK,] - uint32_t ldIK, // (*) ldIK >= numElements - const INDEX_T* inputVals, // [sizeBatch, ldIV,] - uint32_t ldIV, // (*) ldIV >= numElements - float* outputKeys, // [sizeBatch, ldOK,] - uint32_t ldOK, // (*) ldOK >= topK - INDEX_T* outputVals, // [sizeBatch, ldOV,] - uint32_t ldOV, // (*) ldOV >= topK - void* workspace, - bool sort, - uint32_t* hints) - { - auto stream = resource::get_cuda_stream(handle); - - // _cuann_find_topk right now is limited to a max-k of 1024. - // RAFT has a matrix::select_k function - which handles arbitrary sized values of k, - // but doesn't accept strided inputs unlike _cuann_find_topk - // The multi-kernel search path requires strided access - since its cleverly allocating memory - // (layout described in the search_plan_impl function below), such that both the - // neighbors and the internal_topk are adjacent - in a double buffered format. - // Since this layout doesn't work with the matrix::select_k code - we have to copy - // over to a contiguous (non-strided) access to handle topk larger than 1024, and - // potentially also copy back to a strided layout afterwards - if (topK <= 1024) { - return _cuann_find_topk(topK, - sizeBatch, - numElements, - inputKeys, - ldIK, - inputVals, - ldIV, - outputKeys, - ldOK, - outputVals, - ldOV, - workspace, - sort, - hints, - stream); - } - - if (ldIK > numElements) { - if (input_keys_storage.size() != sizeBatch * numElements) { - input_keys_storage.resize(sizeBatch * numElements, stream); - } - batched_memcpy( - input_keys_storage.data(), numElements, inputKeys, ldIK, numElements, sizeBatch, stream); - inputKeys = input_keys_storage.data(); - } - - if (ldIV > numElements) { - if (input_values_storage.size() != sizeBatch * numElements) { - input_values_storage.resize(sizeBatch * numElements, stream); - } - - batched_memcpy( - input_values_storage.data(), numElements, inputVals, ldIV, numElements, sizeBatch, stream); - inputVals = input_values_storage.data(); - } - - if ((ldOK > topK) && (output_keys_storage.size() != sizeBatch * topK)) { - output_keys_storage.resize(sizeBatch * topK, stream); - } - - if ((ldOV > topK) && (output_values_storage.size() != sizeBatch * topK)) { - output_values_storage.resize(sizeBatch * topK, stream); - } - - raft::matrix::select_k( - handle, - raft::make_device_matrix_view(inputKeys, sizeBatch, numElements), - raft::make_device_matrix_view(inputVals, sizeBatch, numElements), - raft::make_device_matrix_view( - ldOK > topK ? output_keys_storage.data() : outputKeys, sizeBatch, topK), - raft::make_device_matrix_view( - ldOV > topK ? output_values_storage.data() : outputVals, sizeBatch, topK), - true, // select_min - sort); - - if (ldOK > topK) { - batched_memcpy(outputKeys, ldOK, output_keys_storage.data(), topK, topK, sizeBatch, stream); - } - - if (ldOV > topK) { - batched_memcpy(outputVals, ldOV, output_values_storage.data(), topK, topK, sizeBatch, stream); - } - } - - void operator()(raft::resources const& res, - DATASET_DESCRIPTOR_T dataset_desc, - raft::device_matrix_view graph, - INDEX_T* const topk_indices_ptr, // [num_queries, topk] - DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk, - SAMPLE_FILTER_T sample_filter) - { - // Init hashmap - cudaStream_t stream = resource::get_cuda_stream(res); - const uint32_t hash_size = hashmap::get_size(hash_bitlen); - set_value_batch( - hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); - - // Topk hint can not be used when applying a filter - uint32_t* const top_hint_ptr = - std::is_same::value - ? topk_hint.data() - : nullptr; - // Init topk_hint - if (top_hint_ptr != nullptr && topk_hint.size() > 0) { - set_value(top_hint_ptr, 0xffffffffu, num_queries, stream); - } - - // Choose initial entry point candidates at random - random_pickup(dataset_desc, - queries_ptr, - num_queries, - result_buffer_size, - num_random_samplings, - rand_xor_mask, - dev_seed_ptr, - num_seeds, - result_indices.data(), - result_distances.data(), - result_buffer_allocation_size, - hashmap.data(), - hash_bitlen, - this->metric, - stream); - - unsigned iter = 0; - while (1) { - // Make an index list of internal top-k nodes - _find_topk(res, - itopk_size, - num_queries, - result_buffer_size, - result_distances.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_indices.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_buffer_allocation_size, - result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_buffer_allocation_size, - topk_workspace.data(), - true, - top_hint_ptr); - - // termination (1) - if ((iter + 1 == max_iterations)) { - iter++; - break; - } - - if (iter + 1 >= min_iterations) { set_value(terminate_flag.data(), 1, stream); } - - // pickup parent nodes - uint32_t _small_hash_bitlen = 0; - if ((iter + 1) % small_hash_reset_interval == 0) { _small_hash_bitlen = small_hash_bitlen; } - pickup_next_parents(result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_buffer_allocation_size, - itopk_size, - num_queries, - hashmap.data(), - hash_bitlen, - _small_hash_bitlen, - parent_node_list.data(), - search_width, - search_width, - terminate_flag.data(), - stream); - - // termination (2) - if (iter + 1 >= min_iterations && terminate_flag.value(stream)) { - iter++; - break; - } - - // Compute distance to child nodes that are adjacent to the parent node - compute_distance_to_child_nodes( - parent_node_list.data(), - result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, - result_buffer_allocation_size, - search_width, - dataset_desc, - graph.data_handle(), - graph.extent(1), - queries_ptr, - num_queries, - hashmap.data(), - hash_bitlen, - result_indices.data() + itopk_size, - result_distances.data() + itopk_size, - result_buffer_allocation_size, - sample_filter, - this->metric, - stream); - - iter++; - } // while ( 1 ) - auto result_indices_ptr = result_indices.data() + (iter & 0x1) * result_buffer_size; - auto result_distances_ptr = result_distances.data() + (iter & 0x1) * result_buffer_size; - - if constexpr (!std::is_same::value) { - // Remove parent bit in search results - remove_parent_bit(num_queries, - result_buffer_size, - result_indices.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - stream); - - apply_filter( - result_indices.data() + (iter & 0x1) * itopk_size, - result_distances.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_buffer_size, - num_queries, - 0, - sample_filter, - stream); - - result_indices_ptr = result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size; - result_distances_ptr = result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size; - _find_topk(res, - itopk_size, - num_queries, - result_buffer_size, - result_distances.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_indices.data() + (iter & 0x1) * itopk_size, - result_buffer_allocation_size, - result_distances_ptr, - result_buffer_allocation_size, - result_indices_ptr, - result_buffer_allocation_size, - topk_workspace.data(), - true, - top_hint_ptr); - } else { - // Remove parent bit in search results - remove_parent_bit( - num_queries, itopk_size, result_indices_ptr, result_buffer_allocation_size, stream); - } - - // Copy results from working buffer to final buffer - batched_memcpy(topk_indices_ptr, - topk, - result_indices_ptr, - result_buffer_allocation_size, - topk, - num_queries, - stream); - if (topk_distances_ptr) { - batched_memcpy(topk_distances_ptr, - topk, - result_distances_ptr, - result_buffer_allocation_size, - topk, - num_queries, - stream); - } - - if (num_executed_iterations) { - for (std::uint32_t i = 0; i < num_queries; i++) { - num_executed_iterations[i] = iter; - } - } - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } -}; - -template -struct search, - SAMPLE_FILTER_T> - : public search_plan_impl, - SAMPLE_FILTER_T> { - using DATASET_DESCRIPTOR_T = cagra_q_dataset_descriptor_t; - using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - - search(raft::resources const& res, - search_params params, - int64_t dim, - int64_t graph_degree, - uint32_t topk, - raft::distance::DistanceType metric) - : search_plan_impl( - res, params, dim, graph_degree, topk, metric) - { - THROW("The multi-kernel mode does not support VPQ"); - } - - void set_params(raft::resources const& res) {} - - void operator()(raft::resources const& res, - DATASET_DESCRIPTOR_T dataset_desc, - raft::device_matrix_view graph, - INDEX_T* const topk_indices_ptr, // [num_queries, topk] - DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk, - SAMPLE_FILTER_T sample_filter) - { - } -}; - -} // namespace multi_kernel_search -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh deleted file mode 100644 index b35d96e9f5..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ /dev/null @@ -1,347 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "hashmap.hpp" - -#include -#include -// #include "search_single_cta.cuh" -// #include "topk_for_cagra/topk_core.cuh" - -#include -#include -#include -#include -#include - -namespace raft::neighbors::cagra::detail { - -struct search_plan_impl_base : public search_params { - int64_t dataset_block_dim; - int64_t dim; - int64_t graph_degree; - uint32_t topk; - raft::distance::DistanceType metric; - search_plan_impl_base(search_params params, - int64_t dim, - int64_t graph_degree, - uint32_t topk, - raft::distance::DistanceType metric) - : search_params(params), dim(dim), graph_degree(graph_degree), topk(topk), metric(metric) - { - set_dataset_block_and_team_size(dim); - if (algo == search_algo::AUTO) { - const size_t num_sm = raft::getMultiProcessorCount(); - if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) { - algo = search_algo::SINGLE_CTA; - RAFT_LOG_DEBUG("Auto strategy: selecting single-cta"); - } else if (topk <= 1024) { - algo = search_algo::MULTI_CTA; - RAFT_LOG_DEBUG("Auto strategy: selecting multi-cta"); - } else { - algo = search_algo::MULTI_KERNEL; - RAFT_LOG_DEBUG("Auto strategy: selecting multi kernel"); - } - } - } - - void set_dataset_block_and_team_size(int64_t dim) - { - constexpr int64_t max_dataset_block_dim = 512; - dataset_block_dim = 128; - while (dataset_block_dim < dim && dataset_block_dim < max_dataset_block_dim) { - dataset_block_dim *= 2; - } - // To keep binary size in check we limit only one team size specialization for each max_dim. - // TODO(tfeher): revise this decision. - switch (dataset_block_dim) { - case 128: team_size = 8; break; - case 256: team_size = 16; break; - default: team_size = 32; break; - } - } -}; - -template -struct search_plan_impl : public search_plan_impl_base { - using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; - - int64_t hash_bitlen; - - size_t small_hash_bitlen; - size_t small_hash_reset_interval; - size_t hashmap_size; - uint32_t dataset_size; - uint32_t result_buffer_size; - - uint32_t smem_size; - uint32_t topk; - uint32_t num_seeds; - - rmm::device_uvector hashmap; - rmm::device_uvector num_executed_iterations; // device or managed? - rmm::device_uvector dev_seed; - - search_plan_impl(raft::resources const& res, - search_params params, - int64_t dim, - int64_t graph_degree, - uint32_t topk, - raft::distance::DistanceType metric) - : search_plan_impl_base(params, dim, graph_degree, topk, metric), - hashmap(0, resource::get_cuda_stream(res)), - num_executed_iterations(0, resource::get_cuda_stream(res)), - dev_seed(0, resource::get_cuda_stream(res)), - num_seeds(0) - { - adjust_search_params(); - check_params(); - calc_hashmap_params(res); - set_dataset_block_and_team_size(dim); - num_executed_iterations.resize(max_queries, resource::get_cuda_stream(res)); - RAFT_LOG_DEBUG("# algo = %d", static_cast(algo)); - } - - virtual ~search_plan_impl() {} - - virtual void operator()(raft::resources const& res, - DATASET_DESCRIPTOR_T dataset_desc, - raft::device_matrix_view graph, - INDEX_T* const result_indices_ptr, // [num_queries, topk] - DISTANCE_T* const result_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const std::uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - std::uint32_t* const num_executed_iterations, // [num_queries] - uint32_t topk, - SAMPLE_FILTER_T sample_filter){}; - - void adjust_search_params() - { - uint32_t _max_iterations = max_iterations; - if (max_iterations == 0) { - if (algo == search_algo::MULTI_CTA) { - _max_iterations = 1 + std::min(32 * 1.1, 32 + 10.0); // TODO(anaruse) - } else { - _max_iterations = - 1 + std::min((itopk_size / search_width) * 1.1, (itopk_size / search_width) + 10.0); - } - } - if (max_iterations < min_iterations) { _max_iterations = min_iterations; } - if (max_iterations < _max_iterations) { - RAFT_LOG_DEBUG( - "# max_iterations is increased from %lu to %u.", max_iterations, _max_iterations); - max_iterations = _max_iterations; - } - if (itopk_size % 32) { - uint32_t itopk32 = itopk_size; - itopk32 += 32 - (itopk_size % 32); - RAFT_LOG_DEBUG("# internal_topk is increased from %lu to %u, as it must be multiple of 32.", - itopk_size, - itopk32); - itopk_size = itopk32; - } - } - - // defines hash_bitlen, small_hash_bitlen, small_hash_reset interval, hash_size - inline void calc_hashmap_params(raft::resources const& res) - { - // for multiple CTA search - uint32_t mc_num_cta_per_query = 0; - uint32_t mc_search_width = 0; - uint32_t mc_itopk_size = 0; - if (algo == search_algo::MULTI_CTA) { - mc_itopk_size = 32; - mc_search_width = 1; - mc_num_cta_per_query = max(search_width, raft::ceildiv(itopk_size, (size_t)32)); - RAFT_LOG_DEBUG("# mc_itopk_size: %u", mc_itopk_size); - RAFT_LOG_DEBUG("# mc_search_width: %u", mc_search_width); - RAFT_LOG_DEBUG("# mc_num_cta_per_query: %u", mc_num_cta_per_query); - } - - // Determine hash size (bit length) - hashmap_size = 0; - hash_bitlen = 0; - small_hash_bitlen = 0; - small_hash_reset_interval = 1024 * 1024; - float max_fill_rate = hashmap_max_fill_rate; - while (hashmap_mode == hash_mode::AUTO || hashmap_mode == hash_mode::SMALL) { - // - // The small-hash reduces hash table size by initializing the hash table - // for each iteraton and re-registering only the nodes that should not be - // re-visited in that iteration. Therefore, the size of small-hash should - // be determined based on the internal topk size and the number of nodes - // visited per iteration. - // - const auto max_visited_nodes = itopk_size + (search_width * graph_degree * 1); - unsigned min_bitlen = 8; // 256 - unsigned max_bitlen = 13; // 8K - if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } - hash_bitlen = min_bitlen; - while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { - hash_bitlen += 1; - } - if (hash_bitlen > max_bitlen) { - // Switch to normal hash if hashmap_mode is AUTO, otherwise exit. - if (hashmap_mode == hash_mode::AUTO) { - hash_bitlen = 0; - break; - } else { - RAFT_FAIL( - "small-hash cannot be used because the required hash size exceeds the limit (%u)", - hashmap::get_size(max_bitlen)); - } - } - small_hash_bitlen = hash_bitlen; - // - // Sincc the hash table size is limited to a power of 2, the requirement, - // the maximum fill rate, may be satisfied even if the frequency of hash - // table reset is reduced to once every 2 or more iterations without - // changing the hash table size. In that case, reduce the reset frequency. - // - small_hash_reset_interval = 1; - while (1) { - const auto max_visited_nodes = - itopk_size + (search_width * graph_degree * (small_hash_reset_interval + 1)); - if (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { break; } - small_hash_reset_interval += 1; - } - break; - } - if (hash_bitlen == 0) { - // - // The size of hash table is determined based on the maximum number of - // nodes that may be visited before the search is completed and the - // maximum fill rate of the hash table. - // - uint32_t max_visited_nodes = itopk_size + (search_width * graph_degree * max_iterations); - if (algo == search_algo::MULTI_CTA) { - max_visited_nodes = mc_itopk_size + (mc_search_width * graph_degree * max_iterations); - max_visited_nodes *= mc_num_cta_per_query; - } - unsigned min_bitlen = 11; // 2K - if (min_bitlen < hashmap_min_bitlen) { min_bitlen = hashmap_min_bitlen; } - hash_bitlen = min_bitlen; - while (max_visited_nodes > hashmap::get_size(hash_bitlen) * max_fill_rate) { - hash_bitlen += 1; - } - RAFT_EXPECTS(hash_bitlen <= 20, "hash_bitlen cannot be largen than 20 (1M)"); - } - - RAFT_LOG_DEBUG("# internal topK = %lu", itopk_size); - RAFT_LOG_DEBUG("# parent size = %lu", search_width); - RAFT_LOG_DEBUG("# min_iterations = %lu", min_iterations); - RAFT_LOG_DEBUG("# max_iterations = %lu", max_iterations); - RAFT_LOG_DEBUG("# max_queries = %lu", max_queries); - RAFT_LOG_DEBUG("# hashmap mode = %s%s-%u", - (small_hash_bitlen > 0 ? "small-" : ""), - "hash", - hashmap::get_size(hash_bitlen)); - if (small_hash_bitlen > 0) { - RAFT_LOG_DEBUG("# small_hash_reset_interval = %lu", small_hash_reset_interval); - } - hashmap_size = sizeof(INDEX_T) * max_queries * hashmap::get_size(hash_bitlen); - RAFT_LOG_DEBUG("# hashmap size: %lu", hashmap_size); - if (hashmap_size >= 1024 * 1024 * 1024) { - RAFT_LOG_DEBUG(" (%.2f GiB)", (double)hashmap_size / (1024 * 1024 * 1024)); - } else if (hashmap_size >= 1024 * 1024) { - RAFT_LOG_DEBUG(" (%.2f MiB)", (double)hashmap_size / (1024 * 1024)); - } else if (hashmap_size >= 1024) { - RAFT_LOG_DEBUG(" (%.2f KiB)", (double)hashmap_size / (1024)); - } - } - - virtual void check(const uint32_t topk) - { - // For single-CTA and multi kernel - RAFT_EXPECTS( - topk <= itopk_size, "topk = %u must be smaller than itopk_size = %lu", topk, itopk_size); - } - - inline void check_params() - { - std::string error_message = ""; - - if (itopk_size > 1024) { - if ((algo == search_algo::MULTI_CTA) || (algo == search_algo::MULTI_KERNEL)) { - } else { - error_message += std::string("- `internal_topk` (" + std::to_string(itopk_size) + - ") must be smaller or equal to 1024"); - } - } - if (algo != search_algo::SINGLE_CTA && algo != search_algo::MULTI_CTA && - algo != search_algo::MULTI_KERNEL) { - error_message += "An invalid kernel mode has been given: " + std::to_string((int)algo) + ""; - } - if (team_size != 0 && team_size != 4 && team_size != 8 && team_size != 16 && team_size != 32) { - error_message += - "`team_size` must be 0, 4, 8, 16 or 32. " + std::to_string(team_size) + " has been given."; - } - if (thread_block_size != 0 && thread_block_size != 64 && thread_block_size != 128 && - thread_block_size != 256 && thread_block_size != 512 && thread_block_size != 1024) { - error_message += "`thread_block_size` must be 0, 64, 128, 256 or 512. " + - std::to_string(thread_block_size) + " has been given."; - } - if (hashmap_min_bitlen > 20) { - error_message += "`hashmap_min_bitlen` must be equal to or smaller than 20. " + - std::to_string(hashmap_min_bitlen) + " has been given."; - } - if (hashmap_max_fill_rate < 0.1 || hashmap_max_fill_rate >= 0.9) { - error_message += - "`hashmap_max_fill_rate` must be equal to or greater than 0.1 and smaller than 0.9. " + - std::to_string(hashmap_max_fill_rate) + " has been given."; - } - if constexpr (!std::is_same::value) { - if (hashmap_mode == hash_mode::SMALL) { - error_message += "`SMALL` hash is not available when filtering"; - } else { - hashmap_mode = hash_mode::HASH; - } - } - if (algo == search_algo::MULTI_CTA) { - if (hashmap_mode == hash_mode::SMALL) { - error_message += "`small_hash` is not available when 'search_mode' is \"multi-cta\""; - } else { - hashmap_mode = hash_mode::HASH; - } - } - - if (error_message.length() != 0) { THROW("[CAGRA Error] %s", error_message.c_str()); } - } -}; - -// template -// struct search_plan { -// search_plan(raft::resources const& res, -// search_params param, -// int64_t dim, -// int64_t graph_degree) -// : plan(res, param, dim, graph_degree) -// { -// } -// void check(uint32_t topk) { plan.check(topk); } - -// // private: -// detail::search_plan_impl plan; -// }; -/** @} */ // end group cagra - -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh deleted file mode 100644 index 0771652787..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "bitonic.hpp" -#include "compute_distance.hpp" -#include "device_common.hpp" -#include "hashmap.hpp" -#include "search_plan.cuh" -#include "search_single_cta_kernel.cuh" -#include "topk_by_radix.cuh" -#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk -#include "utils.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp - -#include - -#include -#include -#include -#include -#include -#include - -namespace raft::neighbors::cagra::detail { -namespace single_cta_search { - -template -struct search : search_plan_impl { - using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; - using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - - using search_plan_impl::max_queries; - using search_plan_impl::itopk_size; - using search_plan_impl::algo; - using search_plan_impl::team_size; - using search_plan_impl::search_width; - using search_plan_impl::min_iterations; - using search_plan_impl::max_iterations; - using search_plan_impl::thread_block_size; - using search_plan_impl::hashmap_mode; - using search_plan_impl::hashmap_min_bitlen; - using search_plan_impl::hashmap_max_fill_rate; - using search_plan_impl::num_random_samplings; - using search_plan_impl::rand_xor_mask; - - using search_plan_impl::dim; - using search_plan_impl::graph_degree; - using search_plan_impl::topk; - - using search_plan_impl::hash_bitlen; - - using search_plan_impl::small_hash_bitlen; - using search_plan_impl::small_hash_reset_interval; - using search_plan_impl::hashmap_size; - using search_plan_impl::dataset_size; - using search_plan_impl::result_buffer_size; - - using search_plan_impl::smem_size; - - using search_plan_impl::hashmap; - using search_plan_impl::num_executed_iterations; - using search_plan_impl::dev_seed; - using search_plan_impl::num_seeds; - - uint32_t num_itopk_candidates; - - search(raft::resources const& res, - search_params params, - int64_t dim, - int64_t graph_degree, - uint32_t topk, - raft::distance::DistanceType metric) - : search_plan_impl( - res, params, dim, graph_degree, topk, metric) - { - set_params(res); - } - - ~search() {} - - inline void set_params(raft::resources const& res) - { - num_itopk_candidates = search_width * graph_degree; - result_buffer_size = itopk_size + num_itopk_candidates; - - typedef raft::Pow2<32> AlignBytes; - unsigned result_buffer_size_32 = AlignBytes::roundUp(result_buffer_size); - - constexpr unsigned max_itopk = 512; - RAFT_EXPECTS(itopk_size <= max_itopk, "itopk_size cannot be larger than %u", max_itopk); - - RAFT_LOG_DEBUG("# num_itopk_candidates: %u", num_itopk_candidates); - RAFT_LOG_DEBUG("# num_itopk: %lu", itopk_size); - // - // Determine the thread block size - // - constexpr unsigned min_block_size = 64; // 32 or 64 - constexpr unsigned min_block_size_radix = 256; - constexpr unsigned max_block_size = 1024; - // - const std::uint32_t topk_ws_size = 3; - const auto query_smem_buffer_length = - raft::ceildiv(dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; - const std::uint32_t base_smem_size = - sizeof(float) * query_smem_buffer_length + - (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + - sizeof(INDEX_T) * hashmap::get_size(small_hash_bitlen) + sizeof(INDEX_T) * search_width + - sizeof(std::uint32_t) * topk_ws_size + sizeof(std::uint32_t) + - DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte; - smem_size = base_smem_size; - if (num_itopk_candidates > 256) { - // Tentatively calculate the required share memory size when radix - // sort based topk is used, assuming the block size is the maximum. - if (itopk_size <= 256) { - smem_size += topk_by_radix_sort<256, INDEX_T>::smem_size * sizeof(std::uint32_t); - } else { - smem_size += topk_by_radix_sort<512, INDEX_T>::smem_size * sizeof(std::uint32_t); - } - } - - uint32_t block_size = thread_block_size; - if (block_size == 0) { - block_size = min_block_size; - - if (num_itopk_candidates > 256) { - // radix-based topk is used. - block_size = min_block_size_radix; - - // Internal topk values per thread must be equlal to or less than 4 - // when radix-sort block_topk is used. - while ((block_size < max_block_size) && (max_itopk / block_size > 4)) { - block_size *= 2; - } - } - - // Increase block size according to shared memory requirements. - // If block size is 32, upper limit of shared memory size per - // thread block is set to 4096. This is GPU generation dependent. - constexpr unsigned ulimit_smem_size_cta32 = 4096; - while (smem_size > ulimit_smem_size_cta32 / 32 * block_size) { - block_size *= 2; - } - - // Increase block size to improve GPU occupancy when batch size - // is small, that is, number of queries is low. - cudaDeviceProp deviceProp = resource::get_device_properties(res); - RAFT_LOG_DEBUG("# multiProcessorCount: %d", deviceProp.multiProcessorCount); - while ((block_size < max_block_size) && - (graph_degree * search_width * team_size >= block_size * 2) && - (max_queries <= (1024 / (block_size * 2)) * deviceProp.multiProcessorCount)) { - block_size *= 2; - } - } - RAFT_LOG_DEBUG("# thread_block_size: %u", block_size); - RAFT_EXPECTS(block_size >= min_block_size, - "block_size cannot be smaller than min_block size, %u", - min_block_size); - RAFT_EXPECTS(block_size <= max_block_size, - "block_size cannot be larger than max_block size %u", - max_block_size); - thread_block_size = block_size; - - if (num_itopk_candidates <= 256) { - RAFT_LOG_DEBUG("# bitonic-sort based topk routine is used"); - } else { - RAFT_LOG_DEBUG("# radix-sort based topk routine is used"); - smem_size = base_smem_size; - if (itopk_size <= 256) { - constexpr unsigned MAX_ITOPK = 256; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); - } else { - constexpr unsigned MAX_ITOPK = 512; - smem_size += topk_by_radix_sort::smem_size * sizeof(std::uint32_t); - } - } - RAFT_LOG_DEBUG("# smem_size: %u", smem_size); - hashmap_size = 0; - if (small_hash_bitlen == 0) { - hashmap_size = sizeof(INDEX_T) * max_queries * hashmap::get_size(hash_bitlen); - hashmap.resize(hashmap_size, resource::get_cuda_stream(res)); - } - RAFT_LOG_DEBUG("# hashmap_size: %lu", hashmap_size); - } - - void operator()(raft::resources const& res, - DATASET_DESCRIPTOR_T dataset_desc, - raft::device_matrix_view graph, - INDEX_T* const result_indices_ptr, // [num_queries, topk] - DISTANCE_T* const result_distances_ptr, // [num_queries, topk] - const DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const std::uint32_t num_queries, - const INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - std::uint32_t* const num_executed_iterations, // [num_queries] - uint32_t topk, - SAMPLE_FILTER_T sample_filter) - { - cudaStream_t stream = resource::get_cuda_stream(res); - select_and_run( - dataset_desc, - graph, - result_indices_ptr, - result_distances_ptr, - queries_ptr, - num_queries, - dev_seed_ptr, - num_executed_iterations, - topk, - num_itopk_candidates, - static_cast(thread_block_size), - smem_size, - hash_bitlen, - hashmap.data(), - small_hash_bitlen, - small_hash_reset_interval, - num_random_samplings, - rand_xor_mask, - num_seeds, - itopk_size, - search_width, - min_iterations, - max_iterations, - sample_filter, - this->metric, - stream); - } -}; - -} // namespace single_cta_search -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh deleted file mode 100644 index 510219ab5d..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh +++ /dev/null @@ -1,602 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include // RAFT_EXPLICIT - -#include - -namespace raft::neighbors::cagra::detail { -namespace single_cta_search { - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -template -void select_and_run( // raft::resources const& res, - DATASET_DESCRIPTOR_T dataset_desc, - raft::device_matrix_view graph, - typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, // [num_queries, topk] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const uint32_t num_queries, - const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk, - uint32_t num_itopk_candidates, - uint32_t block_size, - uint32_t smem_size, - int64_t hash_bitlen, - typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, - size_t small_hash_bitlen, - size_t small_hash_reset_interval, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, - uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, - SAMPLE_FILTER_T sample_filter, - raft::distance::DistanceType metric, - cudaStream_t stream) RAFT_EXPLICIT; - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_single_cta_select_and_run( \ - TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \ - extern template void select_and_run< \ - TEAM_SIZE, \ - MAX_DATASET_DIM, \ - raft::neighbors::cagra::detail::standard_dataset_descriptor_t, \ - SAMPLE_FILTER_T>( \ - raft::neighbors::cagra::detail::standard_dataset_descriptor_t \ - dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - raft::distance::DistanceType metric, \ - cudaStream_t stream); - -instantiate_single_cta_select_and_run( - 32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 1024, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 8, 128, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 16, 256, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 512, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_single_cta_select_and_run( - 32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_single_cta_select_and_run - -#define instantiate_q_single_cta_select_and_run(TEAM_SIZE, \ - MAX_DATASET_DIM, \ - CODE_BOOK_T, \ - PQ_BITS, \ - PQ_CODE_BOOK_DIM, \ - DATA_T, \ - INDEX_T, \ - DISTANCE_T, \ - SAMPLE_FILTER_T) \ - extern template void \ - select_and_run, \ - SAMPLE_FILTER_T>( \ - raft::neighbors::cagra::detail::cagra_q_dataset_descriptor_t dataset, \ - raft::device_matrix_view graph, \ - INDEX_T* const topk_indices_ptr, \ - DISTANCE_T* const topk_distances_ptr, \ - const DATA_T* const queries_ptr, \ - const uint32_t num_queries, \ - const INDEX_T* dev_seed_ptr, \ - uint32_t* const num_executed_iterations, \ - uint32_t topk, \ - uint32_t num_itopk_candidates, \ - uint32_t block_size, \ - uint32_t smem_size, \ - int64_t hash_bitlen, \ - INDEX_T* hashmap_ptr, \ - size_t small_hash_bitlen, \ - size_t small_hash_reset_interval, \ - uint32_t num_random_samplings, \ - uint64_t rand_xor_mask, \ - uint32_t num_seeds, \ - size_t itopk_size, \ - size_t search_width, \ - size_t min_iterations, \ - size_t max_iterations, \ - SAMPLE_FILTER_T sample_filter, \ - raft::distance::DistanceType metric, \ - cudaStream_t stream); - -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 16, 256, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 512, half, 8, 2, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - half, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 16, 256, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 512, half, 8, 4, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - half, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 2, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 2, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 2, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 4, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 4, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 4, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - float, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 2, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 16, 256, half, 8, 2, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 512, half, 8, 2, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 1024, half, 8, 2, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 4, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 16, 256, half, 8, 4, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 512, half, 8, 4, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 1024, half, 8, 4, half, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 2, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 16, 256, half, 8, 2, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 512, half, 8, 2, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - float, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 4, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 16, 256, half, 8, 4, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 32, 512, half, 8, 4, float, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - float, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run(8, - 128, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(8, - 128, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - uint8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run(8, - 128, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(8, - 128, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - int8_t, - uint32_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run(8, - 128, - half, - 8, - 2, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 2, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 2, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(8, - 128, - half, - 8, - 4, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 4, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 4, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - uint8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 2, int8_t, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 2, - int8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 2, - int8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 2, - int8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run( - 8, 128, half, 8, 4, int8_t, int64_t, float, raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(16, - 256, - half, - 8, - 4, - int8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 512, - half, - 8, - 4, - int8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); -instantiate_q_single_cta_select_and_run(32, - 1024, - half, - 8, - 4, - int8_t, - int64_t, - float, - raft::neighbors::filtering::none_cagra_sample_filter); - -#undef instantiate_q_single_cta_select_and_run - -} // namespace single_cta_search -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh deleted file mode 100644 index 232dcb782a..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ /dev/null @@ -1,971 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "bitonic.hpp" -#include "compute_distance.hpp" -#include "device_common.hpp" -#include "hashmap.hpp" -#include "search_plan.cuh" -#include "topk_by_radix.cuh" -#include "topk_for_cagra/topk_core.cuh" // TODO replace with raft topk -#include "utils.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp - -#include - -#include -#include -#include -#include -#include -#include -#include -#include - -namespace raft::neighbors::cagra::detail { -namespace single_cta_search { - -// #define _CLK_BREAKDOWN - -template -__device__ void pickup_next_parents(std::uint32_t* const terminate_flag, - INDEX_T* const next_parent_indices, - INDEX_T* const internal_topk_indices, - const std::size_t internal_topk_size, - const std::size_t dataset_size, - const std::uint32_t search_width) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - // if (threadIdx.x >= 32) return; - - for (std::uint32_t i = threadIdx.x; i < search_width; i += 32) { - next_parent_indices[i] = utils::get_max_value(); - } - std::uint32_t itopk_max = internal_topk_size; - if (itopk_max % 32) { itopk_max += 32 - (itopk_max % 32); } - std::uint32_t num_new_parents = 0; - for (std::uint32_t j = threadIdx.x; j < itopk_max; j += 32) { - std::uint32_t jj = j; - if (TOPK_BY_BITONIC_SORT) { jj = device::swizzling(j); } - INDEX_T index; - int new_parent = 0; - if (j < internal_topk_size) { - index = internal_topk_indices[jj]; - if ((index & index_msb_1_mask) == 0) { // check if most significant bit is set - new_parent = 1; - } - } - const std::uint32_t ballot_mask = __ballot_sync(0xffffffff, new_parent); - if (new_parent) { - const auto i = __popc(ballot_mask & ((1 << threadIdx.x) - 1)) + num_new_parents; - if (i < search_width) { - next_parent_indices[i] = jj; - // set most significant bit as used node - internal_topk_indices[jj] |= index_msb_1_mask; - } - } - num_new_parents += __popc(ballot_mask); - if (num_new_parents >= search_width) { break; } - } - if (threadIdx.x == 0 && (num_new_parents == 0)) { *terminate_flag = 1; } -} - -template -__device__ inline void topk_by_bitonic_sort_1st(float* candidate_distances, // [num_candidates] - IdxT* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - const std::uint32_t num_itopk, - unsigned MULTI_WARPS = 0) -{ - const unsigned lane_id = threadIdx.x % 32; - const unsigned warp_id = threadIdx.x / 32; - if (MULTI_WARPS == 0) { - if (warp_id > 0) { return; } - constexpr unsigned N = (MAX_CANDIDATES + 31) / 32; - float key[N]; - IdxT val[N]; - /* Candidates -> Reg */ - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i); - if (j < num_candidates) { - key[i] = candidate_distances[j]; - val[i] = candidate_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Sort */ - bitonic::warp_sort(key, val); - /* Reg -> Temp_itopk */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_candidates && j < num_itopk) { - candidate_distances[device::swizzling(j)] = key[i]; - candidate_indices[device::swizzling(j)] = val[i]; - } - } - } else { - // Use two warps (64 threads) - constexpr unsigned max_candidates_per_warp = (MAX_CANDIDATES + 1) / 2; - constexpr unsigned N = (max_candidates_per_warp + 31) / 32; - float key[N]; - IdxT val[N]; - if (warp_id < 2) { - /* Candidates -> Reg */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = lane_id + (32 * i); - unsigned j = jl + (max_candidates_per_warp * warp_id); - if (j < num_candidates) { - key[i] = candidate_distances[j]; - val[i] = candidate_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Sort */ - bitonic::warp_sort(key, val); - /* Reg -> Temp_candidates */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = (N * lane_id) + i; - unsigned j = jl + (max_candidates_per_warp * warp_id); - if (j < num_candidates && jl < num_itopk) { - candidate_distances[device::swizzling(j)] = key[i]; - candidate_indices[device::swizzling(j)] = val[i]; - } - } - } - __syncthreads(); - - unsigned num_warps_used = (num_itopk + max_candidates_per_warp - 1) / max_candidates_per_warp; - if (warp_id < num_warps_used) { - /* Temp_candidates -> Reg */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = (N * lane_id) + i; - unsigned kl = max_candidates_per_warp - 1 - jl; - unsigned j = jl + (max_candidates_per_warp * warp_id); - unsigned k = MAX_CANDIDATES - 1 - j; - if (j >= num_candidates || k >= num_candidates || kl >= num_itopk) continue; - float temp_key = candidate_distances[device::swizzling(k)]; - if (key[i] == temp_key) continue; - if ((warp_id == 0) == (key[i] > temp_key)) { - key[i] = temp_key; - val[i] = candidate_indices[device::swizzling(k)]; - } - } - } - if (num_warps_used > 1) { __syncthreads(); } - if (warp_id < num_warps_used) { - /* Merge */ - bitonic::warp_merge(key, val, 32); - /* Reg -> Temp_itopk */ - for (unsigned i = 0; i < N; i++) { - unsigned jl = (N * lane_id) + i; - unsigned j = jl + (max_candidates_per_warp * warp_id); - if (j < num_candidates && j < num_itopk) { - candidate_distances[device::swizzling(j)] = key[i]; - candidate_indices[device::swizzling(j)] = val[i]; - } - } - } - if (num_warps_used > 1) { __syncthreads(); } - } -} - -template -__device__ inline void topk_by_bitonic_sort_2nd(float* itopk_distances, // [num_itopk] - IdxT* itopk_indices, // [num_itopk] - const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - IdxT* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - std::uint32_t* work_buf, - const bool first, - unsigned MULTI_WARPS = 0) -{ - const unsigned lane_id = threadIdx.x % 32; - const unsigned warp_id = threadIdx.x / 32; - if (MULTI_WARPS == 0) { - if (warp_id > 0) { return; } - constexpr unsigned N = (MAX_ITOPK + 31) / 32; - float key[N]; - IdxT val[N]; - if (first) { - /* Load itopk results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i); - if (j < num_itopk) { - key[i] = itopk_distances[j]; - val[i] = itopk_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Sort */ - bitonic::warp_sort(key, val); - } else { - /* Load itopk results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_itopk) { - key[i] = itopk_distances[device::swizzling(j)]; - val[i] = itopk_indices[device::swizzling(j)]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - } - /* Merge candidates */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; // [0:MAX_ITOPK-1] - unsigned k = MAX_ITOPK - 1 - j; - if (k >= num_itopk || k >= num_candidates) continue; - float candidate_key = candidate_distances[device::swizzling(k)]; - if (key[i] > candidate_key) { - key[i] = candidate_key; - val[i] = candidate_indices[device::swizzling(k)]; - } - } - /* Warp Merge */ - bitonic::warp_merge(key, val, 32); - /* Store new itopk results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * lane_id) + i; - if (j < num_itopk) { - itopk_distances[device::swizzling(j)] = key[i]; - itopk_indices[device::swizzling(j)] = val[i]; - } - } - } else { - // Use two warps (64 threads) or more - constexpr unsigned max_itopk_per_warp = (MAX_ITOPK + 1) / 2; - constexpr unsigned N = (max_itopk_per_warp + 31) / 32; - float key[N]; - IdxT val[N]; - if (first) { - /* Load itop results (not sorted) */ - if (warp_id < 2) { - for (unsigned i = 0; i < N; i++) { - unsigned j = lane_id + (32 * i) + (max_itopk_per_warp * warp_id); - if (j < num_itopk) { - key[i] = itopk_distances[j]; - val[i] = itopk_indices[j]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Sort */ - bitonic::warp_sort(key, val); - /* Store intermedidate results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * threadIdx.x) + i; - if (j >= num_itopk) continue; - itopk_distances[device::swizzling(j)] = key[i]; - itopk_indices[device::swizzling(j)] = val[i]; - } - } - __syncthreads(); - if (warp_id < 2) { - /* Load intermedidate results */ - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * threadIdx.x) + i; - unsigned k = MAX_ITOPK - 1 - j; - if (k >= num_itopk) continue; - float temp_key = itopk_distances[device::swizzling(k)]; - if (key[i] == temp_key) continue; - if ((warp_id == 0) == (key[i] > temp_key)) { - key[i] = temp_key; - val[i] = itopk_indices[device::swizzling(k)]; - } - } - /* Warp Merge */ - bitonic::warp_merge(key, val, 32); - } - __syncthreads(); - /* Store itopk results (sorted) */ - if (warp_id < 2) { - for (unsigned i = 0; i < N; i++) { - unsigned j = (N * threadIdx.x) + i; - if (j >= num_itopk) continue; - itopk_distances[device::swizzling(j)] = key[i]; - itopk_indices[device::swizzling(j)] = val[i]; - } - } - } - const uint32_t num_itopk_div2 = num_itopk / 2; - if (threadIdx.x < 3) { - // work_buf is used to obtain turning points in 1st and 2nd half of itopk afer merge. - work_buf[threadIdx.x] = num_itopk_div2; - } - __syncthreads(); - - // Merge candidates (using whole threads) - for (unsigned k = threadIdx.x; k < min(num_candidates, num_itopk); k += blockDim.x) { - const unsigned j = num_itopk - 1 - k; - const float itopk_key = itopk_distances[device::swizzling(j)]; - const float candidate_key = candidate_distances[device::swizzling(k)]; - if (itopk_key > candidate_key) { - itopk_distances[device::swizzling(j)] = candidate_key; - itopk_indices[device::swizzling(j)] = candidate_indices[device::swizzling(k)]; - if (j < num_itopk_div2) { - atomicMin(work_buf + 2, j); - } else { - atomicMin(work_buf + 1, j - num_itopk_div2); - } - } - } - __syncthreads(); - - // Merge 1st and 2nd half of itopk (using whole threads) - for (unsigned j = threadIdx.x; j < num_itopk_div2; j += blockDim.x) { - const unsigned k = j + num_itopk_div2; - float key_0 = itopk_distances[device::swizzling(j)]; - float key_1 = itopk_distances[device::swizzling(k)]; - if (key_0 > key_1) { - itopk_distances[device::swizzling(j)] = key_1; - itopk_distances[device::swizzling(k)] = key_0; - IdxT val_0 = itopk_indices[device::swizzling(j)]; - IdxT val_1 = itopk_indices[device::swizzling(k)]; - itopk_indices[device::swizzling(j)] = val_1; - itopk_indices[device::swizzling(k)] = val_0; - atomicMin(work_buf + 0, j); - } - } - if (threadIdx.x == blockDim.x - 1) { - if (work_buf[2] < num_itopk_div2) { work_buf[1] = work_buf[2]; } - } - __syncthreads(); - // if ((blockIdx.x == 0) && (threadIdx.x == 0)) { - // RAFT_LOG_DEBUG( "work_buf: %u, %u, %u\n", work_buf[0], work_buf[1], work_buf[2] ); - // } - - // Warp-0 merges 1st half of itopk, warp-1 does 2nd half. - if (warp_id < 2) { - // Load intermedidate itopk results - const uint32_t turning_point = work_buf[warp_id]; // turning_point <= num_itopk_div2 - for (unsigned i = 0; i < N; i++) { - unsigned k = num_itopk; - unsigned j = (N * lane_id) + i; - if (j < turning_point) { - k = j + (num_itopk_div2 * warp_id); - } else if (j >= (MAX_ITOPK / 2 - num_itopk_div2)) { - j -= (MAX_ITOPK / 2 - num_itopk_div2); - if ((turning_point <= j) && (j < num_itopk_div2)) { k = j + (num_itopk_div2 * warp_id); } - } - if (k < num_itopk) { - key[i] = itopk_distances[device::swizzling(k)]; - val[i] = itopk_indices[device::swizzling(k)]; - } else { - key[i] = utils::get_max_value(); - val[i] = utils::get_max_value(); - } - } - /* Warp Merge */ - bitonic::warp_merge(key, val, 32); - /* Store new itopk results */ - for (unsigned i = 0; i < N; i++) { - const unsigned j = (N * lane_id) + i; - if (j < num_itopk_div2) { - unsigned k = j + (num_itopk_div2 * warp_id); - itopk_distances[device::swizzling(k)] = key[i]; - itopk_indices[device::swizzling(k)] = val[i]; - } - } - } - } -} - -template -__device__ void topk_by_bitonic_sort(float* itopk_distances, // [num_itopk] - IdxT* itopk_indices, // [num_itopk] - const std::uint32_t num_itopk, - float* candidate_distances, // [num_candidates] - IdxT* candidate_indices, // [num_candidates] - const std::uint32_t num_candidates, - std::uint32_t* work_buf, - const bool first, - const unsigned MULTI_WARPS_1, - const unsigned MULTI_WARPS_2) -{ - // The results in candidate_distances/indices are sorted by bitonic sort. - topk_by_bitonic_sort_1st( - candidate_distances, candidate_indices, num_candidates, num_itopk, MULTI_WARPS_1); - - // The results sorted above are merged with the internal intermediate top-k - // results so far using bitonic merge. - topk_by_bitonic_sort_2nd(itopk_distances, - itopk_indices, - num_itopk, - candidate_distances, - candidate_indices, - num_candidates, - work_buf, - first, - MULTI_WARPS_2); -} - -template -__device__ inline void hashmap_restore(INDEX_T* const hashmap_ptr, - const size_t hashmap_bitlen, - const INDEX_T* itopk_indices, - const uint32_t itopk_size, - const uint32_t first_tid = 0) -{ - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - if (threadIdx.x < first_tid) return; - for (unsigned i = threadIdx.x - first_tid; i < itopk_size; i += blockDim.x - first_tid) { - auto key = itopk_indices[i] & ~index_msb_1_mask; // clear most significant bit - hashmap::insert(hashmap_ptr, hashmap_bitlen, key); - } -} - -// One query one thread block -template -__launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( - typename DATASET_DESCRIPTOR_T::INDEX_T* const result_indices_ptr, // [num_queries, top_k] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const result_distances_ptr, // [num_queries, top_k] - const std::uint32_t top_k, - DATASET_DESCRIPTOR_T dataset_desc, - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const typename DATASET_DESCRIPTOR_T::INDEX_T* const knn_graph, // [dataset_size, graph_degree] - const std::uint32_t graph_degree, - const unsigned num_distilation, - const uint64_t rand_xor_mask, - const typename DATASET_DESCRIPTOR_T::INDEX_T* seed_ptr, // [num_queries, num_seeds] - const uint32_t num_seeds, - typename DATASET_DESCRIPTOR_T::INDEX_T* const - visited_hashmap_ptr, // [num_queries, 1 << hash_bitlen] - const std::uint32_t internal_topk, - const std::uint32_t search_width, - const std::uint32_t min_iteration, - const std::uint32_t max_iteration, - std::uint32_t* const num_executed_iterations, // [num_queries] - const std::uint32_t hash_bitlen, - const std::uint32_t small_hash_bitlen, - const std::uint32_t small_hash_reset_interval, - SAMPLE_FILTER_T sample_filter, - raft::distance::DistanceType metric) -{ - using LOAD_T = device::LOAD_128BIT_T; - - using DATA_T = typename DATASET_DESCRIPTOR_T::DATA_T; - using INDEX_T = typename DATASET_DESCRIPTOR_T::INDEX_T; - using DISTANCE_T = typename DATASET_DESCRIPTOR_T::DISTANCE_T; - using QUERY_T = typename DATASET_DESCRIPTOR_T::QUERY_T; - - const auto query_id = blockIdx.y; - -#ifdef _CLK_BREAKDOWN - std::uint64_t clk_init = 0; - std::uint64_t clk_compute_1st_distance = 0; - std::uint64_t clk_topk = 0; - std::uint64_t clk_reset_hash = 0; - std::uint64_t clk_pickup_parents = 0; - std::uint64_t clk_restore_hash = 0; - std::uint64_t clk_compute_distance = 0; - std::uint64_t clk_start; -#define _CLK_START() clk_start = clock64() -#define _CLK_REC(V) V += clock64() - clk_start; -#else -#define _CLK_START() -#define _CLK_REC(V) -#endif - _CLK_START(); - - extern __shared__ std::uint32_t smem[]; - - // Layout of result_buffer - // +----------------------+------------------------------+---------+ - // | internal_top_k | neighbors of internal_top_k | padding | - // | | | upto 32 | - // +----------------------+------------------------------+---------+ - // |<--- result_buffer_size --->| - std::uint32_t result_buffer_size = internal_topk + (search_width * graph_degree); - std::uint32_t result_buffer_size_32 = result_buffer_size; - if (result_buffer_size % 32) { result_buffer_size_32 += 32 - (result_buffer_size % 32); } - const auto small_hash_size = hashmap::get_size(small_hash_bitlen); - - const auto query_smem_buffer_length = - raft::ceildiv(dataset_desc.dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; - auto query_buffer = reinterpret_cast(smem); - auto result_indices_buffer = reinterpret_cast(query_buffer + query_smem_buffer_length); - auto result_distances_buffer = - reinterpret_cast(result_indices_buffer + result_buffer_size_32); - auto visited_hash_buffer = - reinterpret_cast(result_distances_buffer + result_buffer_size_32); - auto parent_list_buffer = reinterpret_cast(visited_hash_buffer + small_hash_size); - auto distance_work_buffer_ptr = - reinterpret_cast(parent_list_buffer + search_width); - auto topk_ws = reinterpret_cast(distance_work_buffer_ptr + - DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte); - auto terminate_flag = reinterpret_cast(topk_ws + 3); - auto smem_work_ptr = reinterpret_cast(terminate_flag + 1); - - // Set smem working buffer for the distance calculation - dataset_desc.set_smem_ptr(distance_work_buffer_ptr); - - // A flag for filtering. - auto filter_flag = terminate_flag; - - const DATA_T* const query_ptr = queries_ptr + query_id * dataset_desc.dim; - dataset_desc.template copy_query( - query_ptr, query_buffer, query_smem_buffer_length); - - if (threadIdx.x == 0) { - terminate_flag[0] = 0; - topk_ws[0] = ~0u; - } - - // Init hashmap - INDEX_T* local_visited_hashmap_ptr; - if (small_hash_bitlen) { - local_visited_hashmap_ptr = visited_hash_buffer; - } else { - local_visited_hashmap_ptr = visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id); - } - hashmap::init(local_visited_hashmap_ptr, hash_bitlen, 0); - __syncthreads(); - _CLK_REC(clk_init); - - // compute distance to randomly selecting nodes - _CLK_START(); - const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; - device::compute_distance_to_random_nodes(result_indices_buffer, - result_distances_buffer, - query_buffer, - dataset_desc, - result_buffer_size, - num_distilation, - rand_xor_mask, - local_seed_ptr, - num_seeds, - local_visited_hashmap_ptr, - hash_bitlen, - metric); - __syncthreads(); - _CLK_REC(clk_compute_1st_distance); - - std::uint32_t iter = 0; - while (1) { - // sort - if constexpr (TOPK_BY_BITONIC_SORT) { - // [Notice] - // It is good to use multiple warps in topk_by_bitonic_sort() when - // batch size is small (short-latency), but it might not be always good - // when batch size is large (high-throughput). - // topk_by_bitonic_sort() consists of two operations: - // if MAX_CANDIDATES is greater than 128, the first operation uses two warps; - // if MAX_ITOPK is greater than 256, the second operation used two warps. - const unsigned multi_warps_1 = ((blockDim.x >= 64) && (MAX_CANDIDATES > 128)) ? 1 : 0; - const unsigned multi_warps_2 = ((blockDim.x >= 64) && (MAX_ITOPK > 256)) ? 1 : 0; - - // reset small-hash table. - if ((iter + 1) % small_hash_reset_interval == 0) { - // Depending on the block size and the number of warps used in - // topk_by_bitonic_sort(), determine which warps are used to reset - // the small hash and whether they are performed in overlap with - // topk_by_bitonic_sort(). - _CLK_START(); - unsigned hash_start_tid; - if (blockDim.x == 32) { - hash_start_tid = 0; - } else if (blockDim.x == 64) { - if (multi_warps_1 || multi_warps_2) { - hash_start_tid = 0; - } else { - hash_start_tid = 32; - } - } else { - if (multi_warps_1 || multi_warps_2) { - hash_start_tid = 64; - } else { - hash_start_tid = 32; - } - } - hashmap::init(local_visited_hashmap_ptr, hash_bitlen, hash_start_tid); - _CLK_REC(clk_reset_hash); - } - - // topk with bitonic sort - _CLK_START(); - if (std::is_same::value || - *filter_flag == 0) { - topk_by_bitonic_sort(result_distances_buffer, - result_indices_buffer, - internal_topk, - result_distances_buffer + internal_topk, - result_indices_buffer + internal_topk, - search_width * graph_degree, - topk_ws, - (iter == 0), - multi_warps_1, - multi_warps_2); - __syncthreads(); - } else { - topk_by_bitonic_sort_1st( - result_distances_buffer, - result_indices_buffer, - internal_topk + search_width * graph_degree, - internal_topk, - false); - if (threadIdx.x == 0) { *terminate_flag = 0; } - } - _CLK_REC(clk_topk); - } else { - _CLK_START(); - // topk with radix block sort - topk_by_radix_sort{}( - internal_topk, - gridDim.x, - result_buffer_size, - reinterpret_cast(result_distances_buffer), - result_indices_buffer, - reinterpret_cast(result_distances_buffer), - result_indices_buffer, - nullptr, - topk_ws, - true, - reinterpret_cast(smem_work_ptr)); - _CLK_REC(clk_topk); - - // reset small-hash table - if ((iter + 1) % small_hash_reset_interval == 0) { - _CLK_START(); - hashmap::init(local_visited_hashmap_ptr, hash_bitlen); - _CLK_REC(clk_reset_hash); - } - } - __syncthreads(); - - if (iter + 1 == max_iteration) { break; } - - // pick up next parents - if (threadIdx.x < 32) { - _CLK_START(); - pickup_next_parents(terminate_flag, - parent_list_buffer, - result_indices_buffer, - internal_topk, - dataset_desc.size, - search_width); - _CLK_REC(clk_pickup_parents); - } - - // restore small-hash table by putting internal-topk indices in it - if ((iter + 1) % small_hash_reset_interval == 0) { - const unsigned first_tid = ((blockDim.x <= 32) ? 0 : 32); - _CLK_START(); - hashmap_restore( - local_visited_hashmap_ptr, hash_bitlen, result_indices_buffer, internal_topk, first_tid); - _CLK_REC(clk_restore_hash); - } - __syncthreads(); - - if (*terminate_flag && iter >= min_iteration) { break; } - - // compute the norms between child nodes and query node - _CLK_START(); - constexpr unsigned max_n_frags = 8; - device::compute_distance_to_child_nodes( - result_indices_buffer + internal_topk, - result_distances_buffer + internal_topk, - query_buffer, - dataset_desc, - knn_graph, - graph_degree, - local_visited_hashmap_ptr, - hash_bitlen, - parent_list_buffer, - result_indices_buffer, - search_width, - metric); - __syncthreads(); - _CLK_REC(clk_compute_distance); - - // Filtering - if constexpr (!std::is_same::value) { - if (threadIdx.x == 0) { *filter_flag = 0; } - __syncthreads(); - - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - for (unsigned p = threadIdx.x; p < search_width; p += blockDim.x) { - if (parent_list_buffer[p] != invalid_index) { - const auto parent_id = result_indices_buffer[parent_list_buffer[p]] & ~index_msb_1_mask; - if (!sample_filter(query_id, parent_id)) { - // If the parent must not be in the resulting top-k list, remove from the parent list - result_distances_buffer[parent_list_buffer[p]] = utils::get_max_value(); - result_indices_buffer[parent_list_buffer[p]] = invalid_index; - *filter_flag = 1; - } - } - } - __syncthreads(); - } - - iter++; - } - - // Post process for filtering - if constexpr (!std::is_same::value) { - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const INDEX_T invalid_index = utils::get_max_value(); - - for (unsigned i = threadIdx.x; i < internal_topk + search_width * graph_degree; - i += blockDim.x) { - const auto node_id = result_indices_buffer[i] & ~index_msb_1_mask; - if (node_id != (invalid_index & ~index_msb_1_mask) && !sample_filter(query_id, node_id)) { - result_distances_buffer[i] = utils::get_max_value(); - result_indices_buffer[i] = invalid_index; - } - } - - __syncthreads(); - topk_by_bitonic_sort_1st( - result_distances_buffer, - result_indices_buffer, - internal_topk + search_width * graph_degree, - top_k, - false); - __syncthreads(); - } - - for (std::uint32_t i = threadIdx.x; i < top_k; i += blockDim.x) { - unsigned j = i + (top_k * query_id); - unsigned ii = i; - if (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); } - if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; } - constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - - result_indices_ptr[j] = - result_indices_buffer[ii] & ~index_msb_1_mask; // clear most significant bit - } - if (threadIdx.x == 0 && num_executed_iterations != nullptr) { - num_executed_iterations[query_id] = iter + 1; - } -#ifdef _CLK_BREAKDOWN - if ((threadIdx.x == 0 || threadIdx.x == blockDim.x - 1) && ((query_id * 3) % gridDim.y < 3)) { - printf( - "%s:%d " - "query, %d, thread, %d" - ", init, %lu" - ", 1st_distance, %lu" - ", topk, %lu" - ", reset_hash, %lu" - ", pickup_parents, %lu" - ", restore_hash, %lu" - ", distance, %lu" - "\n", - __FILE__, - __LINE__, - query_id, - threadIdx.x, - clk_init, - clk_compute_1st_distance, - clk_topk, - clk_reset_hash, - clk_pickup_parents, - clk_restore_hash, - clk_compute_distance); - } -#endif -} - -template -struct search_kernel_config { - using kernel_t = decltype(&search_kernel); - - template - static auto choose_search_kernel(unsigned itopk_size) -> kernel_t - { - if (itopk_size <= 64) { - return search_kernel; - } else if (itopk_size <= 128) { - return search_kernel; - } else if (itopk_size <= 256) { - return search_kernel; - } else if (itopk_size <= 512) { - return search_kernel; - } - THROW("No kernel for parametels itopk_size %u, max_candidates %u", itopk_size, MAX_CANDIDATES); - } - - static auto choose_itopk_and_mx_candidates(unsigned itopk_size, - unsigned num_itopk_candidates, - unsigned block_size) -> kernel_t - { - if (num_itopk_candidates <= 64) { - // use bitonic sort based topk - return choose_search_kernel<64, 1>(itopk_size); - } else if (num_itopk_candidates <= 128) { - return choose_search_kernel<128, 1>(itopk_size); - } else if (num_itopk_candidates <= 256) { - return choose_search_kernel<256, 1>(itopk_size); - } else { - // Radix-based topk is used - constexpr unsigned max_candidates = 32; // to avoid build failure - if (itopk_size <= 256) { - return search_kernel; - } else if (itopk_size <= 512) { - return search_kernel; - } - } - THROW("No kernel for parametels itopk_size %u, num_itopk_candidates %u", - itopk_size, - num_itopk_candidates); - } -}; - -template -void select_and_run( - DATASET_DESCRIPTOR_T dataset_desc, - raft::device_matrix_view graph, - typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, // [num_queries, topk] - typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] - const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim] - const uint32_t num_queries, - const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds] - uint32_t* const num_executed_iterations, // [num_queries,] - uint32_t topk, - uint32_t num_itopk_candidates, - uint32_t block_size, // - uint32_t smem_size, - int64_t hash_bitlen, - typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr, - size_t small_hash_bitlen, - size_t small_hash_reset_interval, - uint32_t num_random_samplings, - uint64_t rand_xor_mask, - uint32_t num_seeds, - size_t itopk_size, - size_t search_width, - size_t min_iterations, - size_t max_iterations, - SAMPLE_FILTER_T sample_filter, - raft::distance::DistanceType metric, - cudaStream_t stream) -{ - auto kernel = - search_kernel_config:: - choose_itopk_and_mx_candidates(itopk_size, num_itopk_candidates, block_size); - RAFT_CUDA_TRY(cudaFuncSetAttribute(kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size + DATASET_DESCRIPTOR_T::smem_buffer_size_in_byte)); - dim3 thread_dims(block_size, 1, 1); - dim3 block_dims(1, num_queries, 1); - RAFT_LOG_DEBUG( - "Launching kernel with %u threads, %u block %u smem", block_size, num_queries, smem_size); - kernel<<>>(topk_indices_ptr, - topk_distances_ptr, - topk, - dataset_desc, - queries_ptr, - graph.data_handle(), - graph.extent(1), - num_random_samplings, - rand_xor_mask, - dev_seed_ptr, - num_seeds, - hashmap_ptr, - itopk_size, - search_width, - min_iterations, - max_iterations, - num_executed_iterations, - hash_bitlen, - small_hash_bitlen, - small_hash_reset_interval, - sample_filter, - metric); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} -} // namespace single_cta_search -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel.cuh deleted file mode 100644 index 1d8fd8e30a..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "search_single_cta_kernel-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "search_single_cta_kernel-ext.cuh" -#endif diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh b/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh deleted file mode 100644 index 6a6a3cddf4..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh +++ /dev/null @@ -1,91 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "topk_for_cagra/topk_core.cuh" - -namespace raft::neighbors::cagra::detail { -namespace single_cta_search { - -template -struct topk_by_radix_sort_base { - static constexpr std::uint32_t smem_size = MAX_INTERNAL_TOPK * 2 + 2048 + 8; - static constexpr std::uint32_t state_bit_lenght = 0; - static constexpr std::uint32_t vecLen = 2; // TODO -}; -template -struct topk_by_radix_sort : topk_by_radix_sort_base {}; - -template -struct topk_by_radix_sort> - : topk_by_radix_sort_base { - __device__ void operator()(uint32_t topk, - uint32_t batch_size, - uint32_t len_x, - const uint32_t* _x, - const IdxT* _in_vals, - uint32_t* _y, - IdxT* _out_vals, - uint32_t* work, - uint32_t* _hints, - bool sort, - uint32_t* _smem) - { - std::uint8_t* const state = reinterpret_cast(work); - topk_cta_11_core::state_bit_lenght, - topk_by_radix_sort_base::vecLen, - 64, - 32, - IdxT>(topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); - } -}; - -#define TOP_FUNC_PARTIAL_SPECIALIZATION(V) \ - template \ - struct topk_by_radix_sort< \ - MAX_INTERNAL_TOPK, \ - IdxT, \ - std::enable_if_t<((MAX_INTERNAL_TOPK <= V) && (2 * MAX_INTERNAL_TOPK > V))>> \ - : topk_by_radix_sort_base { \ - __device__ void operator()(uint32_t topk, \ - uint32_t batch_size, \ - uint32_t len_x, \ - const uint32_t* _x, \ - const IdxT* _in_vals, \ - uint32_t* _y, \ - IdxT* _out_vals, \ - uint32_t* work, \ - uint32_t* _hints, \ - bool sort, \ - uint32_t* _smem) \ - { \ - assert(blockDim.x >= V / 4); \ - std::uint8_t* state = (std::uint8_t*)work; \ - topk_cta_11_core::state_bit_lenght, \ - topk_by_radix_sort_base::vecLen, \ - V, \ - V / 4, \ - IdxT>( \ - topk, len_x, _x, _in_vals, _y, _out_vals, state, _hints, sort, _smem); \ - } \ - }; -TOP_FUNC_PARTIAL_SPECIALIZATION(128); -TOP_FUNC_PARTIAL_SPECIALIZATION(256); -TOP_FUNC_PARTIAL_SPECIALIZATION(512); -TOP_FUNC_PARTIAL_SPECIALIZATION(1024); - -} // namespace single_cta_search -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h deleted file mode 100644 index ac1a746090..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -#include - -namespace raft::neighbors::cagra::detail { - -// -size_t _cuann_find_topk_bufferSize(uint32_t topK, - uint32_t sizeBatch, - uint32_t numElements, - cudaDataType_t sampleDtype = CUDA_R_32F); - -// -template -void _cuann_find_topk(uint32_t topK, - uint32_t sizeBatch, - uint32_t numElements, - const float* inputKeys, // [sizeBatch, ldIK,] - uint32_t ldIK, // (*) ldIK >= numElements - const ValT* inputVals, // [sizeBatch, ldIV,] - uint32_t ldIV, // (*) ldIV >= numElements - float* outputKeys, // [sizeBatch, ldOK,] - uint32_t ldOK, // (*) ldOK >= topK - ValT* outputVals, // [sizeBatch, ldOV,] - uint32_t ldOV, // (*) ldOV >= topK - void* workspace, - bool sort = false, - uint32_t* hint = NULL, - cudaStream_t stream = 0); - -#ifdef __CUDA_ARCH__ -#define CUDA_DEVICE_HOST_FUNC __device__ -#else -#define CUDA_DEVICE_HOST_FUNC -#endif -// -CUDA_DEVICE_HOST_FUNC inline size_t _cuann_aligned(size_t size, size_t unit = 128) -{ - if (size % unit) { size += unit - (size % unit); } - return size; -} -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh deleted file mode 100644 index aedeb1be67..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh +++ /dev/null @@ -1,1040 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once -#include "topk.h" - -#include - -#include -#include -#include -#include - -namespace raft::neighbors::cagra::detail { -// -__device__ inline uint32_t convert(uint32_t x) -{ - if (x & 0x80000000) { - return x ^ 0xffffffff; - } else { - return x ^ 0x80000000; - } -} - -// -__device__ inline uint16_t convert(uint16_t x) -{ - if (x & 0x8000) { - return x ^ 0xffff; - } else { - return x ^ 0x8000; - } -} - -// -struct u32_vector { - uint1 x1; - uint2 x2; - uint4 x4; - ulonglong4 x8; -}; - -// -struct u16_vector { - ushort1 x1; - ushort2 x2; - ushort4 x4; - uint4 x8; -}; - -// -template -__device__ inline void load_u32_vector(struct u32_vector& vec, const uint32_t* x, int i) -{ - if (vecLen == 1) { - vec.x1 = ((uint1*)(x + i))[0]; - } else if (vecLen == 2) { - vec.x2 = ((uint2*)(x + i))[0]; - } else if (vecLen == 4) { - vec.x4 = ((uint4*)(x + i))[0]; - } else if (vecLen == 8) { - vec.x8 = ((ulonglong4*)(x + i))[0]; - } -} - -// -template -__device__ inline void load_u16_vector(struct u16_vector& vec, const uint16_t* x, int i) -{ - if (vecLen == 1) { - vec.x1 = ((ushort1*)(x + i))[0]; - } else if (vecLen == 2) { - vec.x2 = ((ushort2*)(x + i))[0]; - } else if (vecLen == 4) { - vec.x4 = ((ushort4*)(x + i))[0]; - } else if (vecLen == 8) { - vec.x8 = ((uint4*)(x + i))[0]; - } -} - -// -template -__device__ inline uint32_t get_element_from_u32_vector(struct u32_vector& vec, int i) -{ - uint32_t xi; - if (vecLen == 1) { - xi = convert(vec.x1.x); - } else if (vecLen == 2) { - if (i == 0) - xi = convert(vec.x2.x); - else - xi = convert(vec.x2.y); - } else if (vecLen == 4) { - if (i == 0) - xi = convert(vec.x4.x); - else if (i == 1) - xi = convert(vec.x4.y); - else if (i == 2) - xi = convert(vec.x4.z); - else - xi = convert(vec.x4.w); - } else if (vecLen == 8) { - if (i == 0) - xi = convert((uint32_t)(vec.x8.x & 0xffffffff)); - else if (i == 1) - xi = convert((uint32_t)(vec.x8.x >> 32)); - else if (i == 2) - xi = convert((uint32_t)(vec.x8.y & 0xffffffff)); - else if (i == 3) - xi = convert((uint32_t)(vec.x8.y >> 32)); - else if (i == 4) - xi = convert((uint32_t)(vec.x8.z & 0xffffffff)); - else if (i == 5) - xi = convert((uint32_t)(vec.x8.z >> 32)); - else if (i == 6) - xi = convert((uint32_t)(vec.x8.w & 0xffffffff)); - else - xi = convert((uint32_t)(vec.x8.w >> 32)); - } - return xi; -} - -// -template -__device__ inline uint16_t get_element_from_u16_vector(struct u16_vector& vec, int i) -{ - uint16_t xi; - if (vecLen == 1) { - xi = convert(vec.x1.x); - } else if (vecLen == 2) { - if (i == 0) - xi = convert(vec.x2.x); - else - xi = convert(vec.x2.y); - } else if (vecLen == 4) { - if (i == 0) - xi = convert(vec.x4.x); - else if (i == 1) - xi = convert(vec.x4.y); - else if (i == 2) - xi = convert(vec.x4.z); - else - xi = convert(vec.x4.w); - } else if (vecLen == 8) { - if (i == 0) - xi = convert((uint16_t)(vec.x8.x & 0xffff)); - else if (i == 1) - xi = convert((uint16_t)(vec.x8.x >> 16)); - else if (i == 2) - xi = convert((uint16_t)(vec.x8.y & 0xffff)); - else if (i == 3) - xi = convert((uint16_t)(vec.x8.y >> 16)); - else if (i == 4) - xi = convert((uint16_t)(vec.x8.z & 0xffff)); - else if (i == 5) - xi = convert((uint16_t)(vec.x8.z >> 16)); - else if (i == 6) - xi = convert((uint16_t)(vec.x8.w & 0xffff)); - else - xi = convert((uint16_t)(vec.x8.w >> 16)); - } - return xi; -} - -template -__device__ inline void block_scan(const T input, T& output) -{ - switch (blockDim.x) { - case 32: { - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - BlockScanT(temp_storage).InclusiveSum(input, output); - } break; - case 64: { - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - BlockScanT(temp_storage).InclusiveSum(input, output); - } break; - case 128: { - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - BlockScanT(temp_storage).InclusiveSum(input, output); - } break; - case 256: { - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - BlockScanT(temp_storage).InclusiveSum(input, output); - } break; - case 512: { - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - BlockScanT(temp_storage).InclusiveSum(input, output); - } break; - case 1024: { - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - BlockScanT(temp_storage).InclusiveSum(input, output); - } break; - default: break; - } -} - -// -template -__device__ inline void update_histogram(int itr, - uint32_t thread_id, - uint32_t num_threads, - uint32_t hint, - uint32_t threshold, - uint32_t& num_bins, - uint32_t& shift, - const T* x, // [nx,] - uint32_t nx, - uint32_t* hist, // [num_bins] - uint8_t* state, - uint32_t* output, // [topk] - uint32_t* output_count) -{ - if (sizeof(T) == 4) { - // 32-bit (uint32_t) - // itr:0, calculate histogram with 11 bits from bit-21 to bit-31 - // itr:1, calculate histogram with 11 bits from bit-10 to bit-20 - // itr:2, calculate histogram with 10 bits from bit-0 to bit-9 - if (itr == 0) { - shift = 21; - num_bins = 2048; - } else if (itr == 1) { - shift = 10; - num_bins = 2048; - } else { - shift = 0; - num_bins = 1024; - } - } else if (sizeof(T) == 2) { - // 16-bit (uint16_t) - // itr:0, calculate histogram with 8 bits from bit-8 to bit-15 - // itr:1, calculate histogram with 8 bits from bit-0 to bit-7 - if (itr == 0) { - shift = 8; - num_bins = 256; - } else { - shift = 0; - num_bins = 256; - } - } else { - return; - } - if (itr > 0) { - for (int i = threadIdx.x; i < num_bins; i += blockDim.x) { - hist[i] = 0; - } - __syncthreads(); - } - - // (*) Note that 'thread_id' may be different from 'threadIdx.x', - // and 'num_threads' may be different from 'blockDim.x' - int ii = 0; - for (int i = thread_id * vecLen; i < nx; i += num_threads * max(vecLen, stateBitLen), ii++) { - uint8_t iState = 0; - if ((stateBitLen == 8) && (itr > 0)) { - iState = state[thread_id + (num_threads * ii)]; - if (iState == (uint8_t)0xff) continue; - } -#pragma unroll - for (int v = 0; v < max(vecLen, stateBitLen); v += vecLen) { - const int iv = i + (num_threads * v); - if (iv >= nx) break; - - struct u32_vector x_u32_vec; - struct u16_vector x_u16_vec; - if (sizeof(T) == 4) { - load_u32_vector(x_u32_vec, (const uint32_t*)x, iv); - } else { - load_u16_vector(x_u16_vec, (const uint16_t*)x, iv); - } -#pragma unroll - for (int u = 0; u < vecLen; u++) { - const int ivu = iv + u; - if (ivu >= nx) break; - - uint8_t mask = (uint8_t)0x1 << (v + u); - if ((stateBitLen == 8) && (iState & mask)) continue; - - uint32_t xi; - if (sizeof(T) == 4) { - xi = get_element_from_u32_vector(x_u32_vec, u); - } else { - xi = get_element_from_u16_vector(x_u16_vec, u); - } - if ((xi > hint) && (itr == 0)) { - if (stateBitLen == 8) { iState |= mask; } - } else if (xi < threshold) { - if (stateBitLen == 8) { - // If the condition is already met, record the index. - output[atomicAdd(output_count, 1)] = ivu; - iState |= mask; - } - } else { - const uint32_t k = (xi - threshold) >> shift; // 0 <= k - if (k >= num_bins) { - if (stateBitLen == 8) { iState |= mask; } - } else if (k + 1 < num_bins) { - // Update histogram - atomicAdd(&(hist[k + 1]), 1); - } - } - } - } - if (stateBitLen == 8) { state[thread_id + (num_threads * ii)] = iState; } - } - __syncthreads(); -} - -template -__device__ inline void select_best_index_for_next_threshold_core(uint32_t& my_index, - uint32_t& my_csum, - const unsigned num_bins, - const uint32_t* const hist, - const uint32_t nx_below_threshold, - const uint32_t max_threshold, - const uint32_t threshold, - const uint32_t shift, - const uint32_t topk) -{ - typedef cub::BlockScan BlockScanT; - __shared__ typename BlockScanT::TempStorage temp_storage; - if (num_bins == 2048) { - constexpr int n_data = 2048 / blockDim_x; - uint32_t csum[n_data]; - for (int i = 0; i < n_data; i++) { - csum[i] = hist[i + (n_data * threadIdx.x)]; - } - BlockScanT(temp_storage).InclusiveSum(csum, csum); - for (int i = n_data - 1; i >= 0; i--) { - if (nx_below_threshold + csum[i] > topk) continue; - const uint32_t index = i + (n_data * threadIdx.x); - if (threshold + (index << shift) > max_threshold) continue; - my_index = index; - my_csum = csum[i]; - break; - } - } else if (num_bins == 1024) { - constexpr int n_data = 1024 / blockDim_x; - uint32_t csum[n_data]; - for (int i = 0; i < n_data; i++) { - csum[i] = hist[i + (n_data * threadIdx.x)]; - } - BlockScanT(temp_storage).InclusiveSum(csum, csum); - for (int i = n_data - 1; i >= 0; i--) { - if (nx_below_threshold + csum[i] > topk) continue; - const uint32_t index = i + (n_data * threadIdx.x); - if (threshold + (index << shift) > max_threshold) continue; - my_index = index; - my_csum = csum[i]; - break; - } - } -} - -// -__device__ inline void select_best_index_for_next_threshold( - const uint32_t topk, - const uint32_t threshold, - const uint32_t max_threshold, - const uint32_t nx_below_threshold, - const uint32_t num_bins, - const uint32_t shift, - const uint32_t* const hist, // [num_bins] - uint32_t* const best_index, - uint32_t* const best_csum) -{ - // Scan the histogram ('hist') and compute csum. Then, find the largest - // index under the condition that the sum of the number of elements found - // so far ('nx_below_threshold') and the csum value does not exceed the - // topk value. - uint32_t my_index = 0xffffffff; - uint32_t my_csum = 0; - if (num_bins <= blockDim.x) { - uint32_t csum = 0; - if (threadIdx.x < num_bins) { csum = hist[threadIdx.x]; } - detail::block_scan(csum, csum); - if (threadIdx.x < num_bins) { - const uint32_t index = threadIdx.x; - if ((nx_below_threshold + csum <= topk) && (threshold + (index << shift) <= max_threshold)) { - my_index = index; - my_csum = csum; - } - } - } else { - switch (blockDim.x) { - case 64: - select_best_index_for_next_threshold_core<64>(my_index, - my_csum, - num_bins, - hist, - nx_below_threshold, - max_threshold, - threshold, - shift, - topk); - break; - case 128: - select_best_index_for_next_threshold_core<128>(my_index, - my_csum, - num_bins, - hist, - nx_below_threshold, - max_threshold, - threshold, - shift, - topk); - break; - case 256: - select_best_index_for_next_threshold_core<256>(my_index, - my_csum, - num_bins, - hist, - nx_below_threshold, - max_threshold, - threshold, - shift, - topk); - break; - case 512: - select_best_index_for_next_threshold_core<512>(my_index, - my_csum, - num_bins, - hist, - nx_below_threshold, - max_threshold, - threshold, - shift, - topk); - break; - case 1024: - select_best_index_for_next_threshold_core<1024>(my_index, - my_csum, - num_bins, - hist, - nx_below_threshold, - max_threshold, - threshold, - shift, - topk); - break; - } - } - if (threadIdx.x < num_bins) { - const int laneid = 31 - __clz(__ballot_sync(0xffffffff, (my_index != 0xffffffff))); - if ((threadIdx.x & 0x1f) == laneid) { - const uint32_t old_index = atomicMax(best_index, my_index); - if (old_index < my_index) { atomicMax(best_csum, my_csum); } - } - } - __syncthreads(); -} - -// -template -__device__ inline void output_index_below_threshold(const uint32_t topk, - const uint32_t thread_id, - const uint32_t num_threads, - const uint32_t threshold, - const uint32_t nx_below_threshold, - const T* const x, // [nx,] - const uint32_t nx, - const uint8_t* state, - uint32_t* const output, // [topk] - uint32_t* const output_count, - uint32_t* const output_count_eq) -{ - int ii = 0; - for (int i = thread_id * vecLen; i < nx; i += num_threads * max(vecLen, stateBitLen), ii++) { - uint8_t iState = 0; - if (stateBitLen == 8) { - iState = state[thread_id + (num_threads * ii)]; - if (iState == (uint8_t)0xff) continue; - } -#pragma unroll - for (int v = 0; v < max(vecLen, stateBitLen); v += vecLen) { - const int iv = i + (num_threads * v); - if (iv >= nx) break; - - struct u32_vector u32_vec; - struct u16_vector u16_vec; - if (sizeof(T) == 4) { - load_u32_vector(u32_vec, (const uint32_t*)x, iv); - } else { - load_u16_vector(u16_vec, (const uint16_t*)x, iv); - } -#pragma unroll - for (int u = 0; u < vecLen; u++) { - const int ivu = iv + u; - if (ivu >= nx) break; - - const uint8_t mask = (uint8_t)0x1 << (v + u); - if ((stateBitLen == 8) && (iState & mask)) continue; - - uint32_t xi; - if (sizeof(T) == 4) { - xi = get_element_from_u32_vector(u32_vec, u); - } else { - xi = get_element_from_u16_vector(u16_vec, u); - } - if (xi < threshold) { - output[atomicAdd(output_count, 1)] = ivu; - } else if (xi == threshold) { - // (*) If the value is equal to the threshold, the index - // processed first is recorded. Cause of non-determinism. - if (nx_below_threshold + atomicAdd(output_count_eq, 1) < topk) { - output[atomicAdd(output_count, 1)] = ivu; - } - } - } - } - } -} - -// -template -__device__ inline void swap(T& val1, T& val2) -{ - const T val0 = val1; - val1 = val2; - val2 = val0; -} - -// -template -__device__ inline bool swap_if_needed(K& key1, K& key2) -{ - if (key1 > key2) { - swap(key1, key2); - return true; - } - return false; -} - -// -template -__device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2) -{ - if (key1 > key2) { - swap(key1, key2); - swap(val1, val2); - return true; - } - return false; -} - -// -template -__device__ inline bool swap_if_needed(K& key1, K& key2, V& val1, V& val2, bool ascending) -{ - if (key1 == key2) { return false; } - if ((key1 > key2) == ascending) { - swap(key1, key2); - swap(val1, val2); - return true; - } - return false; -} - -// -template -__device__ inline T max_value_of(); -template <> -__device__ inline float max_value_of() -{ - return FLT_MAX; -} -template <> -__device__ inline uint32_t max_value_of() -{ - return ~0u; -} - -template -__device__ __host__ inline uint32_t get_state_size(uint32_t len_x) -{ -#ifdef __CUDA_ARCH__ - const uint32_t num_threads = blockDim.x; -#else - const uint32_t num_threads = BLOCK_SIZE; -#endif - if (stateBitLen == 8) { - uint32_t numElements_perThread = (len_x + num_threads - 1) / num_threads; - uint32_t numState_perThread = (numElements_perThread + stateBitLen - 1) / stateBitLen; - return numState_perThread * num_threads; - } - return 0; -} - -// -template -__device__ inline void topk_cta_11_core(uint32_t topk, - uint32_t len_x, - const uint32_t* _x, // [size_batch, ld_x,] - const ValT* _in_vals, // [size_batch, ld_iv,] - uint32_t* _y, // [size_batch, ld_y,] - ValT* _out_vals, // [size_batch, ld_ov,] - uint8_t* _state, // [size_batch, ...,] - uint32_t* _hint, - bool sort, - uint32_t* _smem) -{ - uint32_t* const smem_out_vals = _smem; - uint32_t* const hist = &(_smem[2 * maxTopk]); - uint32_t* const best_index = &(_smem[2 * maxTopk + 2048]); - uint32_t* const best_csum = &(_smem[2 * maxTopk + 2048 + 3]); - - const uint32_t num_threads = blockDim.x; - const uint32_t thread_id = threadIdx.x; - uint32_t nx = len_x; - const uint32_t* const x = _x; - const ValT* in_vals = NULL; - if (_in_vals) { in_vals = _in_vals; } - uint32_t* y = NULL; - if (_y) { y = _y; } - ValT* out_vals = NULL; - if (_out_vals) { out_vals = _out_vals; } - uint8_t* state = _state; - const uint32_t hint = (_hint == NULL ? ~0u : *_hint); - - // Initialize shared memory - for (int i = 2 * maxTopk + thread_id; i < 2 * maxTopk + 2048 + 8; i += num_threads) { - _smem[i] = 0; - } - uint32_t* const output_count = &(_smem[2 * maxTopk + 2048 + 6]); - uint32_t* const output_count_eq = &(_smem[2 * maxTopk + 2048 + 7]); - uint32_t threshold = 0; - uint32_t nx_below_threshold = 0; - __syncthreads(); - - // - // Search for the maximum threshold that satisfies "(x < threshold).sum() <= topk". - // -#pragma unroll - for (int j = 0; j < 3; j += 1) { - uint32_t num_bins; - uint32_t shift; - - update_histogram(j, - thread_id, - num_threads, - hint, - threshold, - num_bins, - shift, - x, - nx, - hist, - state, - smem_out_vals, - output_count); - select_best_index_for_next_threshold(topk, - threshold, - hint, - nx_below_threshold, - num_bins, - shift, - hist, - best_index + j, - best_csum + j); - - threshold += (best_index[j] << shift); - nx_below_threshold += best_csum[j]; - if (nx_below_threshold == topk) break; - } - - if ((_hint != NULL) && (thread_id == 0)) { *_hint = min(threshold, hint); } - - // - // Output index that satisfies "x[i] < threshold". - // - output_index_below_threshold(topk, - thread_id, - num_threads, - threshold, - nx_below_threshold, - x, - nx, - state, - smem_out_vals, - output_count, - output_count_eq); - __syncthreads(); - -#ifdef CUANN_DEBUG - if (thread_id == 0 && output_count[0] < topk) { - RAFT_LOG_DEBUG( - "# i_batch:%d, topk:%d, output_count:%d, nx_below_threshold:%d, threshold:%08x\n", - i_batch, - topk, - output_count[0], - nx_below_threshold, - threshold); - } -#endif - - if (!sort) { - for (int k = thread_id; k < topk; k += blockDim.x) { - const uint32_t i = smem_out_vals[k]; - if (y) { y[k] = x[i]; } - if (out_vals) { - if (in_vals) { - out_vals[k] = in_vals[i]; - } else { - out_vals[k] = i; - } - } - } - return; - } - - constexpr int numTopkPerThread = maxTopk / numSortThreads; - float my_keys[numTopkPerThread]; - ValT my_vals[numTopkPerThread]; - - // Read keys and values to registers - if (thread_id < numSortThreads) { - for (int i = 0; i < numTopkPerThread; i++) { - const int k = thread_id + (numSortThreads * i); - if (k < topk) { - const int j = smem_out_vals[k]; - my_keys[i] = ((float*)x)[j]; - if (in_vals) { - my_vals[i] = in_vals[j]; - } else { - my_vals[i] = j; - } - } else { - my_keys[i] = FLT_MAX; - my_vals[i] = ~static_cast(0); - } - } - } - - uint32_t mask = 1; - - // Sorting by thread - if (thread_id < numSortThreads) { - const bool ascending = ((thread_id & mask) == 0); - if (numTopkPerThread == 3) { - swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); - swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); - swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); - } else { - for (int j = 0; j < numTopkPerThread / 2; j += 1) { -#pragma unroll - for (int i = 0; i < numTopkPerThread; i += 2) { - swap_if_needed( - my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); - } -#pragma unroll - for (int i = 1; i < numTopkPerThread - 1; i += 2) { - swap_if_needed( - my_keys[i], my_keys[i + 1], my_vals[i], my_vals[i + 1], ascending); - } - } - } - } - - // Bitonic Sorting - while (mask < numSortThreads) { - uint32_t next_mask = mask << 1; - - for (uint32_t curr_mask = mask; curr_mask > 0; curr_mask >>= 1) { - const bool ascending = ((thread_id & curr_mask) == 0) == ((thread_id & next_mask) == 0); - if (curr_mask >= 32) { - // inter warp - ValT* const smem_vals = reinterpret_cast(_smem); // [maxTopk] - float* const smem_keys = - reinterpret_cast(smem_vals + maxTopk); // [numTopkPerThread, numSortThreads] - __syncthreads(); - if (thread_id < numSortThreads) { -#pragma unroll - for (int i = 0; i < numTopkPerThread; i++) { - smem_keys[thread_id + (numSortThreads * i)] = my_keys[i]; - smem_vals[thread_id + (numSortThreads * i)] = my_vals[i]; - } - } - __syncthreads(); - if (thread_id < numSortThreads) { -#pragma unroll - for (int i = 0; i < numTopkPerThread; i++) { - float opp_key = smem_keys[(thread_id ^ curr_mask) + (numSortThreads * i)]; - ValT opp_val = smem_vals[(thread_id ^ curr_mask) + (numSortThreads * i)]; - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); - } - } - } else { - // intra warp - if (thread_id < numSortThreads) { -#pragma unroll - for (int i = 0; i < numTopkPerThread; i++) { - float opp_key = __shfl_xor_sync(0xffffffff, my_keys[i], curr_mask); - ValT opp_val = __shfl_xor_sync(0xffffffff, my_vals[i], curr_mask); - swap_if_needed(my_keys[i], opp_key, my_vals[i], opp_val, ascending); - } - } - } - } - - if (thread_id < numSortThreads) { - const bool ascending = ((thread_id & next_mask) == 0); - if (numTopkPerThread == 3) { - swap_if_needed(my_keys[0], my_keys[1], my_vals[0], my_vals[1], ascending); - swap_if_needed(my_keys[0], my_keys[2], my_vals[0], my_vals[2], ascending); - swap_if_needed(my_keys[1], my_keys[2], my_vals[1], my_vals[2], ascending); - } else { -#pragma unroll - for (uint32_t curr_mask = numTopkPerThread / 2; curr_mask > 0; curr_mask >>= 1) { -#pragma unroll - for (int i = 0; i < numTopkPerThread; i++) { - const int j = i ^ curr_mask; - if (i > j) continue; - swap_if_needed(my_keys[i], my_keys[j], my_vals[i], my_vals[j], ascending); - } - } - } - } - mask = next_mask; - } - - // Write sorted keys and values - if (thread_id < numSortThreads) { - for (int i = 0; i < numTopkPerThread; i++) { - const int k = i + (numTopkPerThread * thread_id); - if (k < topk) { - if (y) { y[k] = reinterpret_cast(my_keys)[i]; } - if (out_vals) { out_vals[k] = my_vals[i]; } - } - } - } -} - -namespace { - -// -constexpr std::uint32_t NUM_THREADS = 1024; // DO NOT CHANGE -constexpr std::uint32_t STATE_BIT_LENGTH = 8; // 0: state not used, 8: state used -constexpr std::uint32_t MAX_VEC_LENGTH = 4; // 1, 2, 4 or 8 - -// -// -int _get_vecLen(uint32_t maxSamples, int maxVecLen = MAX_VEC_LENGTH) -{ - int vecLen = min(maxVecLen, (int)MAX_VEC_LENGTH); - while ((maxSamples % vecLen) != 0) { - vecLen /= 2; - } - return vecLen; -} -} // unnamed namespace - -template -__launch_bounds__(1024, 1) RAFT_KERNEL - kern_topk_cta_11(uint32_t topk, - uint32_t size_batch, - uint32_t len_x, - const uint32_t* _x, // [size_batch, ld_x,] - uint32_t ld_x, - const ValT* _in_vals, // [size_batch, ld_iv,] - uint32_t ld_iv, - uint32_t* _y, // [size_batch, ld_y,] - uint32_t ld_y, - ValT* _out_vals, // [size_batch, ld_ov,] - uint32_t ld_ov, - uint8_t* _state, // [size_batch, ...,] - uint32_t* _hints, // [size_batch,] - bool sort) -{ - const uint32_t i_batch = blockIdx.x; - if (i_batch >= size_batch) return; - - constexpr uint32_t smem_len = 2 * maxTopk + 2048 + 8; - static_assert(maxTopk * (1 + utils::size_of() / utils::size_of()) <= smem_len, - "maxTopk * sizeof(ValT) must be smaller or equal to 8192 byte"); - __shared__ uint32_t _smem[smem_len]; - - topk_cta_11_core( - topk, - len_x, - (_x == NULL ? NULL : _x + i_batch * ld_x), - (_in_vals == NULL ? NULL : _in_vals + i_batch * ld_iv), - (_y == NULL ? NULL : _y + i_batch * ld_y), - (_out_vals == NULL ? NULL : _out_vals + i_batch * ld_ov), - (_state == NULL ? NULL : _state + i_batch * get_state_size(len_x)), - (_hints == NULL ? NULL : _hints + i_batch), - sort, - _smem); -} - -// -size_t inline _cuann_find_topk_bufferSize(uint32_t topK, - uint32_t sizeBatch, - uint32_t numElements, - cudaDataType_t sampleDtype) -{ - constexpr int numThreads = NUM_THREADS; - constexpr int stateBitLen = STATE_BIT_LENGTH; - assert(stateBitLen == 0 || stateBitLen == 8); - - size_t workspaceSize = 1; - // state - if (stateBitLen == 8) { - workspaceSize = _cuann_aligned( - sizeof(uint8_t) * get_state_size(numElements) * sizeBatch); - } - - return workspaceSize; -} - -template -inline void _cuann_find_topk(uint32_t topK, - uint32_t sizeBatch, - uint32_t numElements, - const float* inputKeys, // [sizeBatch, ldIK,] - uint32_t ldIK, // (*) ldIK >= numElements - const ValT* inputVals, // [sizeBatch, ldIV,] - uint32_t ldIV, // (*) ldIV >= numElements - float* outputKeys, // [sizeBatch, ldOK,] - uint32_t ldOK, // (*) ldOK >= topK - ValT* outputVals, // [sizeBatch, ldOV,] - uint32_t ldOV, // (*) ldOV >= topK - void* workspace, - bool sort, - uint32_t* hints, - cudaStream_t stream) -{ - assert(ldIK >= numElements); - assert(ldIV >= numElements); - assert(ldOK >= topK); - assert(ldOV >= topK); - - constexpr int numThreads = NUM_THREADS; - constexpr int stateBitLen = STATE_BIT_LENGTH; - assert(stateBitLen == 0 || stateBitLen == 8); - - uint8_t* state = NULL; - if (stateBitLen == 8) { state = (uint8_t*)workspace; } - - dim3 threads(numThreads, 1, 1); - dim3 blocks(sizeBatch, 1, 1); - - void (*cta_kernel)(uint32_t, - uint32_t, - uint32_t, - const uint32_t*, - uint32_t, - const ValT*, - uint32_t, - uint32_t*, - uint32_t, - ValT*, - uint32_t, - uint8_t*, - uint32_t*, - bool) = nullptr; - - // V:vecLen, K:maxTopk, T:numSortThreads -#define SET_KERNEL_VKT(V, K, T, ValT) \ - do { \ - assert(numThreads >= T); \ - assert((K % T) == 0); \ - assert((K / T) <= 4); \ - cta_kernel = kern_topk_cta_11; \ - } while (0) - - // V: vecLen -#define SET_KERNEL_V(V, ValT) \ - do { \ - if (topK <= 32) { \ - SET_KERNEL_VKT(V, 32, 32, ValT); \ - } else if (topK <= 64) { \ - SET_KERNEL_VKT(V, 64, 32, ValT); \ - } else if (topK <= 96) { \ - SET_KERNEL_VKT(V, 96, 32, ValT); \ - } else if (topK <= 128) { \ - SET_KERNEL_VKT(V, 128, 32, ValT); \ - } else if (topK <= 192) { \ - SET_KERNEL_VKT(V, 192, 64, ValT); \ - } else if (topK <= 256) { \ - SET_KERNEL_VKT(V, 256, 64, ValT); \ - } else if (topK <= 384) { \ - SET_KERNEL_VKT(V, 384, 128, ValT); \ - } else if (topK <= 512) { \ - SET_KERNEL_VKT(V, 512, 128, ValT); \ - } else if (topK <= 768) { \ - SET_KERNEL_VKT(V, 768, 256, ValT); \ - } else if (topK <= 1024) { \ - SET_KERNEL_VKT(V, 1024, 256, ValT); \ - } \ - /* else if (topK <= 1536) { SET_KERNEL_VKT(V, 1536, 512); } */ \ - /* else if (topK <= 2048) { SET_KERNEL_VKT(V, 2048, 512); } */ \ - /* else if (topK <= 3072) { SET_KERNEL_VKT(V, 3072, 1024); } */ \ - /* else if (topK <= 4096) { SET_KERNEL_VKT(V, 4096, 1024); } */ \ - else { \ - RAFT_FAIL("topk must be lower than or equal to 1024"); \ - } \ - } while (0) - - int _vecLen = _get_vecLen(ldIK, 2); - if (_vecLen == 2) { - SET_KERNEL_V(2, ValT); - } else if (_vecLen == 1) { - SET_KERNEL_V(1, ValT); - } - - cta_kernel<<>>(topK, - sizeBatch, - numElements, - (const uint32_t*)inputKeys, - ldIK, - inputVals, - ldIV, - (uint32_t*)outputKeys, - ldOK, - outputVals, - ldOV, - state, - hints, - sort); - - return; -} -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/utils.hpp b/cpp/include/raft/neighbors/detail/cagra/utils.hpp deleted file mode 100644 index ece95a7cb7..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/utils.hpp +++ /dev/null @@ -1,289 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include - -#include - -#include -#include - -#include -#include -#include - -namespace raft::neighbors::cagra::detail { -namespace utils { -template -inline cudaDataType_t get_cuda_data_type(); -template <> -inline cudaDataType_t get_cuda_data_type() -{ - return CUDA_R_32F; -} -template <> -inline cudaDataType_t get_cuda_data_type() -{ - return CUDA_R_16F; -} -template <> -inline cudaDataType_t get_cuda_data_type() -{ - return CUDA_R_8I; -} -template <> -inline cudaDataType_t get_cuda_data_type() -{ - return CUDA_R_8U; -} -template <> -inline cudaDataType_t get_cuda_data_type() -{ - return CUDA_R_32U; -} -template <> -inline cudaDataType_t get_cuda_data_type() -{ - return CUDA_R_64U; -} - -template -constexpr unsigned size_of(); -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 1; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 1; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 2; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 4; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 8; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 16; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 32; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 4; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 2; -} -template <> -_RAFT_HOST_DEVICE constexpr unsigned size_of() -{ - return 4; -} - -// max values for data types -template -union fp_conv { - BS_T bs; - FP_T fp; -}; -template -_RAFT_HOST_DEVICE inline T get_max_value(); -template <> -_RAFT_HOST_DEVICE inline float get_max_value() -{ - return FLT_MAX; -}; -template <> -_RAFT_HOST_DEVICE inline half get_max_value() -{ - return fp_conv{.bs = 0x7aff}.fp; -}; -template <> -_RAFT_HOST_DEVICE inline std::uint32_t get_max_value() -{ - return 0xffffffffu; -}; -template <> -_RAFT_HOST_DEVICE inline std::uint64_t get_max_value() -{ - return 0xfffffffffffffffflu; -}; - -template -struct constexpr_max { - static const int value = A; -}; - -template -struct constexpr_max A), bool>> { - static const int value = B; -}; - -template -struct gen_index_msb_1_mask { - static constexpr IdxT value = static_cast(1) << (utils::size_of() * 8 - 1); -}; -} // namespace utils - -/** - * Utility to sync memory from a host_matrix_view to a device_matrix_view - * - * In certain situations (UVM/HMM/ATS) host memory might be directly accessible on the - * device, and no extra allocations need to be performed. This class checks - * if the host_matrix_view is already accessible on the device, and only creates device - * memory and copies over if necessary. In memory limited situations this is preferable - * to having both a host and device copy - * TODO: once the mdbuffer changes here https://github.com/wphicks/raft/blob/fea-mdbuffer - * have been merged, we should remove this class and switch over to using mdbuffer for this - */ -template -class device_matrix_view_from_host { - public: - device_matrix_view_from_host(raft::resources const& res, host_matrix_view host_view) - : host_view_(host_view) - { - cudaPointerAttributes attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle())); - device_ptr = reinterpret_cast(attr.devicePointer); - if (device_ptr == NULL) { - // allocate memory and copy over - device_mem_.emplace( - raft::make_device_matrix(res, host_view.extent(0), host_view.extent(1))); - raft::copy(device_mem_->data_handle(), - host_view.data_handle(), - host_view.extent(0) * host_view.extent(1), - resource::get_cuda_stream(res)); - device_ptr = device_mem_->data_handle(); - } - } - - device_matrix_view view() - { - return make_device_matrix_view(device_ptr, host_view_.extent(0), host_view_.extent(1)); - } - - T* data_handle() { return device_ptr; } - - bool allocated_memory() const { return device_mem_.has_value(); } - - private: - std::optional> device_mem_; - host_matrix_view host_view_; - T* device_ptr; -}; - -/** - * Utility to sync memory from a device_matrix_view to a host_matrix_view - * - * In certain situations (UVM/HMM/ATS) device memory might be directly accessible on the - * host, and no extra allocations need to be performed. This class checks - * if the device_matrix_view is already accessible on the host, and only creates host - * memory and copies over if necessary. In memory limited situations this is preferable - * to having both a host and device copy - * TODO: once the mdbuffer changes here https://github.com/wphicks/raft/blob/fea-mdbuffer - * have been merged, we should remove this class and switch over to using mdbuffer for this - */ -template -class host_matrix_view_from_device { - public: - host_matrix_view_from_device(raft::resources const& res, device_matrix_view device_view) - : device_view_(device_view) - { - cudaPointerAttributes attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, device_view.data_handle())); - host_ptr = reinterpret_cast(attr.hostPointer); - if (host_ptr == NULL) { - // allocate memory and copy over - host_mem_.emplace( - raft::make_host_matrix(device_view.extent(0), device_view.extent(1))); - raft::copy(host_mem_->data_handle(), - device_view.data_handle(), - device_view.extent(0) * device_view.extent(1), - resource::get_cuda_stream(res)); - host_ptr = host_mem_->data_handle(); - } - } - - host_matrix_view view() - { - return make_host_matrix_view(host_ptr, device_view_.extent(0), device_view_.extent(1)); - } - - T* data_handle() { return host_ptr; } - - bool allocated_memory() const { return host_mem_.has_value(); } - - private: - std::optional> host_mem_; - device_matrix_view device_view_; - T* host_ptr; -}; - -// Copy matrix src to dst. pad rows with 0 if necessary to make them 16 byte aligned. -template -void copy_with_padding(raft::resources const& res, - raft::device_matrix& dst, - mdspan, row_major, data_accessor> src, - rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource()) -{ - size_t padded_dim = round_up_safe(src.extent(1) * sizeof(T), 16) / sizeof(T); - - if ((dst.extent(0) != src.extent(0)) || (static_cast(dst.extent(1)) != padded_dim)) { - // clear existing memory before allocating to prevent OOM errors on large datasets - if (dst.size()) { dst = make_device_matrix(res, 0, 0); } - dst = make_device_mdarray(res, mr, make_extents(src.extent(0), padded_dim)); - } - if (dst.extent(1) == src.extent(1)) { - raft::copy(dst.data_handle(), src.data_handle(), src.size(), resource::get_cuda_stream(res)); - } else { - // copy with padding - RAFT_CUDA_TRY(cudaMemsetAsync( - dst.data_handle(), 0, dst.size() * sizeof(T), resource::get_cuda_stream(res))); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(), - sizeof(T) * dst.extent(1), - src.data_handle(), - sizeof(T) * src.extent(1), - sizeof(T) * src.extent(1), - src.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - } -} -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp deleted file mode 100644 index a6a6ae59a5..0000000000 --- a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp +++ /dev/null @@ -1,192 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../dataset.hpp" - -#include -#include -#include - -#include - -#include -#include - -namespace raft::neighbors::detail { - -using dataset_instance_tag = uint32_t; -constexpr dataset_instance_tag kSerializeEmptyDataset = 1; -constexpr dataset_instance_tag kSerializeStridedDataset = 2; -constexpr dataset_instance_tag kSerializeVPQDataset = 3; - -template -void serialize(const raft::resources& res, std::ostream& os, const empty_dataset& dataset) -{ - serialize_scalar(res, os, dataset.suggested_dim); -} - -template -void serialize(const raft::resources& res, - std::ostream& os, - const strided_dataset& dataset) -{ - auto n_rows = dataset.n_rows(); - auto dim = dataset.dim(); - auto stride = dataset.stride(); - serialize_scalar(res, os, n_rows); - serialize_scalar(res, os, dim); - serialize_scalar(res, os, stride); - // Remove padding before saving the dataset - auto src = dataset.view(); - auto dst = make_host_matrix(n_rows, dim); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(), - sizeof(DataT) * dim, - src.data_handle(), - sizeof(DataT) * stride, - sizeof(DataT) * dim, - n_rows, - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - resource::sync_stream(res); - serialize_mdspan(res, os, dst.view()); -} - -template -void serialize(const raft::resources& res, - std::ostream& os, - const vpq_dataset& dataset) -{ - serialize_scalar(res, os, dataset.n_rows()); - serialize_scalar(res, os, dataset.dim()); - serialize_scalar(res, os, dataset.vq_n_centers()); - serialize_scalar(res, os, dataset.pq_n_centers()); - serialize_scalar(res, os, dataset.pq_len()); - serialize_scalar(res, os, dataset.encoded_row_length()); - serialize_mdspan(res, os, make_const_mdspan(dataset.vq_code_book.view())); - serialize_mdspan(res, os, make_const_mdspan(dataset.pq_code_book.view())); - serialize_mdspan(res, os, make_const_mdspan(dataset.data.view())); -} - -template -void serialize(const raft::resources& res, std::ostream& os, const dataset& dataset) -{ - if (auto x = dynamic_cast*>(&dataset); x != nullptr) { - serialize_scalar(res, os, kSerializeEmptyDataset); - return serialize(res, os, *x); - } - if (auto x = dynamic_cast*>(&dataset); x != nullptr) { - serialize_scalar(res, os, kSerializeStridedDataset); - serialize_scalar(res, os, CUDA_R_32F); - return serialize(res, os, *x); - } - if (auto x = dynamic_cast*>(&dataset); x != nullptr) { - serialize_scalar(res, os, kSerializeStridedDataset); - serialize_scalar(res, os, CUDA_R_16F); - return serialize(res, os, *x); - } - if (auto x = dynamic_cast*>(&dataset); x != nullptr) { - serialize_scalar(res, os, kSerializeStridedDataset); - serialize_scalar(res, os, CUDA_R_8I); - return serialize(res, os, *x); - } - if (auto x = dynamic_cast*>(&dataset); x != nullptr) { - serialize_scalar(res, os, kSerializeStridedDataset); - serialize_scalar(res, os, CUDA_R_8U); - return serialize(res, os, *x); - } - if (auto x = dynamic_cast*>(&dataset); x != nullptr) { - serialize_scalar(res, os, kSerializeVPQDataset); - serialize_scalar(res, os, CUDA_R_32F); - return serialize(res, os, *x); - } - if (auto x = dynamic_cast*>(&dataset); x != nullptr) { - serialize_scalar(res, os, kSerializeVPQDataset); - serialize_scalar(res, os, CUDA_R_16F); - return serialize(res, os, *x); - } - RAFT_FAIL("unsupported dataset type."); -} - -template -auto deserialize_empty(raft::resources const& res, std::istream& is) - -> std::unique_ptr> -{ - auto suggested_dim = deserialize_scalar(res, is); - return std::make_unique>(suggested_dim); -} - -template -auto deserialize_strided(raft::resources const& res, std::istream& is) - -> std::unique_ptr> -{ - auto n_rows = deserialize_scalar(res, is); - auto dim = deserialize_scalar(res, is); - auto stride = deserialize_scalar(res, is); - auto host_array = make_host_matrix(n_rows, dim); - deserialize_mdspan(res, is, host_array.view()); - return make_strided_dataset(res, host_array, stride); -} - -template -auto deserialize_vpq(raft::resources const& res, std::istream& is) - -> std::unique_ptr> -{ - auto n_rows = deserialize_scalar(res, is); - auto dim = deserialize_scalar(res, is); - auto vq_n_centers = deserialize_scalar(res, is); - auto pq_n_centers = deserialize_scalar(res, is); - auto pq_len = deserialize_scalar(res, is); - auto encoded_row_length = deserialize_scalar(res, is); - - auto vq_code_book = make_device_matrix(res, vq_n_centers, dim); - auto pq_code_book = make_device_matrix(res, pq_n_centers, pq_len); - auto data = make_device_matrix(res, n_rows, encoded_row_length); - - deserialize_mdspan(res, is, vq_code_book.view()); - deserialize_mdspan(res, is, pq_code_book.view()); - deserialize_mdspan(res, is, data.view()); - - return std::make_unique>( - std::move(vq_code_book), std::move(pq_code_book), std::move(data)); -} - -template -auto deserialize_dataset(raft::resources const& res, std::istream& is) - -> std::unique_ptr> -{ - switch (deserialize_scalar(res, is)) { - case kSerializeEmptyDataset: return deserialize_empty(res, is); - case kSerializeStridedDataset: - switch (deserialize_scalar(res, is)) { - case CUDA_R_32F: return deserialize_strided(res, is); - case CUDA_R_16F: return deserialize_strided(res, is); - case CUDA_R_8I: return deserialize_strided(res, is); - case CUDA_R_8U: return deserialize_strided(res, is); - default: break; - } - case kSerializeVPQDataset: - switch (deserialize_scalar(res, is)) { - case CUDA_R_32F: return deserialize_vpq(res, is); - case CUDA_R_16F: return deserialize_vpq(res, is); - default: break; - } - default: break; - } - RAFT_FAIL("Failed to deserialize dataset: unsupported combination of instance tags."); -} - -} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/hnsw.hpp b/cpp/include/raft/neighbors/detail/hnsw.hpp deleted file mode 100644 index bd4e6608de..0000000000 --- a/cpp/include/raft/neighbors/detail/hnsw.hpp +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "hnsw_types.hpp" - -#include -#include - -#include -#include - -#include - -namespace raft::neighbors::hnsw::detail { - -template -void get_search_knn_results(hnswlib::HierarchicalNSW::type> const* idx, - const T* query, - int k, - uint64_t* indices, - float* distances) -{ - auto result = idx->searchKnn(query, k); - assert(result.size() >= static_cast(k)); - - for (int i = k - 1; i >= 0; --i) { - indices[i] = result.top().second; - distances[i] = result.top().first; - result.pop(); - } -} - -template -void search(raft::resources const& res, - const search_params& params, - const index& idx, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) -{ - idx.set_ef(params.ef); - auto const* hnswlib_index = - reinterpret_cast::type> const*>( - idx.get_index()); - - // when num_threads == 0, automatically maximize parallelism - if (params.num_threads) { -#pragma omp parallel for num_threads(params.num_threads) - for (int64_t i = 0; i < queries.extent(0); ++i) { - get_search_knn_results(hnswlib_index, - queries.data_handle() + i * queries.extent(1), - neighbors.extent(1), - neighbors.data_handle() + i * neighbors.extent(1), - distances.data_handle() + i * distances.extent(1)); - } - } else { -#pragma omp parallel for - for (int64_t i = 0; i < queries.extent(0); ++i) { - get_search_knn_results(hnswlib_index, - queries.data_handle() + i * queries.extent(1), - neighbors.extent(1), - neighbors.data_handle() + i * neighbors.extent(1), - distances.data_handle() + i * distances.extent(1)); - } - } -} - -} // namespace raft::neighbors::hnsw::detail diff --git a/cpp/include/raft/neighbors/detail/hnsw_serialize.hpp b/cpp/include/raft/neighbors/detail/hnsw_serialize.hpp deleted file mode 100644 index 4c3728f8fc..0000000000 --- a/cpp/include/raft/neighbors/detail/hnsw_serialize.hpp +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "../hnsw_types.hpp" -#include "hnsw_types.hpp" - -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -namespace raft::neighbors::hnsw::detail { - -template -std::unique_ptr> deserialize(raft::resources const& handle, - const std::string& filename, - int dim, - raft::distance::DistanceType metric) -{ - return std::unique_ptr>(new index_impl(filename, dim, metric)); -} - -} // namespace raft::neighbors::hnsw::detail diff --git a/cpp/include/raft/neighbors/detail/hnsw_types.hpp b/cpp/include/raft/neighbors/detail/hnsw_types.hpp deleted file mode 100644 index 8d601f59ae..0000000000 --- a/cpp/include/raft/neighbors/detail/hnsw_types.hpp +++ /dev/null @@ -1,108 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "../hnsw_types.hpp" - -#include -#include - -#include -#include - -#include -#include -#include - -namespace raft::neighbors::hnsw::detail { - -/** - * @addtogroup cagra_hnswlib Build CAGRA index and search with hnswlib - * @{ - */ - -template -struct hnsw_dist_t { - using type = void; -}; - -template <> -struct hnsw_dist_t { - using type = float; -}; - -template <> -struct hnsw_dist_t { - using type = int; -}; - -template <> -struct hnsw_dist_t { - using type = int; -}; - -template -struct index_impl : index { - public: - /** - * @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index - * - * @param[in] filepath path to the index - * @param[in] dim dimensions of the training dataset - * @param[in] metric distance metric to search. Supported metrics ("L2Expanded", "InnerProduct") - */ - index_impl(std::string filepath, int dim, raft::distance::DistanceType metric) - : index{dim, metric} - { - if constexpr (std::is_same_v) { - if (metric == raft::distance::L2Expanded) { - space_ = std::make_unique(dim); - } else if (metric == raft::distance::InnerProduct) { - space_ = std::make_unique(dim); - } - } else if constexpr (std::is_same_v or std::is_same_v) { - if (metric == raft::distance::L2Expanded) { - space_ = std::make_unique>(dim); - } - } - - RAFT_EXPECTS(space_ != nullptr, "Unsupported metric type was used"); - - appr_alg_ = std::make_unique::type>>( - space_.get(), filepath); - - appr_alg_->base_layer_only = true; - } - - /** - @brief Get hnswlib index - */ - auto get_index() const -> void const* override { return appr_alg_.get(); } - - /** - @brief Set ef for search - */ - void set_ef(int ef) const override { appr_alg_->ef_ = ef; } - - private: - std::unique_ptr::type>> appr_alg_; - std::unique_ptr::type>> space_; -}; - -/**@}*/ - -} // namespace raft::neighbors::hnsw::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_common.cuh b/cpp/include/raft/neighbors/detail/ivf_common.cuh deleted file mode 100644 index df0319e181..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_common.cuh +++ /dev/null @@ -1,326 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include // matrix::detail::select::warpsort::warp_sort_distributed - -#include - -namespace raft::neighbors::ivf::detail { - -/** - * Default value returned by `search` when the `n_probes` is too small and top-k is too large. - * One may encounter it if the combined size of probed clusters is smaller than the requested - * number of results per query. - */ -template -constexpr static IdxT kOutOfBoundsRecord = std::numeric_limits::max(); - -template -struct dummy_block_sort_t { - using queue_t = - matrix::detail::select::warpsort::warp_sort_distributed; - template - __device__ dummy_block_sort_t(int k, Args...){}; -}; - -/** - * For each query, we calculate a cumulative sum of the cluster sizes that we probe, and return that - * in chunk_indices. Essentially this is a segmented inclusive scan of the cluster sizes. The total - * number of samples per query (sum of the cluster sizes that we probe) is returned in n_samples. - */ -template -__launch_bounds__(BlockDim) RAFT_KERNEL - calc_chunk_indices_kernel(uint32_t n_probes, - const uint32_t* cluster_sizes, // [n_clusters] - const uint32_t* clusters_to_probe, // [n_queries, n_probes] - uint32_t* chunk_indices, // [n_queries, n_probes] - uint32_t* n_samples // [n_queries] - ) -{ - using block_scan = cub::BlockScan; - __shared__ typename block_scan::TempStorage shm; - - // locate the query data - clusters_to_probe += n_probes * blockIdx.x; - chunk_indices += n_probes * blockIdx.x; - - // block scan - const uint32_t n_probes_aligned = Pow2::roundUp(n_probes); - uint32_t total = 0; - for (uint32_t probe_ix = threadIdx.x; probe_ix < n_probes_aligned; probe_ix += BlockDim) { - auto label = probe_ix < n_probes ? clusters_to_probe[probe_ix] : 0u; - auto chunk = probe_ix < n_probes ? cluster_sizes[label] : 0u; - if (threadIdx.x == 0) { chunk += total; } - block_scan(shm).InclusiveSum(chunk, chunk, total); - __syncthreads(); - if (probe_ix < n_probes) { chunk_indices[probe_ix] = chunk; } - } - // save the total size - if (threadIdx.x == 0) { n_samples[blockIdx.x] = total; } -} - -struct calc_chunk_indices { - public: - struct configured { - void* kernel; - dim3 block_dim; - dim3 grid_dim; - uint32_t n_probes; - - inline void operator()(const uint32_t* cluster_sizes, - const uint32_t* clusters_to_probe, - uint32_t* chunk_indices, - uint32_t* n_samples, - rmm::cuda_stream_view stream) - { - void* args[] = // NOLINT - {&n_probes, &cluster_sizes, &clusters_to_probe, &chunk_indices, &n_samples}; - RAFT_CUDA_TRY(cudaLaunchKernel(kernel, grid_dim, block_dim, args, 0, stream)); - } - }; - - static inline auto configure(uint32_t n_probes, uint32_t n_queries) -> configured - { - return try_block_dim<1024>(n_probes, n_queries); - } - - private: - template - static auto try_block_dim(uint32_t n_probes, uint32_t n_queries) -> configured - { - if constexpr (BlockDim >= WarpSize * 2) { - if (BlockDim >= n_probes * 2) { return try_block_dim<(BlockDim / 2)>(n_probes, n_queries); } - } - return {reinterpret_cast(calc_chunk_indices_kernel), - dim3(BlockDim, 1, 1), - dim3(n_queries, 1, 1), - n_probes}; - } -}; - -/** - * Look up the chunk id corresponding to the sample index. - * - * Each query vector was compared to all the vectors from n_probes clusters, and sample_ix is an - * ordered number of one of such vectors. This function looks up to which chunk it belongs, - * and returns the index within the chunk (which is also an index within a cluster). - * - * @param[inout] sample_ix - * input: the offset of the sample in the batch; - * output: the offset inside the chunk (probe) / selected cluster. - * @param[in] n_probes number of probes - * @param[in] chunk_indices offsets of the chunks within the batch [n_probes] - * @return chunk index (== n_probes when the input index is not in the valid range, - * which can happen if there is not enough data to output in the selected clusters). - */ -__device__ inline auto find_chunk_ix(uint32_t& sample_ix, // NOLINT - uint32_t n_probes, - const uint32_t* chunk_indices) -> uint32_t -{ - uint32_t ix_min = 0; - uint32_t ix_max = n_probes; - do { - uint32_t i = (ix_min + ix_max) / 2; - if (chunk_indices[i] <= sample_ix) { - ix_min = i + 1; - } else { - ix_max = i; - } - } while (ix_min < ix_max); - if (ix_min > 0) { sample_ix -= chunk_indices[ix_min - 1]; } - return ix_min; -} - -template -__launch_bounds__(BlockDim) RAFT_KERNEL - postprocess_neighbors_kernel(IdxT* neighbors_out, // [n_queries, topk] - const uint32_t* neighbors_in, // [n_queries, topk] - const IdxT* const* db_indices, // [n_clusters][..] - const uint32_t* clusters_to_probe, // [n_queries, n_probes] - const uint32_t* chunk_indices, // [n_queries, n_probes] - uint32_t n_queries, - uint32_t n_probes, - uint32_t topk) -{ - const uint64_t i = threadIdx.x + BlockDim * uint64_t(blockIdx.x); - const uint32_t query_ix = i / uint64_t(topk); - if (query_ix >= n_queries) { return; } - const uint32_t k = i % uint64_t(topk); - neighbors_in += query_ix * topk; - neighbors_out += query_ix * topk; - chunk_indices += query_ix * n_probes; - clusters_to_probe += query_ix * n_probes; - uint32_t data_ix = neighbors_in[k]; - const uint32_t chunk_ix = find_chunk_ix(data_ix, n_probes, chunk_indices); - const bool valid = chunk_ix < n_probes; - neighbors_out[k] = - valid ? db_indices[clusters_to_probe[chunk_ix]][data_ix] : kOutOfBoundsRecord; -} - -/** - * Transform found sample indices into the corresponding database indices - * (as stored in index.indices()). - * The sample indices are the record indices as they appear in the database view formed by the - * probed clusters / defined by the `chunk_indices`. - * We assume the searched sample sizes (for a single query) fit into `uint32_t`. - */ -template -void postprocess_neighbors(IdxT* neighbors_out, // [n_queries, topk] - const uint32_t* neighbors_in, // [n_queries, topk] - const IdxT* const* db_indices, // [n_clusters][..] - const uint32_t* clusters_to_probe, // [n_queries, n_probes] - const uint32_t* chunk_indices, // [n_queries, n_probes] - uint32_t n_queries, - uint32_t n_probes, - uint32_t topk, - rmm::cuda_stream_view stream) -{ - constexpr int kPNThreads = 256; - const int pn_blocks = raft::div_rounding_up_unsafe(n_queries * topk, kPNThreads); - postprocess_neighbors_kernel - <<>>(neighbors_out, - neighbors_in, - db_indices, - clusters_to_probe, - chunk_indices, - n_queries, - n_probes, - topk); -} - -/** - * Post-process the scores depending on the metric type; - * translate the element type if necessary. - */ -template -void postprocess_distances(ScoreOutT* out, // [n_queries, topk] - const ScoreInT* in, // [n_queries, topk] - distance::DistanceType metric, - uint32_t n_queries, - uint32_t topk, - float scaling_factor, - bool account_for_max_close, - rmm::cuda_stream_view stream) -{ - constexpr bool needs_cast = !std::is_same::value; - const bool needs_copy = ((void*)in) != ((void*)out); - size_t len = size_t(n_queries) * size_t(topk); - switch (metric) { - case distance::DistanceType::L2Unexpanded: - case distance::DistanceType::L2Expanded: { - if (scaling_factor != 1.0) { - linalg::unaryOp( - out, - in, - len, - raft::compose_op(raft::mul_const_op{scaling_factor * scaling_factor}, - raft::cast_op{}), - stream); - } else if (needs_cast || needs_copy) { - linalg::unaryOp(out, in, len, raft::cast_op{}, stream); - } - } break; - case distance::DistanceType::L2SqrtUnexpanded: - case distance::DistanceType::L2SqrtExpanded: { - if (scaling_factor != 1.0) { - linalg::unaryOp(out, - in, - len, - raft::compose_op{raft::mul_const_op{scaling_factor}, - raft::sqrt_op{}, - raft::cast_op{}}, - stream); - } else if (needs_cast) { - linalg::unaryOp( - out, in, len, raft::compose_op{raft::sqrt_op{}, raft::cast_op{}}, stream); - } else { - linalg::unaryOp(out, in, len, raft::sqrt_op{}, stream); - } - } break; - case distance::DistanceType::InnerProduct: { - float factor = (account_for_max_close ? -1.0 : 1.0) * scaling_factor * scaling_factor; - if (factor != 1.0) { - linalg::unaryOp( - out, - in, - len, - raft::compose_op(raft::mul_const_op{factor}, raft::cast_op{}), - stream); - } else if (needs_cast || needs_copy) { - linalg::unaryOp(out, in, len, raft::cast_op{}, stream); - } - } break; - default: RAFT_FAIL("Unexpected metric."); - } -} - -/** Update the state of the dependent index members. */ -template -void recompute_internal_state(const raft::resources& res, Index& index) -{ - auto stream = resource::get_cuda_stream(res); - auto tmp_res = resource::get_workspace_resource(res); - rmm::device_uvector sorted_sizes(index.n_lists(), stream, tmp_res); - - // Actualize the list pointers - auto data_ptrs = index.data_ptrs(); - auto inds_ptrs = index.inds_ptrs(); - for (uint32_t label = 0; label < index.n_lists(); label++) { - auto& list = index.lists()[label]; - const auto data_ptr = list ? list->data.data_handle() : nullptr; - const auto inds_ptr = list ? list->indices.data_handle() : nullptr; - copy(&data_ptrs(label), &data_ptr, 1, stream); - copy(&inds_ptrs(label), &inds_ptr, 1, stream); - } - - // Sort the cluster sizes in the descending order. - int begin_bit = 0; - int end_bit = sizeof(uint32_t) * 8; - size_t cub_workspace_size = 0; - cub::DeviceRadixSort::SortKeysDescending(nullptr, - cub_workspace_size, - index.list_sizes().data_handle(), - sorted_sizes.data(), - index.n_lists(), - begin_bit, - end_bit, - stream); - rmm::device_buffer cub_workspace(cub_workspace_size, stream, tmp_res); - cub::DeviceRadixSort::SortKeysDescending(cub_workspace.data(), - cub_workspace_size, - index.list_sizes().data_handle(), - sorted_sizes.data(), - index.n_lists(), - begin_bit, - end_bit, - stream); - // copy the results to CPU - std::vector sorted_sizes_host(index.n_lists()); - copy(sorted_sizes_host.data(), sorted_sizes.data(), index.n_lists(), stream); - resource::sync_stream(res); - - // accumulate the sorted cluster sizes - auto accum_sorted_sizes = index.accum_sorted_sizes(); - accum_sorted_sizes(0) = 0; - for (uint32_t label = 0; label < sorted_sizes_host.size(); label++) { - accum_sorted_sizes(label + 1) = accum_sorted_sizes(label) + sorted_sizes_host[label]; - } -} - -} // namespace raft::neighbors::ivf::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh deleted file mode 100644 index 55184cc615..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_flat_build.cuh +++ /dev/null @@ -1,538 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -namespace raft::neighbors::ivf_flat::detail { - -using namespace raft::spatial::knn::detail; // NOLINT - -template -auto clone(const raft::resources& res, const index& source) -> index -{ - auto stream = resource::get_cuda_stream(res); - - // Allocate the new index - index target(res, - source.metric(), - source.n_lists(), - source.adaptive_centers(), - source.conservative_memory_allocation(), - source.dim()); - - // Copy the independent parts - copy(target.list_sizes().data_handle(), - source.list_sizes().data_handle(), - source.list_sizes().size(), - stream); - copy(target.centers().data_handle(), - source.centers().data_handle(), - source.centers().size(), - stream); - if (source.center_norms().has_value()) { - target.allocate_center_norms(res); - copy(target.center_norms()->data_handle(), - source.center_norms()->data_handle(), - source.center_norms()->size(), - stream); - } - // Copy shared pointers - target.lists() = source.lists(); - - // Make sure the device pointers point to the new lists - ivf::detail::recompute_internal_state(res, target); - - return target; -} - -/** - * @brief Record the dataset into the index, one source row at a time. - * - * The index consists of the dataset rows, grouped by their labels (into clusters/lists). - * Within each cluster (list), the data is grouped into blocks of `WarpSize` interleaved - * vectors. Note, the total index length is slightly larger than the dataset length, because - * each cluster is padded by `WarpSize` elements - * - * CUDA launch grid: - * X dimension must cover the dataset (n_rows), YZ are not used; - * there are no dependencies between threads, hence no constraints on the block size. - * - * @tparam T element type. - * @tparam IdxT type of the indices in the source source_vecs - * @tparam LabelT label type - * @tparam gather_src if false, then we build the index from vectors source_vecs[i,:], otherwise - * we use source_vecs[source_ixs[i],:]. In both cases i=0..n_rows-1. - * - * @param[in] labels device pointer to the cluster ids for each row [n_rows] - * @param[in] source_vecs device pointer to the input data [n_rows, dim] - * @param[in] source_ixs device pointer to the input indices [n_rows] - * @param[out] list_data_ptrs device pointer to the index data of size [n_lists][index_size, dim] - * @param[out] list_index_ptrs device pointer to the source ids corr. to the output [n_lists] - * [index_size] - * @param[out] list_sizes_ptr device pointer to the cluster sizes [n_lists]; - * it's used as an atomic counter, and must be initialized with zeros. - * @param n_rows source length - * @param dim the dimensionality of the data - * @param veclen size of vectorized loads/stores; must satisfy `dim % veclen == 0`. - * - */ -template -RAFT_KERNEL build_index_kernel(const LabelT* labels, - const T* source_vecs, - const IdxT* source_ixs, - T** list_data_ptrs, - IdxT** list_index_ptrs, - uint32_t* list_sizes_ptr, - IdxT n_rows, - uint32_t dim, - uint32_t veclen, - IdxT batch_offset = 0) -{ - const IdxT i = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x; - if (i >= n_rows) { return; } - - auto list_id = labels[i]; - auto inlist_id = atomicAdd(list_sizes_ptr + list_id, 1); - auto* list_index = list_index_ptrs[list_id]; - auto* list_data = list_data_ptrs[list_id]; - - // Record the source vector id in the index - list_index[inlist_id] = source_ixs == nullptr ? i + batch_offset : source_ixs[i]; - - // The data is written in interleaved groups of `index::kGroupSize` vectors - using interleaved_group = Pow2; - auto group_offset = interleaved_group::roundDown(inlist_id); - auto ingroup_id = interleaved_group::mod(inlist_id) * veclen; - - // Point to the location of the interleaved group of vectors - list_data += group_offset * dim; - - // Point to the source vector - if constexpr (gather_src) { - source_vecs += source_ixs[i] * dim; - } else { - source_vecs += i * dim; - } - // Interleave dimensions of the source vector while recording it. - // NB: such `veclen` is selected, that `dim % veclen == 0` - for (uint32_t l = 0; l < dim; l += veclen) { - for (uint32_t j = 0; j < veclen; j++) { - list_data[l * kIndexGroupSize + ingroup_id + j] = source_vecs[l + j]; - } - } -} - -/** See raft::neighbors::ivf_flat::extend docs */ -template -void extend(raft::resources const& handle, - index* index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -{ - using LabelT = uint32_t; - RAFT_EXPECTS(index != nullptr, "index cannot be empty."); - - auto stream = resource::get_cuda_stream(handle); - auto n_lists = index->n_lists(); - auto dim = index->dim(); - list_spec list_device_spec{index->dim(), - index->conservative_memory_allocation()}; - common::nvtx::range fun_scope( - "ivf_flat::extend(%zu, %u)", size_t(n_rows), dim); - - RAFT_EXPECTS(new_indices != nullptr || index->size() == 0, - "You must pass data indices when the index is non-empty."); - - auto new_labels = raft::make_device_vector(handle, n_rows); - raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.metric = index->metric(); - auto orig_centroids_view = - raft::make_device_matrix_view(index->centers().data_handle(), n_lists, dim); - // Calculate the batch size for the input data if it's not accessible directly from the device - constexpr size_t kReasonableMaxBatchSize = 65536; - size_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); - - // Predict the cluster labels for the new data, in batches if necessary - utils::batch_load_iterator vec_batches(new_vectors, - n_rows, - index->dim(), - max_batch_size, - stream, - resource::get_workspace_resource(handle)); - - for (const auto& batch : vec_batches) { - auto batch_data_view = - raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); - auto batch_labels_view = raft::make_device_vector_view( - new_labels.data_handle() + batch.offset(), batch.size()); - raft::cluster::kmeans_balanced::predict(handle, - kmeans_params, - batch_data_view, - orig_centroids_view, - batch_labels_view, - utils::mapping{}); - } - - auto* list_sizes_ptr = index->list_sizes().data_handle(); - auto old_list_sizes_dev = raft::make_device_vector(handle, n_lists); - copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream); - - // Calculate the centers and sizes on the new data, starting from the original values - if (index->adaptive_centers()) { - auto centroids_view = raft::make_device_matrix_view( - index->centers().data_handle(), index->centers().extent(0), index->centers().extent(1)); - auto list_sizes_view = - raft::make_device_vector_view, IdxT>( - list_sizes_ptr, n_lists); - for (const auto& batch : vec_batches) { - auto batch_data_view = - raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); - auto batch_labels_view = raft::make_device_vector_view( - new_labels.data_handle() + batch.offset(), batch.size()); - raft::cluster::kmeans_balanced::helpers::calc_centers_and_sizes(handle, - batch_data_view, - batch_labels_view, - centroids_view, - list_sizes_view, - false, - utils::mapping{}); - } - } else { - raft::stats::histogram(raft::stats::HistTypeAuto, - reinterpret_cast(list_sizes_ptr), - IdxT(n_lists), - new_labels.data_handle(), - n_rows, - 1, - stream); - raft::linalg::add( - list_sizes_ptr, list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); - } - - // Calculate and allocate new list data - std::vector new_list_sizes(n_lists); - std::vector old_list_sizes(n_lists); - { - copy(old_list_sizes.data(), old_list_sizes_dev.data_handle(), n_lists, stream); - copy(new_list_sizes.data(), list_sizes_ptr, n_lists, stream); - resource::sync_stream(handle); - auto& lists = index->lists(); - for (uint32_t label = 0; label < n_lists; label++) { - ivf::resize_list(handle, - lists[label], - list_device_spec, - new_list_sizes[label], - Pow2::roundUp(old_list_sizes[label])); - } - } - // Update the pointers and the sizes - ivf::detail::recompute_internal_state(handle, *index); - // Copy the old sizes, so we can start from the current state of the index; - // we'll rebuild the `list_sizes_ptr` in the following kernel, using it as an atomic counter. - raft::copy(list_sizes_ptr, old_list_sizes_dev.data_handle(), n_lists, stream); - - utils::batch_load_iterator vec_indices( - new_indices, n_rows, 1, max_batch_size, stream, resource::get_workspace_resource(handle)); - utils::batch_load_iterator idx_batch = vec_indices.begin(); - size_t next_report_offset = 0; - size_t d_report_offset = n_rows * 5 / 100; - for (const auto& batch : vec_batches) { - auto batch_data_view = - raft::make_device_matrix_view(batch.data(), batch.size(), index->dim()); - // Kernel to insert the new vectors - const dim3 block_dim(256); - const dim3 grid_dim(raft::ceildiv(batch.size(), block_dim.x)); - build_index_kernel - <<>>(new_labels.data_handle() + batch.offset(), - batch_data_view.data_handle(), - idx_batch->data(), - index->data_ptrs().data_handle(), - index->inds_ptrs().data_handle(), - list_sizes_ptr, - batch.size(), - dim, - index->veclen(), - batch.offset()); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - if (batch.offset() > next_report_offset) { - float progress = batch.offset() * 100.0f / n_rows; - RAFT_LOG_DEBUG("ivf_flat::extend added vectors %zu, %6.1f%% complete", - static_cast(batch.offset()), - progress); - next_report_offset += d_report_offset; - } - ++idx_batch; - } - // Precompute the centers vector norms for L2Expanded distance - if (!index->center_norms().has_value()) { - index->allocate_center_norms(handle); - if (index->center_norms().has_value()) { - raft::linalg::rowNorm(index->center_norms()->data_handle(), - index->centers().data_handle(), - dim, - n_lists, - raft::linalg::L2Norm, - true, - stream); - RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); - } - } else if (index->center_norms().has_value() && index->adaptive_centers()) { - raft::linalg::rowNorm(index->center_norms()->data_handle(), - index->centers().data_handle(), - dim, - n_lists, - raft::linalg::L2Norm, - true, - stream); - RAFT_LOG_TRACE_VEC(index->center_norms()->data_handle(), std::min(dim, 20)); - } -} - -/** See raft::neighbors::ivf_flat::extend docs */ -template -auto extend(raft::resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index -{ - auto ext_index = clone(handle, orig_index); - detail::extend(handle, &ext_index, new_vectors, new_indices, n_rows); - return ext_index; -} - -/** See raft::neighbors::ivf_flat::build docs */ -template -inline auto build(raft::resources const& handle, - const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index -{ - auto stream = resource::get_cuda_stream(handle); - common::nvtx::range fun_scope( - "ivf_flat::build(%zu, %u)", size_t(n_rows), dim); - static_assert(std::is_same_v || std::is_same_v || std::is_same_v, - "unsupported data type"); - RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); - RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); - - index index(handle, params, dim); - utils::memzero( - index.accum_sorted_sizes().data_handle(), index.accum_sorted_sizes().size(), stream); - utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream); - utils::memzero(index.data_ptrs().data_handle(), index.data_ptrs().size(), stream); - utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream); - - // Train the kmeans clustering - { - auto trainset_ratio = std::max( - 1, n_rows / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); - auto n_rows_train = n_rows / trainset_ratio; - rmm::device_uvector trainset(n_rows_train * index.dim(), stream); - // TODO: a proper sampling - RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), - sizeof(T) * index.dim(), - dataset, - sizeof(T) * index.dim() * trainset_ratio, - sizeof(T) * index.dim(), - n_rows_train, - cudaMemcpyDefault, - stream)); - auto trainset_const_view = - raft::make_device_matrix_view(trainset.data(), n_rows_train, index.dim()); - auto centers_view = raft::make_device_matrix_view( - index.centers().data_handle(), index.n_lists(), index.dim()); - raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = index.metric(); - raft::cluster::kmeans_balanced::fit( - handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); - } - - // add the data if necessary - if (params.add_data_on_build) { - detail::extend(handle, &index, dataset, nullptr, n_rows); - } - return index; -} - -/** - * Build an index that can be used in refinement operation. - * - * See raft::neighbors::refine for details on the refinement operation. - * - * The returned index cannot be used for a regular ivf_flat::search. The index misses information - * about coarse clusters. Instead, the neighbor candidates are assumed to form clusters, one for - * each query. The candidate vectors are gathered into the index dataset, that can be later used - * in ivfflat_interleaved_scan. - * - * @param[in] handle the raft handle - * @param[inout] refinement_index - * @param[in] dataset device pointer to dataset vectors, size [n_rows, dim]. Note that n_rows is - * not known to this function, but each candidate_idx has to be smaller than n_rows. - * @param[in] candidate_idx device pointer to neighbor candidates, size [n_queries, n_candidates] - * @param[in] n_candidates of neighbor_candidates - */ -template -inline void fill_refinement_index(raft::resources const& handle, - index* refinement_index, - const T* dataset, - const IdxT* candidate_idx, - IdxT n_queries, - uint32_t n_candidates) -{ - using LabelT = uint32_t; - - auto stream = resource::get_cuda_stream(handle); - uint32_t n_lists = n_queries; - common::nvtx::range fun_scope( - "ivf_flat::fill_refinement_index(%zu, %u)", size_t(n_queries)); - - rmm::device_uvector new_labels(n_queries * n_candidates, stream); - auto new_labels_view = - raft::make_device_vector_view(new_labels.data(), n_queries * n_candidates); - linalg::map_offset( - handle, - new_labels_view, - raft::compose_op(raft::cast_op(), raft::div_const_op(n_candidates))); - - auto list_sizes_ptr = refinement_index->list_sizes().data_handle(); - // We do not fill centers and center norms, since we will not run coarse search. - - // Allocate new memory - auto& lists = refinement_index->lists(); - list_spec list_device_spec{refinement_index->dim(), false}; - for (uint32_t label = 0; label < n_lists; label++) { - ivf::resize_list(handle, lists[label], list_device_spec, n_candidates, uint32_t(0)); - } - // Update the pointers and the sizes - ivf::detail::recompute_internal_state(handle, *refinement_index); - - RAFT_CUDA_TRY(cudaMemsetAsync(list_sizes_ptr, 0, n_lists * sizeof(uint32_t), stream)); - - const dim3 block_dim(256); - const dim3 grid_dim(raft::ceildiv(n_queries * n_candidates, block_dim.x)); - build_index_kernel - <<>>(new_labels.data(), - dataset, - candidate_idx, - refinement_index->data_ptrs().data_handle(), - refinement_index->inds_ptrs().data_handle(), - list_sizes_ptr, - n_queries * n_candidates, - refinement_index->dim(), - refinement_index->veclen()); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -RAFT_KERNEL pack_interleaved_list_kernel(const T* codes, - T* list_data, - uint32_t n_rows, - uint32_t dim, - uint32_t veclen, - std::variant offset_or_indices) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const uint32_t dst_ix = std::holds_alternative(offset_or_indices) - ? std::get(offset_or_indices) + tid - : std::get(offset_or_indices)[tid]; - if (tid < n_rows) { codepacker::pack_1(codes + tid * dim, list_data, dim, veclen, dst_ix); } -} - -template -RAFT_KERNEL unpack_interleaved_list_kernel( - const T* list_data, - T* codes, - uint32_t n_rows, - uint32_t dim, - uint32_t veclen, - std::variant offset_or_indices) -{ - uint32_t tid = blockIdx.x * blockDim.x + threadIdx.x; - const uint32_t src_ix = std::holds_alternative(offset_or_indices) - ? std::get(offset_or_indices) + tid - : std::get(offset_or_indices)[tid]; - if (tid < n_rows) { codepacker::unpack_1(list_data, codes + tid * dim, dim, veclen, src_ix); } -} - -template -void pack_list_data( - raft::resources const& res, - device_matrix_view codes, - uint32_t veclen, - std::variant offset_or_indices, - device_mdspan::list_extents, row_major> list_data) -{ - uint32_t n_rows = codes.extent(0); - uint32_t dim = codes.extent(1); - if (n_rows == 0 || dim == 0) return; - static constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto stream = resource::get_cuda_stream(res); - pack_interleaved_list_kernel<<>>( - codes.data_handle(), list_data.data_handle(), n_rows, dim, veclen, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void unpack_list_data( - raft::resources const& res, - device_mdspan::list_extents, row_major> list_data, - uint32_t veclen, - std::variant offset_or_indices, - device_matrix_view codes) -{ - uint32_t n_rows = codes.extent(0); - uint32_t dim = codes.extent(1); - if (n_rows == 0 || dim == 0) return; - static constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto stream = resource::get_cuda_stream(res); - unpack_interleaved_list_kernel<<>>( - list_data.data_handle(), codes.data_handle(), n_rows, dim, veclen, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh deleted file mode 100644 index 140a9f17c8..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::neighbors::ivf_flat::index -#include // none_ivf_sample_filter -#include // RAFT_EXPLICIT - -#include // rmm:cuda_stream_view - -#include - -#include // uintX_t - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft::neighbors::ivf_flat::detail { - -auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k) -> bool; - -template -void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& index, - const T* queries, - const uint32_t* coarse_query_results, - const uint32_t n_queries, - const uint32_t queries_offset, - const raft::distance::DistanceType metric, - const uint32_t n_probes, - const uint32_t k, - const uint32_t max_samples, - const uint32_t* chunk_indices, - const bool select_min, - IvfSampleFilterT sample_filter, - uint32_t* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) RAFT_EXPLICIT; - -} // namespace raft::neighbors::ivf_flat::detail - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( \ - T, AccT, IdxT, IvfSampleFilterT) \ - extern template void \ - raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const uint32_t queries_offset, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const uint32_t max_samples, \ - const uint32_t* chunk_indices, \ - const bool select_min, \ - IvfSampleFilterT sample_filter, \ - uint32_t* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream) - -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - float, float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - half, half, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - int8_t, int32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan( - uint8_t, uint32_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); - -#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh deleted file mode 100644 index 9cd8b70148..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh +++ /dev/null @@ -1,1188 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // RAFT_LOG_TRACE -#include -#include -#include -#include -#include -#include -#include -#include // RAFT_CUDA_TRY -#include -#include -#include -#include - -#include - -namespace raft::neighbors::ivf_flat::detail { - -using namespace raft::spatial::knn::detail; // NOLINT - -constexpr int kThreadsPerBlock = 128; - -auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k) -> bool -{ - return k <= matrix::detail::select::warpsort::kMaxCapacity; -} - -/** - * @brief Copy `n` elements per block from one place to another. - * - * @param[out] out target pointer (unique per block) - * @param[in] in source pointer - * @param n number of elements to copy - */ -template -__device__ inline void copy_vectorized(T* out, const T* in, uint32_t n) -{ - constexpr int VecElems = VecBytes / sizeof(T); // NOLINT - using align_bytes = Pow2<(size_t)VecBytes>; - if constexpr (VecElems > 1) { - using align_elems = Pow2; - if (!align_bytes::areSameAlignOffsets(out, in)) { - return copy_vectorized<(VecBytes >> 1), T>(out, in, n); - } - { // process unaligned head - uint32_t head = align_bytes::roundUp(in) - in; - if (head > 0) { - copy_vectorized(out, in, head); - n -= head; - in += head; - out += head; - } - } - { // process main part vectorized - using vec_t = typename IOType::Type; - copy_vectorized( - reinterpret_cast(out), reinterpret_cast(in), align_elems::div(n)); - } - { // process unaligned tail - uint32_t tail = align_elems::mod(n); - if (tail > 0) { - n -= tail; - copy_vectorized(out + n, in + n, tail); - } - } - } - if constexpr (VecElems <= 1) { - for (int i = threadIdx.x; i < n; i += blockDim.x) { - out[i] = in[i]; - } - } -} - -/** - * @brief Load a part of a vector from the index and from query, compute the (part of the) distance - * between them, and aggregate it using the provided Lambda; one structure per thread, per query, - * and per index item. - * - * @tparam kUnroll elements per loop (normally, kUnroll = WarpSize / Veclen) - * @tparam Lambda computing the part of the distance for one dimension and aggregating it: - * void (AccT& acc, AccT x, AccT y) - * @tparam Veclen size of the vectorized load - * @tparam T type of the data in the query and the index - * @tparam AccT type of the accumulated value (an optimization for 8bit values to be loaded as 32bit - * values) - */ -template -struct loadAndComputeDist { - Lambda compute_dist; - AccT& dist; - - __device__ __forceinline__ loadAndComputeDist(AccT& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version assumes the query is stored in shared memory. - * Every thread here processes exactly kUnroll * Veclen elements independently of others. - */ - template - __device__ __forceinline__ void runLoadShmemCompute(const T* const& data, - const T* query_shared, - IdxT loadIndex, - IdxT shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - T encV[Veclen]; - ldg(encV, data + (loadIndex + j * kIndexGroupSize) * Veclen); - T queryRegs[Veclen]; - lds(queryRegs, &query_shared[shmemIndex + j * Veclen]); -#pragma unroll - for (int k = 0; k < Veclen; ++k) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version assumes the query is stored in the global memory and is different for every - * thread. One warp loads exactly WarpSize query elements at once and then reshuffles them into - * corresponding threads (`WarpSize / (kUnroll * Veclen)` elements per thread at once). - */ - template - __device__ __forceinline__ void runLoadShflAndCompute(const T*& data, - const T* query, - IdxT baseLoadIndex, - const int lane_id) - { - T queryReg = query[baseLoadIndex + lane_id]; - constexpr int stride = kUnroll * Veclen; - constexpr int totalIter = WarpSize / stride; - constexpr int gmemStride = stride * kIndexGroupSize; -#pragma unroll - for (int i = 0; i < totalIter; ++i, data += gmemStride) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - T encV[Veclen]; - ldg(encV, data + (lane_id + j * kIndexGroupSize) * Veclen); - const int d = (i * kUnroll + j) * Veclen; -#pragma unroll - for (int k = 0; k < Veclen; ++k) { - compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); - } - } - } - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version augments `runLoadShflAndCompute` when `dim` is not a multiple of `WarpSize`. - */ - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const T*& data, const T* query, const int lane_id, const int dim, const int dimBlocks) - { - const int loadDim = dimBlocks + lane_id; - T queryReg = loadDim < dim ? query[loadDim] : T{0}; - const int loadDataIdx = lane_id * Veclen; - for (int d = 0; d < dim - dimBlocks; d += Veclen, data += kIndexGroupSize * Veclen) { - T enc[Veclen]; - ldg(enc, data + loadDataIdx); -#pragma unroll - for (int k = 0; k < Veclen; k++) { - compute_dist(dist, shfl(queryReg, d + k, WarpSize), enc[k]); - } - } - } -}; - -// This handles uint8_t 8, 16 Veclens -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { - constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int - loadIndex = loadIndex * veclen_int; -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + loadIndex + j * kIndexGroupSize * veclen_int); - uint32_t queryRegs[veclen_int]; - lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int - uint32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int stride = kUnroll * uint8_veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); - const int d = (i * kUnroll + j) * veclen_int; -#pragma unroll - for (int k = 0; k < veclen_int; ++k) { - compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); - } - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen_int = uint8_veclen / 4; - const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; - d += uint8_veclen, data += kIndexGroupSize * uint8_veclen) { - uint32_t enc[veclen_int]; - ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - uint32_t q = shfl(queryReg, (d / 4) + k, WarpSize); - compute_dist(dist, q, enc[k]); - } - } - } -}; - -// Keep this specialized uint8 Veclen = 4, because compiler is generating suboptimal code while -// using above common template of int2/int4 -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 4; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 4; - const int loadDim = dimBlocks + lane_id; - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query)[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = reinterpret_cast(data)[lane_id]; - uint32_t q = shfl(queryReg, d / veclen, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = - (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 2; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 2; - int loadDim = dimBlocks + lane_id * veclen; - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = reinterpret_cast(data)[lane_id]; - uint32_t q = shfl(queryReg, d / veclen, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = data[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = query_shared[shmemIndex + j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = query[baseLoadIndex + lane_id]; - constexpr int veclen = 1; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = data[lane_id + j * kIndexGroupSize]; - uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 1; - int loadDim = dimBlocks + lane_id; - uint32_t queryReg = loadDim < dim ? query[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = data[lane_id]; - uint32_t q = shfl(queryReg, d, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -// This device function is for int8 veclens 4, 8 and 16 -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { - constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int - -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + (loadIndex + j * kIndexGroupSize) * veclen_int); - int32_t queryRegs[veclen_int]; - lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int - - int32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int stride = kUnroll * int8_veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); - const int d = (i * kUnroll + j) * veclen_int; -#pragma unroll - for (int k = 0; k < veclen_int; ++k) { - int32_t q = shfl(queryReg, d + k, WarpSize); - compute_dist(dist, q, encV[k]); - } - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen_int = int8_veclen / 4; - const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int; - int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += int8_veclen, data += kIndexGroupSize * int8_veclen) { - int32_t enc[veclen_int]; - ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - int32_t q = shfl(queryReg, (d / 4) + k, WarpSize); // Here 4 is for 1 - int; - compute_dist(dist, q, enc[k]); - } - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - int32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - int32_t queryReg = - (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 2; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - int32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen = 2; - int loadDim = dimBlocks + lane_id * veclen; - int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - int32_t enc = reinterpret_cast(data + lane_id * veclen)[0]; - int32_t q = shfl(queryReg, d / veclen, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - compute_dist(dist, query_shared[shmemIndex + j], data[loadIndex + j * kIndexGroupSize]); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen = 1; - constexpr int stride = kUnroll * veclen; - int32_t queryReg = query[baseLoadIndex + lane_id]; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - compute_dist( - dist, shfl(queryReg, i * kUnroll + j, WarpSize), data[lane_id + j * kIndexGroupSize]); - } - } - } - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen = 1; - const int loadDim = dimBlocks + lane_id; - int32_t queryReg = loadDim < dim ? query[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - compute_dist(dist, shfl(queryReg, d, WarpSize), data[lane_id]); - } - } -}; - -// switch to dummy blocksort when Capacity is 0 this explicit dummy is chosen -// to support access to warpsort constants like ::queue_t::kDummy -template -struct flat_block_sort { - using type = matrix::detail::select::warpsort:: - block_sort; -}; - -template -struct flat_block_sort<0, Ascending, T, IdxT> - : ivf::detail::dummy_block_sort_t { - using type = ivf::detail::dummy_block_sort_t; -}; - -template -using block_sort_t = typename flat_block_sort::type; - -/** - * Scan clusters for nearest neighbors of the query vectors. - * See `ivfflat_interleaved_scan` for more information. - * - * The clusters are stored in the interleaved index format described in ivf_flat_types.hpp. - * For each query vector, a set of clusters is probed: the distance to each vector in the cluster is - * calculated, and the top-k nearest neighbors are selected. - * - * @param compute_dist distance function - * @param query_smem_elems number of dimensions of the query vector to fit in a shared memory of a - * block; this number must be a multiple of `WarpSize * Veclen`. - * @param[in] query a pointer to all queries in a row-major contiguous format [gridDim.y, dim] - * @param[in] coarse_index a pointer to the cluster indices to search through [n_probes] - * @param[in] list_indices index.indices - * @param[in] list_data index.data - * @param[in] list_sizes index.list_sizes - * @param[in] list_offsets index.list_offsets - * @param n_probes - * @param k - * @param dim - * @param sample_filter - * @param[out] neighbors - * @param[out] distances - */ -template -RAFT_KERNEL __launch_bounds__(kThreadsPerBlock) - interleaved_scan_kernel(Lambda compute_dist, - PostLambda post_process, - const uint32_t query_smem_elems, - const T* query, - const uint32_t* coarse_index, - const T* const* list_data_ptrs, - const uint32_t* list_sizes, - const uint32_t queries_offset, - const uint32_t n_probes, - const uint32_t k, - const uint32_t max_samples, - const uint32_t* chunk_indices, - const uint32_t dim, - IvfSampleFilterT sample_filter, - uint32_t* neighbors, - float* distances) -{ - extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; - constexpr bool kManageLocalTopK = Capacity > 0; - // Using shared memory for the (part of the) query; - // This allows to save on global memory bandwidth when reading index and query - // data at the same time. - // Its size is `query_smem_elems`. - T* query_shared = reinterpret_cast(interleaved_scan_kernel_smem); - // Make the query input and output point to this block's shared query - { - const int query_id = blockIdx.y; - query += query_id * dim; - if constexpr (kManageLocalTopK) { - neighbors += query_id * k * gridDim.x + blockIdx.x * k; - distances += query_id * k * gridDim.x + blockIdx.x * k; - } else { - distances += query_id * uint64_t(max_samples); - } - chunk_indices += (n_probes * query_id); - coarse_index += query_id * n_probes; - } - - // Copy a part of the query into shared memory for faster processing - copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); - __syncthreads(); - - using local_topk_t = block_sort_t; - local_topk_t queue(k); - { - using align_warp = Pow2; - const int lane_id = align_warp::mod(threadIdx.x); - - // How many full warps needed to compute the distance (without remainder) - const uint32_t full_warps_along_dim = align_warp::roundDown(dim); - - const uint32_t shm_assisted_dim = - (dim > query_smem_elems) ? query_smem_elems : full_warps_along_dim; - - // Every CUDA block scans one cluster at a time. - for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) { - const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) - - // The number of vectors in each cluster(list); [nlist] - const uint32_t list_length = list_sizes[list_id]; - - // The number of interleaved groups to be processed - const uint32_t num_groups = - align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 - - uint32_t sample_offset = 0; - if (probe_id > 0) { sample_offset = chunk_indices[probe_id - 1]; } - assert(list_length == chunk_indices[probe_id] - sample_offset); - assert(sample_offset + list_length <= max_samples); - - constexpr int kUnroll = WarpSize / Veclen; - constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize; - // Every warp reads WarpSize vectors and computes the distances to them. - // Then, the distances and corresponding ids are distributed among the threads, - // and each thread adds one (id, dist) pair to the filtering queue. - for (uint32_t group_id = align_warp::div(threadIdx.x); group_id < num_groups; - group_id += kNumWarps) { - AccT dist = 0; - // This is where this warp begins reading data (start position of an interleaved group) - const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; - - // This is the vector a given lane/thread handles - const uint32_t vec_id = group_id * WarpSize + lane_id; - const bool valid = - vec_id < list_length && sample_filter(queries_offset + blockIdx.y, list_id, vec_id); - - // Process first shm_assisted_dim dimensions (always using shared memory) - if (valid) { - loadAndComputeDist lc(dist, - compute_dist); - for (int pos = 0; pos < shm_assisted_dim; - pos += WarpSize, data += kIndexGroupSize * WarpSize) { - lc.runLoadShmemCompute(data, query_shared, lane_id, pos); - } - } - - if (dim > query_smem_elems) { - // The default path - using shfl ops - for dimensions beyond query_smem_elems - loadAndComputeDist lc(dist, - compute_dist); - for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += WarpSize) { - lc.runLoadShflAndCompute(data, query, pos, lane_id); - } - lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); - } else { - // when shm_assisted_dim == full_warps_along_dim < dim - if (valid) { - loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT> lc(dist, compute_dist); - for (int pos = full_warps_along_dim; pos < dim; - pos += Veclen, data += kIndexGroupSize * Veclen) { - lc.runLoadShmemCompute(data, query_shared, lane_id, pos); - } - } - } - - // Enqueue one element per thread - const float val = valid ? static_cast(dist) : local_topk_t::queue_t::kDummy; - if constexpr (kManageLocalTopK) { - queue.add(val, sample_offset + vec_id); - } else { - if (vec_id < list_length) distances[sample_offset + vec_id] = val; - } - } - - // fill up unused slots for current query - if constexpr (!kManageLocalTopK) { - if (probe_id + 1 == n_probes) { - for (uint32_t i = threadIdx.x + sample_offset + list_length; i < max_samples; - i += blockDim.x) { - distances[i] = local_topk_t::queue_t::kDummy; - } - } - } - } - } - - // finalize and store selected neighbours - if constexpr (kManageLocalTopK) { - __syncthreads(); - queue.done(interleaved_scan_kernel_smem); - queue.store(distances, neighbors, post_process); - } -} - -/** - * Configure the gridDim.x to maximize GPU occupancy, but reduce the output size - */ -template -uint32_t configure_launch_x(uint32_t numQueries, uint32_t n_probes, int32_t sMemSize, T func) -{ - int dev_id; - RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); - int num_sms; - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); - int num_blocks_per_sm = 0; - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, func, kThreadsPerBlock, sMemSize)); - - size_t min_grid_size = num_sms * num_blocks_per_sm; - size_t min_grid_x = ceildiv(min_grid_size, numQueries); - return min_grid_x > n_probes ? n_probes : static_cast(min_grid_x); -} - -template -void launch_kernel(Lambda lambda, - PostLambda post_process, - const index& index, - const T* queries, - const uint32_t* coarse_index, - const uint32_t num_queries, - const uint32_t queries_offset, - const uint32_t n_probes, - const uint32_t k, - const uint32_t max_samples, - const uint32_t* chunk_indices, - IvfSampleFilterT sample_filter, - uint32_t* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) -{ - RAFT_EXPECTS(Veclen == index.veclen(), - "Configured Veclen does not match the index interleaving pattern."); - constexpr auto kKernel = interleaved_scan_kernel; - const int max_query_smem = 16384; - int query_smem_elems = - std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); - int smem_size = query_smem_elems * sizeof(T); - - if constexpr (Capacity > 0) { - constexpr int kSubwarpSize = std::min(Capacity, WarpSize); - auto block_merge_mem = - raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( - kThreadsPerBlock / kSubwarpSize, k); - smem_size += std::max(smem_size, block_merge_mem); - } - - // power-of-two less than cuda limit (for better addr alignment) - constexpr uint32_t kMaxGridY = 32768; - - if (grid_dim_x == 0) { - grid_dim_x = configure_launch_x(std::min(kMaxGridY, num_queries), n_probes, smem_size, kKernel); - return; - } - - for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) { - uint32_t grid_dim_y = std::min(kMaxGridY, num_queries - query_offset); - dim3 grid_dim(grid_dim_x, grid_dim_y, 1); - dim3 block_dim(kThreadsPerBlock); - RAFT_LOG_TRACE( - "Launching the ivf-flat interleaved_scan_kernel (%d, %d, 1) x (%d, 1, 1), n_probes = %d, " - "smem_size = %d", - grid_dim.x, - grid_dim.y, - block_dim.x, - n_probes, - smem_size); - kKernel<<>>(lambda, - post_process, - query_smem_elems, - queries, - coarse_index, - index.data_ptrs().data_handle(), - index.list_sizes().data_handle(), - queries_offset + query_offset, - n_probes, - k, - max_samples, - chunk_indices, - index.dim(), - sample_filter, - neighbors, - distances); - queries += grid_dim_y * index.dim(); - if constexpr (Capacity > 0) { - neighbors += grid_dim_y * grid_dim_x * k; - distances += grid_dim_y * grid_dim_x * k; - } else { - distances += grid_dim_y * max_samples; - } - chunk_indices += grid_dim_y * n_probes; - coarse_index += grid_dim_y * n_probes; - } -} - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) - { - const auto diff = x - y; - acc += diff * diff; - } -}; - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(uint32_t& acc, uint32_t x, uint32_t y) - { - if constexpr (Veclen > 1) { - const auto diff = __vabsdiffu4(x, y); - acc = dp4a(diff, diff, acc); - } else { - const auto diff = __usad(x, y, 0u); - acc += diff * diff; - } - } -}; - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(int32_t& acc, int32_t x, int32_t y) - { - if constexpr (Veclen > 1) { - // Note that we enforce here that the unsigned version of dp4a is used, because the difference - // between two int8 numbers can be greater than 127 and therefore represented as a negative - // number in int8. Casting from int8 to int32 would yield incorrect results, while casting - // from uint8 to uint32 is correct. - const auto diff = __vabsdiffs4(x, y); - acc = dp4a(diff, diff, static_cast(acc)); - } else { - const auto diff = x - y; - acc += diff * diff; - } - } -}; - -template -struct inner_prod_dist { - __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) - { - if constexpr (Veclen > 1 && (std::is_same_v || std::is_same_v)) { - acc = dp4a(x, y, acc); - } else { - acc += x * y; - } - } -}; - -/** Select the distance computation function and forward the rest of the arguments. */ -template -void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... args) -{ - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2Unexpanded: - return launch_kernel, - raft::identity_op>({}, {}, std::forward(args)...); - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: - return launch_kernel, - raft::sqrt_op>({}, {}, std::forward(args)...); - case raft::distance::DistanceType::InnerProduct: - return launch_kernel, - raft::identity_op>({}, {}, std::forward(args)...); - // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. - default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); - } -} - -/** - * Lift the `capacity` and `veclen` parameters to the template level, - * forward the rest of the arguments unmodified to `launch_interleaved_scan_kernel`. - */ -template (1, 16 / sizeof(T))> -struct select_interleaved_scan_kernel { - /** - * Recursively reduce the `Capacity` and `Veclen` parameters until they match the - * corresponding runtime arguments. - * By default, this recursive process starts with maximum possible values of the - * two parameters and ends with both values equal to 1. - */ - template - static inline void run(int k_max, int veclen, bool select_min, Args&&... args) - { - if constexpr (Capacity > 0) { - if (k_max == 0 || k_max > Capacity) { - return select_interleaved_scan_kernel::run( - k_max, veclen, select_min, std::forward(args)...); - } - } - if constexpr (Capacity > 1) { - if (k_max * 2 <= Capacity) { - return select_interleaved_scan_kernel::run(k_max, - veclen, - select_min, - std::forward(args)...); - } - } - if constexpr (Veclen > 1) { - if (veclen % Veclen != 0) { - return select_interleaved_scan_kernel::run( - k_max, 1, select_min, std::forward(args)...); - } - } - // NB: this is the limitation of the warpsort structures that use a huge number of - // registers (used in the main kernel here). - RAFT_EXPECTS(Capacity == 0 || k_max == Capacity, - "Capacity must be either 0 or a power-of-two not bigger than the maximum " - "allowed size matrix::detail::select::warpsort::kMaxCapacity (%d).", - matrix::detail::select::warpsort::kMaxCapacity); - RAFT_EXPECTS( - veclen == Veclen, - "Veclen must be power-of-two not bigger than the maximum allowed size for this data type."); - if (select_min) { - launch_with_fixed_consts( - std::forward(args)...); - } else { - launch_with_fixed_consts( - std::forward(args)...); - } - } -}; - -/** - * @brief Configure and launch an appropriate template instance of the interleaved scan kernel. - * - * @tparam T value type - * @tparam AccT accumulated type - * @tparam IdxT type of the indices - * - * @param index previously built ivf-flat index - * @param[in] queries device pointer to the query vectors [batch_size, dim] - * @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes] - * @param n_queries batch size - * @param[in] queries_offset - * An offset of the current query batch. It is used for feeding sample_filter with the - * correct query index. - * @param metric type of the measured distance - * @param n_probes number of nearest clusters to query - * @param k number of nearest neighbors. - * NB: the maximum value of `k` is limited statically by `kMaxCapacity`. - * @param select_min whether to select nearest (true) or furthest (false) points w.r.t. the given - * metric. - * @param[out] neighbors device pointer to the result indices for each query and cluster - * [batch_size, grid_dim_x, k] - * @param[out] distances device pointer to the result distances for each query and cluster - * [batch_size, grid_dim_x, k] - * @param[inout] grid_dim_x number of blocks launched across all n_probes clusters; - * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) - * @param stream - * @param sample_filter - * A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to - * provide a green light for every sample. - */ -template -void ivfflat_interleaved_scan(const index& index, - const T* queries, - const uint32_t* coarse_query_results, - const uint32_t n_queries, - const uint32_t queries_offset, - const raft::distance::DistanceType metric, - const uint32_t n_probes, - const uint32_t k, - const uint32_t max_samples, - const uint32_t* chunk_indices, - const bool select_min, - IvfSampleFilterT sample_filter, - uint32_t* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) -{ - const int capacity = bound_by_power_of_two(k); - - auto filter_adapter = raft::neighbors::filtering::ivf_to_sample_filter( - index.inds_ptrs().data_handle(), sample_filter); - select_interleaved_scan_kernel::run(capacity, - index.veclen(), - select_min, - metric, - index, - queries, - coarse_query_results, - n_queries, - queries_offset, - n_probes, - k, - max_samples, - chunk_indices, - filter_adapter, - neighbors, - distances, - grid_dim_x, - stream); -} - -} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh deleted file mode 100644 index 63f341dd9a..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) -#include "ivf_flat_interleaved_scan-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "ivf_flat_interleaved_scan-ext.cuh" -#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh deleted file mode 100644 index c14b0e810f..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::neighbors::ivf_flat::index -#include // none_ivf_sample_filter -#include // RAFT_EXPLICIT - -#include - -#include - -#include // uintX_t - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft::neighbors::ivf_flat::detail { - -template -void search(raft::resources const& handle, - const search_params& params, - const raft::neighbors::ivf_flat::index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::device_async_resource_ref mr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; - -} // namespace raft::neighbors::ivf_flat::detail - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT, IvfSampleFilterT) \ - extern template void raft::neighbors::ivf_flat::detail::search( \ - raft::resources const& handle, \ - const search_params& params, \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::device_async_resource_ref mr, \ - IvfSampleFilterT sample_filter) - -instantiate_raft_neighbors_ivf_flat_detail_search( - float, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_search( - half, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_search( - int8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); -instantiate_raft_neighbors_ivf_flat_detail_search( - uint8_t, int64_t, raft::neighbors::filtering::none_ivf_sample_filter); - -#undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh deleted file mode 100644 index 388dd60f14..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ /dev/null @@ -1,329 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // RAFT_LOG_TRACE -#include -#include // raft::resources -#include // is_min_close, DistanceType -#include // raft::linalg::gemm -#include // raft::linalg::norm -#include // raft::linalg::unary_op -#include // matrix::detail::select_k -#include // raft::neighbors::detail::ivf -#include // interleaved_scan -#include // raft::neighbors::ivf_flat::index -#include // none_ivf_sample_filter -#include // utils::mapping - -#include - -namespace raft::neighbors::ivf_flat::detail { - -using namespace raft::spatial::knn::detail; // NOLINT - -template -void search_impl(raft::resources const& handle, - const raft::neighbors::ivf_flat::index& index, - const T* queries, - uint32_t n_queries, - uint32_t queries_offset, - uint32_t k, - uint32_t n_probes, - uint32_t max_samples, - bool select_min, - IdxT* neighbors, - AccT* distances, - rmm::device_async_resource_ref search_mr, - IvfSampleFilterT sample_filter) -{ - auto stream = resource::get_cuda_stream(handle); - - std::size_t n_queries_probes = std::size_t(n_queries) * std::size_t(n_probes); - - // The norm of query - rmm::device_uvector query_norm_dev(n_queries, stream, search_mr); - // The distance value of cluster(list) and queries - rmm::device_uvector distance_buffer_dev(n_queries * index.n_lists(), stream, search_mr); - // The topk distance value of cluster(list) and queries - rmm::device_uvector coarse_distances_dev(n_queries_probes, stream, search_mr); - // The topk index of cluster(list) and queries - rmm::device_uvector coarse_indices_dev(n_queries_probes, stream, search_mr); - - // Optional structures if postprocessing is required - // The topk distance value of candidate vectors from each cluster(list) - rmm::device_uvector distances_tmp_dev(0, stream, search_mr); - // Number of samples for each query - rmm::device_uvector num_samples(0, stream, search_mr); - // Offsets per probe for each query - rmm::device_uvector chunk_index(0, stream, search_mr); - - // The topk index of candidate vectors from each cluster(list), local index offset - // also we might need additional storage for select_k - rmm::device_uvector indices_tmp_dev(0, stream, search_mr); - rmm::device_uvector neighbors_uint32_buf(0, stream, search_mr); - - size_t float_query_size; - if constexpr (std::is_integral_v) { - float_query_size = n_queries * index.dim(); - } else { - float_query_size = 0; - } - rmm::device_uvector converted_queries_dev(float_query_size, stream, search_mr); - float* converted_queries_ptr = converted_queries_dev.data(); - - if constexpr (std::is_same_v) { - converted_queries_ptr = const_cast(queries); - } else { - linalg::unaryOp( - converted_queries_ptr, queries, n_queries * index.dim(), utils::mapping{}, stream); - } - - float alpha = 1.0f; - float beta = 0.0f; - - // todo(lsugy): raft distance? (if performance is similar/better than gemm) - switch (index.metric()) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2SqrtExpanded: { - alpha = -2.0f; - beta = 1.0f; - raft::linalg::rowNorm(query_norm_dev.data(), - converted_queries_ptr, - static_cast(index.dim()), - static_cast(n_queries), - raft::linalg::L2Norm, - true, - stream); - utils::outer_add(query_norm_dev.data(), - (IdxT)n_queries, - index.center_norms()->data_handle(), - (IdxT)index.n_lists(), - distance_buffer_dev.data(), - stream); - RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min(20, index.dim())); - RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); - break; - } - default: { - alpha = 1.0f; - beta = 0.0f; - } - } - - linalg::gemm(handle, - true, - false, - index.n_lists(), - n_queries, - index.dim(), - &alpha, - index.centers().data_handle(), - index.dim(), - converted_queries_ptr, - index.dim(), - &beta, - distance_buffer_dev.data(), - index.n_lists(), - stream); - - RAFT_LOG_TRACE_VEC(distance_buffer_dev.data(), std::min(20, index.n_lists())); - matrix::detail::select_k(handle, - distance_buffer_dev.data(), - nullptr, - n_queries, - index.n_lists(), - n_probes, - coarse_distances_dev.data(), - coarse_indices_dev.data(), - select_min); - RAFT_LOG_TRACE_VEC(coarse_indices_dev.data(), n_probes); - RAFT_LOG_TRACE_VEC(coarse_distances_dev.data(), n_probes); - - uint32_t grid_dim_x = 0; - if (n_probes > 1) { - // query the gridDimX size to store probes topK output - ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( - index, - nullptr, - nullptr, - n_queries, - queries_offset, - index.metric(), - n_probes, - k, - 0, - nullptr, - select_min, - sample_filter, - nullptr, - nullptr, - grid_dim_x, - stream); - } else { - grid_dim_x = 1; - } - - num_samples.resize(n_queries, stream); - chunk_index.resize(n_queries_probes, stream); - - ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)(index.list_sizes().data_handle(), - coarse_indices_dev.data(), - chunk_index.data(), - num_samples.data(), - stream); - - auto distances_dev_ptr = distances; - - uint32_t* neighbors_uint32 = nullptr; - if constexpr (sizeof(IdxT) == sizeof(uint32_t)) { - neighbors_uint32 = reinterpret_cast(neighbors); - } else { - neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k), stream); - neighbors_uint32 = neighbors_uint32_buf.data(); - } - - uint32_t* indices_dev_ptr = nullptr; - - bool manage_local_topk = is_local_topk_feasible(k); - if (!manage_local_topk || grid_dim_x > 1) { - auto target_size = std::size_t(n_queries) * (manage_local_topk ? grid_dim_x * k : max_samples); - - distances_tmp_dev.resize(target_size, stream); - if (manage_local_topk) indices_tmp_dev.resize(target_size, stream); - - distances_dev_ptr = distances_tmp_dev.data(); - indices_dev_ptr = indices_tmp_dev.data(); - } else { - indices_dev_ptr = neighbors_uint32; - } - - ivfflat_interleaved_scan::value_t, IdxT, IvfSampleFilterT>( - index, - queries, - coarse_indices_dev.data(), - n_queries, - queries_offset, - index.metric(), - n_probes, - k, - max_samples, - chunk_index.data(), - select_min, - sample_filter, - indices_dev_ptr, - distances_dev_ptr, - grid_dim_x, - stream); - - RAFT_LOG_TRACE_VEC(distances_dev_ptr, 2 * k); - if (indices_dev_ptr != nullptr) { RAFT_LOG_TRACE_VEC(indices_dev_ptr, 2 * k); } - - // Merge topk values from different blocks - if (!manage_local_topk || grid_dim_x > 1) { - matrix::detail::select_k(handle, - distances_tmp_dev.data(), - indices_tmp_dev.data(), - n_queries, - manage_local_topk ? (k * grid_dim_x) : max_samples, - k, - distances, - neighbors_uint32, - select_min, - false, - matrix::SelectAlgo::kAuto, - manage_local_topk ? nullptr : num_samples.data()); - } - if (!manage_local_topk) { - // post process distances && neighbor IDs - ivf::detail::postprocess_distances( - distances, distances, index.metric(), n_queries, k, 1.0, false, stream); - } - ivf::detail::postprocess_neighbors(neighbors, - neighbors_uint32, - index.inds_ptrs().data_handle(), - coarse_indices_dev.data(), - chunk_index.data(), - n_queries, - n_probes, - k, - stream); -} - -/** See raft::neighbors::ivf_flat::search docs */ -template -inline void search(raft::resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource(), - IvfSampleFilterT sample_filter = IvfSampleFilterT()) -{ - common::nvtx::range fun_scope( - "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); - - RAFT_EXPECTS(params.n_probes > 0, - "n_probes (number of clusters to probe in the search) must be positive."); - auto n_probes = std::min(params.n_probes, index.n_lists()); - bool manage_local_topk = is_local_topk_feasible(k); - - uint32_t max_samples = 0; - if (!manage_local_topk) { - IdxT ms = - Pow2<128 / sizeof(float)>::roundUp(std::max(index.accum_sorted_sizes()(n_probes), k)); - RAFT_EXPECTS(ms <= IdxT(std::numeric_limits::max()), - "The maximum sample size is too big."); - max_samples = ms; - } - - // a batch size heuristic: try to keep the workspace within the specified size - constexpr uint64_t kExpectedWsSize = 1024 * 1024 * 1024; - uint64_t max_ws_size = std::min(resource::get_workspace_free_bytes(handle), kExpectedWsSize); - - uint64_t ws_size_per_query = 4ull * (2 * n_probes + index.n_lists() + index.dim() + 1) + - (manage_local_topk ? ((sizeof(IdxT) + 4) * n_probes * k) - : (4ull * (max_samples + n_probes + 1))); - - const uint32_t max_queries = - std::min(n_queries, raft::div_rounding_up_safe(max_ws_size, ws_size_per_query)); - - for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { - uint32_t queries_batch = min(max_queries, n_queries - offset_q); - - search_impl(handle, - index, - queries + offset_q * index.dim(), - queries_batch, - offset_q, - k, - n_probes, - max_samples, - raft::distance::is_min_close(index.metric()), - neighbors + offset_q * k, - distances + offset_q * k, - mr, - sample_filter); - } -} - -} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh deleted file mode 100644 index 7b03ebeab6..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "ivf_flat_search-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "ivf_flat_search-ext.cuh" -#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh deleted file mode 100644 index 3897b83aa6..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::neighbors::ivf_flat::detail { - -// Serialization version -// No backward compatibility yet; that is, can't add additional fields without breaking -// backward compatibility. -// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward -// compatible fashion. -constexpr int serialization_version = 4; - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index_ IVF-Flat index - * - */ -template -void serialize(raft::resources const& handle, std::ostream& os, const index& index_) -{ - RAFT_LOG_DEBUG( - "Saving IVF-Flat index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); - - std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); - dtype_string.resize(4); - os << dtype_string; - - serialize_scalar(handle, os, serialization_version); - serialize_scalar(handle, os, index_.size()); - serialize_scalar(handle, os, index_.dim()); - serialize_scalar(handle, os, index_.n_lists()); - serialize_scalar(handle, os, index_.metric()); - serialize_scalar(handle, os, index_.adaptive_centers()); - serialize_scalar(handle, os, index_.conservative_memory_allocation()); - serialize_mdspan(handle, os, index_.centers()); - if (index_.center_norms()) { - bool has_norms = true; - serialize_scalar(handle, os, has_norms); - serialize_mdspan(handle, os, *index_.center_norms()); - } else { - bool has_norms = false; - serialize_scalar(handle, os, has_norms); - } - auto sizes_host = make_host_vector(index_.list_sizes().extent(0)); - copy(sizes_host.data_handle(), - index_.list_sizes().data_handle(), - sizes_host.size(), - resource::get_cuda_stream(handle)); - resource::sync_stream(handle); - serialize_mdspan(handle, os, sizes_host.view()); - - list_spec list_store_spec{index_.dim(), true}; - for (uint32_t label = 0; label < index_.n_lists(); label++) { - ivf::serialize_list(handle, - os, - index_.lists()[label], - list_store_spec, - Pow2::roundUp(sizes_host(label))); - } - resource::sync_stream(handle); -} - -template -void serialize(raft::resources const& handle, - const std::string& filename, - const index& index_) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - detail::serialize(handle, of, index_); - - of.close(); - if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } -} - -/** Load an index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * @param[in] index_ IVF-Flat index - * - */ -template -auto deserialize(raft::resources const& handle, std::istream& is) -> index -{ - char dtype_string[4]; - is.read(dtype_string, 4); - - auto ver = deserialize_scalar(handle, is); - if (ver != serialization_version) { - RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); - } - auto n_rows = deserialize_scalar(handle, is); - auto dim = deserialize_scalar(handle, is); - auto n_lists = deserialize_scalar(handle, is); - auto metric = deserialize_scalar(handle, is); - bool adaptive_centers = deserialize_scalar(handle, is); - bool cma = deserialize_scalar(handle, is); - - index index_ = index(handle, metric, n_lists, adaptive_centers, cma, dim); - - deserialize_mdspan(handle, is, index_.centers()); - bool has_norms = deserialize_scalar(handle, is); - if (has_norms) { - index_.allocate_center_norms(handle); - if (!index_.center_norms()) { - RAFT_FAIL("Error inconsistent center norms"); - } else { - auto center_norms = index_.center_norms().value(); - deserialize_mdspan(handle, is, center_norms); - } - } - deserialize_mdspan(handle, is, index_.list_sizes()); - - list_spec list_device_spec{index_.dim(), cma}; - list_spec list_store_spec{index_.dim(), true}; - for (uint32_t label = 0; label < index_.n_lists(); label++) { - ivf::deserialize_list(handle, is, index_.lists()[label], list_store_spec, list_device_spec); - } - resource::sync_stream(handle); - - ivf::detail::recompute_internal_state(handle, index_); - - return index_; -} - -template -auto deserialize(raft::resources const& handle, const std::string& filename) -> index -{ - std::ifstream is(filename, std::ios::in | std::ios::binary); - - if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - auto index = detail::deserialize(handle, is); - - is.close(); - - return index; -} -} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh deleted file mode 100644 index 24574642ef..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_pq_build.cuh +++ /dev/null @@ -1,1836 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include - -#include -#include -#include - -#include -#include - -namespace raft::neighbors::ivf_pq::detail { - -using namespace raft::spatial::knn::detail; // NOLINT - -using internal_extents_t = int64_t; // The default mdspan extent type used internally. - -template -__launch_bounds__(BlockDim) RAFT_KERNEL copy_warped_kernel( - T* out, uint32_t ld_out, const S* in, uint32_t ld_in, uint32_t n_cols, size_t n_rows) -{ - using warp = Pow2; - size_t row_ix = warp::div(size_t(threadIdx.x) + size_t(BlockDim) * size_t(blockIdx.x)); - uint32_t i = warp::mod(threadIdx.x); - if (row_ix >= n_rows) return; - out += row_ix * ld_out; - in += row_ix * ld_in; - auto f = utils::mapping{}; - for (uint32_t col_ix = i; col_ix < n_cols; col_ix += warp::Value) { - auto x = f(in[col_ix]); - __syncwarp(); - out[col_ix] = x; - } -} - -/** - * Copy the data one warp-per-row: - * - * 1. load the data per-warp - * 2. apply the `utils::mapping{}` - * 3. sync within warp - * 4. store the data. - * - * Assuming sizeof(T) >= sizeof(S) and the data is properly aligned (see the usage in `build`), this - * allows to re-structure the data within rows in-place. - */ -template -void copy_warped(T* out, - uint32_t ld_out, - const S* in, - uint32_t ld_in, - uint32_t n_cols, - size_t n_rows, - rmm::cuda_stream_view stream) -{ - constexpr uint32_t kBlockDim = 128; - dim3 threads(kBlockDim, 1, 1); - dim3 blocks(div_rounding_up_safe(n_rows, kBlockDim / WarpSize), 1, 1); - copy_warped_kernel - <<>>(out, ld_out, in, ld_in, n_cols, n_rows); -} - -/** - * @brief Fill-in a random orthogonal transformation matrix. - * - * @param handle - * @param force_random_rotation - * @param n_rows - * @param n_cols - * @param[out] rotation_matrix device pointer to a row-major matrix of size [n_rows, n_cols]. - * @param rng random number generator state - */ -inline void make_rotation_matrix(raft::resources const& handle, - bool force_random_rotation, - uint32_t n_rows, - uint32_t n_cols, - float* rotation_matrix, - raft::random::RngState rng = raft::random::RngState(7ULL)) -{ - common::nvtx::range fun_scope( - "ivf_pq::make_rotation_matrix(%u * %u)", n_rows, n_cols); - auto stream = resource::get_cuda_stream(handle); - bool inplace = n_rows == n_cols; - uint32_t n = std::max(n_rows, n_cols); - if (force_random_rotation || !inplace) { - rmm::device_uvector buf(inplace ? 0 : n * n, stream); - float* mat = inplace ? rotation_matrix : buf.data(); - raft::random::normal(handle, rng, mat, n * n, 0.0f, 1.0f); - linalg::detail::qrGetQ_inplace(handle, mat, n, n, stream); - if (!inplace) { - RAFT_CUDA_TRY(cudaMemcpy2DAsync(rotation_matrix, - sizeof(float) * n_cols, - mat, - sizeof(float) * n, - sizeof(float) * n_cols, - n_rows, - cudaMemcpyDefault, - stream)); - } - } else { - uint32_t stride = n + 1; - auto rotation_matrix_view = - raft::make_device_vector_view(rotation_matrix, n * n); - linalg::map_offset(handle, rotation_matrix_view, [stride] __device__(uint32_t i) { - return static_cast(i % stride == 0u); - }); - } -} - -/** - * @brief Compute residual vectors from the source dataset given by selected indices. - * - * The residual has the form `rotation_matrix %* (dataset[row_ids, :] - center)` - * - */ -template -void select_residuals(raft::resources const& handle, - float* residuals, - IdxT n_rows, - uint32_t dim, - uint32_t rot_dim, - const float* rotation_matrix, // [rot_dim, dim] - const float* center, // [dim] - const T* dataset, // [.., dim] - const IdxT* row_ids, // [n_rows] - rmm::device_async_resource_ref device_memory - -) -{ - auto stream = resource::get_cuda_stream(handle); - rmm::device_uvector tmp(size_t(n_rows) * size_t(dim), stream, device_memory); - // Note: the number of rows of the input dataset isn't actually n_rows, but matrix::gather doesn't - // need to know it, any strictly positive number would work. - cub::TransformInputIterator, const T*> mapping_itr( - dataset, utils::mapping{}); - raft::matrix::gather(mapping_itr, (IdxT)dim, n_rows, row_ids, n_rows, tmp.data(), stream); - - raft::matrix::linewise_op(handle, - make_device_matrix_view(tmp.data(), n_rows, dim), - make_device_matrix_view(tmp.data(), n_rows, dim), - true, - raft::sub_op{}, - make_device_vector_view(center, dim)); - - float alpha = 1.0; - float beta = 0.0; - linalg::gemm(handle, - true, - false, - rot_dim, - n_rows, - dim, - &alpha, - rotation_matrix, - dim, - tmp.data(), - dim, - &beta, - residuals, - rot_dim, - stream); -} - -/** - * @brief Compute residual vectors from the source dataset given by selected indices. - * - * The residual has the form - * `rotation_matrix %* (dataset[:, :] - centers[labels[:], 0:dim])` - * - */ -template -void flat_compute_residuals( - raft::resources const& handle, - float* residuals, // [n_rows, rot_dim] - IdxT n_rows, - device_matrix_view rotation_matrix, // [rot_dim, dim] - device_matrix_view centers, // [n_lists, dim_ext] - const T* dataset, // [n_rows, dim] - std::variant labels, // [n_rows] - rmm::device_async_resource_ref device_memory) -{ - auto stream = resource::get_cuda_stream(handle); - auto dim = rotation_matrix.extent(1); - auto rot_dim = rotation_matrix.extent(0); - rmm::device_uvector tmp(n_rows * dim, stream, device_memory); - auto tmp_view = raft::make_device_vector_view(tmp.data(), tmp.size()); - linalg::map_offset(handle, tmp_view, [centers, dataset, labels, dim] __device__(size_t i) { - auto row_ix = i / dim; - auto el_ix = i % dim; - auto label = std::holds_alternative(labels) - ? std::get(labels) - : std::get(labels)[row_ix]; - return utils::mapping{}(dataset[i]) - centers(label, el_ix); - }); - - float alpha = 1.0f; - float beta = 0.0f; - linalg::gemm(handle, - true, - false, - rot_dim, - n_rows, - dim, - &alpha, - rotation_matrix.data_handle(), - dim, - tmp.data(), - dim, - &beta, - residuals, - rot_dim, - stream); -} - -template -__launch_bounds__(BlockDim) RAFT_KERNEL - fill_indices_kernel(IdxT n_rows, IdxT* data_indices, IdxT* data_offsets, const uint32_t* labels) -{ - const auto i = IdxT(BlockDim) * IdxT(blockIdx.x) + IdxT(threadIdx.x); - if (i >= n_rows) { return; } - data_indices[atomicAdd(data_offsets + labels[i], 1)] = i; -} - -/** - * @brief Calculate cluster offsets and arrange data indices into clusters. - * - * @param n_rows - * @param n_lists - * @param[in] labels output of k-means prediction [n_rows] - * @param[in] cluster_sizes [n_lists] - * @param[out] cluster_offsets [n_lists+1] - * @param[out] data_indices [n_rows] - * - * @return size of the largest cluster - */ -template -auto calculate_offsets_and_indices(IdxT n_rows, - uint32_t n_lists, - const uint32_t* labels, - const uint32_t* cluster_sizes, - IdxT* cluster_offsets, - IdxT* data_indices, - rmm::cuda_stream_view stream) -> uint32_t -{ - auto exec_policy = rmm::exec_policy(stream); - // Calculate the offsets - IdxT cumsum = 0; - update_device(cluster_offsets, &cumsum, 1, stream); - thrust::inclusive_scan( - exec_policy, cluster_sizes, cluster_sizes + n_lists, cluster_offsets + 1, add_op{}); - update_host(&cumsum, cluster_offsets + n_lists, 1, stream); - uint32_t max_cluster_size = - *thrust::max_element(exec_policy, cluster_sizes, cluster_sizes + n_lists); - stream.synchronize(); - RAFT_EXPECTS(cumsum == n_rows, "cluster sizes do not add up."); - RAFT_LOG_DEBUG("Max cluster size %d", max_cluster_size); - rmm::device_uvector data_offsets_buf(n_lists, stream); - auto data_offsets = data_offsets_buf.data(); - copy(data_offsets, cluster_offsets, n_lists, stream); - constexpr uint32_t n_threads = 128; // NOLINT - const IdxT n_blocks = raft::div_rounding_up_unsafe(n_rows, n_threads); - fill_indices_kernel - <<>>(n_rows, data_indices, data_offsets, labels); - return max_cluster_size; -} - -template -void set_centers(raft::resources const& handle, index* index, const float* cluster_centers) -{ - auto stream = resource::get_cuda_stream(handle); - auto* device_memory = resource::get_workspace_resource(handle); - - // combine cluster_centers and their norms - RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle(), - sizeof(float) * index->dim_ext(), - cluster_centers, - sizeof(float) * index->dim(), - sizeof(float) * index->dim(), - index->n_lists(), - cudaMemcpyDefault, - stream)); - - rmm::device_uvector center_norms(index->n_lists(), stream, device_memory); - raft::linalg::rowNorm(center_norms.data(), - cluster_centers, - index->dim(), - index->n_lists(), - raft::linalg::L2Norm, - true, - stream); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle() + index->dim(), - sizeof(float) * index->dim_ext(), - center_norms.data(), - sizeof(float), - sizeof(float), - index->n_lists(), - cudaMemcpyDefault, - stream)); - - // Rotate cluster_centers - float alpha = 1.0; - float beta = 0.0; - linalg::gemm(handle, - true, - false, - index->rot_dim(), - index->n_lists(), - index->dim(), - &alpha, - index->rotation_matrix().data_handle(), - index->dim(), - cluster_centers, - index->dim(), - &beta, - index->centers_rot().data_handle(), - index->rot_dim(), - resource::get_cuda_stream(handle)); -} - -template -void transpose_pq_centers(const resources& handle, - index& index, - const float* pq_centers_source) -{ - auto stream = resource::get_cuda_stream(handle); - auto extents = index.pq_centers().extents(); - static_assert(extents.rank() == 3); - auto extents_source = - make_extents(extents.extent(0), extents.extent(2), extents.extent(1)); - auto span_source = - make_mdspan(pq_centers_source, extents_source); - auto pq_centers_view = raft::make_device_vector_view( - index.pq_centers().data_handle(), index.pq_centers().size()); - linalg::map_offset(handle, pq_centers_view, [span_source, extents] __device__(size_t i) { - uint32_t ii[3]; - for (int r = 2; r > 0; r--) { - ii[r] = i % extents.extent(r); - i /= extents.extent(r); - } - ii[0] = i; - return span_source(ii[0], ii[2], ii[1]); - }); -} - -template -void train_per_subset(raft::resources const& handle, - index& index, - size_t n_rows, - const float* trainset, // [n_rows, dim] - const uint32_t* labels, // [n_rows] - uint32_t kmeans_n_iters, - rmm::device_async_resource_ref managed_memory) -{ - auto stream = resource::get_cuda_stream(handle); - auto device_memory = resource::get_workspace_resource(handle); - - rmm::device_uvector pq_centers_tmp(index.pq_centers().size(), stream, device_memory); - rmm::device_uvector sub_trainset(n_rows * size_t(index.pq_len()), stream, device_memory); - rmm::device_uvector sub_labels(n_rows, stream, device_memory); - - rmm::device_uvector pq_cluster_sizes(index.pq_book_size(), stream, device_memory); - - for (uint32_t j = 0; j < index.pq_dim(); j++) { - common::nvtx::range pq_per_subspace_scope( - "ivf_pq::build::per_subspace[%u]", j); - - // Get the rotated cluster centers for each training vector. - // This will be subtracted from the input vectors afterwards. - utils::copy_selected( - n_rows, - index.pq_len(), - index.centers_rot().data_handle() + index.pq_len() * j, - labels, - index.rot_dim(), - sub_trainset.data(), - index.pq_len(), - stream); - - // sub_trainset is the slice of: rotate(trainset) - centers_rot - float alpha = 1.0; - float beta = -1.0; - linalg::gemm(handle, - true, - false, - index.pq_len(), - n_rows, - index.dim(), - &alpha, - index.rotation_matrix().data_handle() + index.dim() * index.pq_len() * j, - index.dim(), - trainset, - index.dim(), - &beta, - sub_trainset.data(), - index.pq_len(), - stream); - - // train PQ codebook for this subspace - auto sub_trainset_view = raft::make_device_matrix_view( - sub_trainset.data(), n_rows, index.pq_len()); - auto centers_tmp_view = raft::make_device_matrix_view( - pq_centers_tmp.data() + index.pq_book_size() * index.pq_len() * j, - index.pq_book_size(), - index.pq_len()); - auto sub_labels_view = - raft::make_device_vector_view(sub_labels.data(), n_rows); - auto cluster_sizes_view = raft::make_device_vector_view( - pq_cluster_sizes.data(), index.pq_book_size()); - raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.n_iters = kmeans_n_iters; - kmeans_params.metric = raft::distance::DistanceType::L2Expanded; - raft::cluster::kmeans_balanced::helpers::build_clusters(handle, - kmeans_params, - sub_trainset_view, - centers_tmp_view, - sub_labels_view, - cluster_sizes_view, - utils::mapping{}); - } - transpose_pq_centers(handle, index, pq_centers_tmp.data()); -} - -template -void train_per_cluster(raft::resources const& handle, - index& index, - size_t n_rows, - const float* trainset, // [n_rows, dim] - const uint32_t* labels, // [n_rows] - uint32_t kmeans_n_iters, - rmm::device_async_resource_ref managed_memory) -{ - auto stream = resource::get_cuda_stream(handle); - auto device_memory = resource::get_workspace_resource(handle); - - rmm::device_uvector pq_centers_tmp(index.pq_centers().size(), stream, device_memory); - rmm::device_uvector cluster_sizes(index.n_lists(), stream, managed_memory); - rmm::device_uvector indices_buf(n_rows, stream, device_memory); - rmm::device_uvector offsets_buf(index.n_lists() + 1, stream, managed_memory); - - raft::stats::histogram(raft::stats::HistTypeAuto, - reinterpret_cast(cluster_sizes.data()), - index.n_lists(), - labels, - n_rows, - 1, - stream); - - auto cluster_offsets = offsets_buf.data(); - auto indices = indices_buf.data(); - uint32_t max_cluster_size = calculate_offsets_and_indices( - IdxT(n_rows), index.n_lists(), labels, cluster_sizes.data(), cluster_offsets, indices, stream); - - rmm::device_uvector pq_labels( - size_t(max_cluster_size) * size_t(index.pq_dim()), stream, device_memory); - rmm::device_uvector pq_cluster_sizes(index.pq_book_size(), stream, device_memory); - rmm::device_uvector rot_vectors( - size_t(max_cluster_size) * size_t(index.rot_dim()), stream, device_memory); - - resource::sync_stream(handle); // make sure cluster offsets are up-to-date - for (uint32_t l = 0; l < index.n_lists(); l++) { - auto cluster_size = cluster_sizes.data()[l]; - if (cluster_size == 0) continue; - common::nvtx::range pq_per_cluster_scope( - "ivf_pq::build::per_cluster[%u](size = %u)", l, cluster_size); - - select_residuals(handle, - rot_vectors.data(), - IdxT(cluster_size), - index.dim(), - index.rot_dim(), - index.rotation_matrix().data_handle(), - index.centers().data_handle() + size_t(l) * size_t(index.dim_ext()), - trainset, - indices + cluster_offsets[l], - device_memory); - - // limit the cluster size to bound the training time. - // [sic] we interpret the data as pq_len-dimensional - size_t big_enough = 256ul * std::max(index.pq_book_size(), index.pq_dim()); - size_t available_rows = size_t(cluster_size) * size_t(index.pq_dim()); - auto pq_n_rows = uint32_t(std::min(big_enough, available_rows)); - // train PQ codebook for this cluster - auto rot_vectors_view = raft::make_device_matrix_view( - rot_vectors.data(), pq_n_rows, index.pq_len()); - auto centers_tmp_view = raft::make_device_matrix_view( - pq_centers_tmp.data() + static_cast(index.pq_book_size()) * - static_cast(index.pq_len()) * static_cast(l), - index.pq_book_size(), - index.pq_len()); - auto pq_labels_view = - raft::make_device_vector_view(pq_labels.data(), pq_n_rows); - auto pq_cluster_sizes_view = raft::make_device_vector_view( - pq_cluster_sizes.data(), index.pq_book_size()); - raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.n_iters = kmeans_n_iters; - kmeans_params.metric = raft::distance::DistanceType::L2Expanded; - raft::cluster::kmeans_balanced::helpers::build_clusters(handle, - kmeans_params, - rot_vectors_view, - centers_tmp_view, - pq_labels_view, - pq_cluster_sizes_view, - utils::mapping{}); - } - transpose_pq_centers(handle, index, pq_centers_tmp.data()); -} - -/** - * A helper function: given the dataset in the rotated space - * [n_rows, rot_dim] = [n_rows, pq_dim * pq_len], - * reinterpret the last dimension as two: [n_rows, pq_dim, pq_len] - * - * @tparam T - * @tparam IdxT - * - * @param vectors input data [n_rows, rot_dim] - * @param pq_centers codebook (used to infer the structure - pq_len) - * @return reinterpreted vectors [n_rows, pq_dim, pq_len] - */ -template -static __device__ auto reinterpret_vectors( - device_matrix_view vectors, - device_mdspan, row_major> pq_centers) - -> device_mdspan, row_major> -{ - const uint32_t pq_len = pq_centers.extent(1); - const uint32_t pq_dim = vectors.extent(1) / pq_len; - using layout_t = typename decltype(vectors)::layout_type; - using accessor_t = typename decltype(vectors)::accessor_type; - return mdspan, layout_t, accessor_t>( - vectors.data_handle(), extent_3d{vectors.extent(0), pq_dim, pq_len}); -} - -/** - * A consumer for the `run_on_list` and `run_on_vector` that just flattens PQ codes - * one-per-byte. That is, independent of the code width (pq_bits), one code uses - * the whole byte, hence one vectors uses pq_dim bytes. - */ -struct unpack_codes { - device_matrix_view out_codes; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[out] out_codes the destination for the read codes. - */ - __device__ inline unpack_codes(device_matrix_view out_codes) - : out_codes{out_codes} - { - } - - /** Write j-th component (code) of the i-th vector into the output array. */ - __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) - { - out_codes(i, j) = code; - } -}; - -template -__launch_bounds__(BlockSize) RAFT_KERNEL unpack_list_data_kernel( - device_matrix_view out_codes, - device_mdspan::list_extents, row_major> in_list_data, - std::variant offset_or_indices) -{ - const uint32_t pq_dim = out_codes.extent(1); - auto unpack_action = unpack_codes{out_codes}; - run_on_list(in_list_data, offset_or_indices, out_codes.extent(0), pq_dim, unpack_action); -} - -/** - * Unpack flat PQ codes from an existing list by the given offset. - * - * @param[out] codes flat PQ codes, one code per byte [n_rows, pq_dim] - * @param[in] list_data the packed ivf::list data. - * @param[in] offset_or_indices how many records in the list to skip or the exact indices. - * @param[in] pq_bits codebook size (1 << pq_bits) - * @param[in] stream - */ -inline void unpack_list_data( - device_matrix_view codes, - device_mdspan::list_extents, row_major> list_data, - std::variant offset_or_indices, - uint32_t pq_bits, - rmm::cuda_stream_view stream) -{ - auto n_rows = codes.extent(0); - if (n_rows == 0) { return; } - - constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [pq_bits]() { - switch (pq_bits) { - case 4: return unpack_list_data_kernel; - case 5: return unpack_list_data_kernel; - case 6: return unpack_list_data_kernel; - case 7: return unpack_list_data_kernel; - case 8: return unpack_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(); - kernel<<>>(codes, list_data, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** Unpack the list data; see the public interface for the api and usage. */ -template -void unpack_list_data(raft::resources const& res, - const index& index, - device_matrix_view out_codes, - uint32_t label, - std::variant offset_or_indices) -{ - unpack_list_data(out_codes, - index.lists()[label]->data.view(), - offset_or_indices, - index.pq_bits(), - resource::get_cuda_stream(res)); -} - -/** - * A consumer for the `run_on_vector` that just flattens PQ codes - * into a tightly packed matrix. That is, the codes are not expanded to one code-per-byte. - */ -template -struct unpack_contiguous { - uint8_t* codes; - uint32_t code_size; - - /** - * Create a callable to be passed to `run_on_vector`. - * - * @param[in] codes flat compressed PQ codes - */ - __host__ __device__ inline unpack_contiguous(uint8_t* codes, uint32_t pq_dim) - : codes{codes}, code_size{raft::ceildiv(pq_dim * PqBits, 8)} - { - } - - /** Write j-th component (code) of the i-th vector into the output array. */ - __host__ __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) - { - bitfield_view_t code_view{codes + i * code_size}; - code_view[j] = code; - } -}; - -template -__launch_bounds__(BlockSize) RAFT_KERNEL unpack_contiguous_list_data_kernel( - uint8_t* out_codes, - device_mdspan::list_extents, row_major> in_list_data, - uint32_t n_rows, - uint32_t pq_dim, - std::variant offset_or_indices) -{ - run_on_list( - in_list_data, offset_or_indices, n_rows, pq_dim, unpack_contiguous(out_codes, pq_dim)); -} - -/** - * Unpack flat PQ codes from an existing list by the given offset. - * - * @param[out] codes flat compressed PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)] - * @param[in] list_data the packed ivf::list data. - * @param[in] offset_or_indices how many records in the list to skip or the exact indices. - * @param[in] pq_bits codebook size (1 << pq_bits) - * @param[in] stream - */ -inline void unpack_contiguous_list_data( - uint8_t* codes, - device_mdspan::list_extents, row_major> list_data, - uint32_t n_rows, - uint32_t pq_dim, - std::variant offset_or_indices, - uint32_t pq_bits, - rmm::cuda_stream_view stream) -{ - if (n_rows == 0) { return; } - - constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [pq_bits]() { - switch (pq_bits) { - case 4: return unpack_contiguous_list_data_kernel; - case 5: return unpack_contiguous_list_data_kernel; - case 6: return unpack_contiguous_list_data_kernel; - case 7: return unpack_contiguous_list_data_kernel; - case 8: return unpack_contiguous_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(); - kernel<<>>(codes, list_data, n_rows, pq_dim, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** Unpack the list data; see the public interface for the api and usage. */ -template -void unpack_contiguous_list_data(raft::resources const& res, - const index& index, - uint8_t* out_codes, - uint32_t n_rows, - uint32_t label, - std::variant offset_or_indices) -{ - unpack_contiguous_list_data(out_codes, - index.lists()[label]->data.view(), - n_rows, - index.pq_dim(), - offset_or_indices, - index.pq_bits(), - resource::get_cuda_stream(res)); -} - -/** A consumer for the `run_on_list` and `run_on_vector` that approximates the original input data. - */ -struct reconstruct_vectors { - codebook_gen codebook_kind; - uint32_t cluster_ix; - uint32_t pq_len; - device_mdspan, row_major> pq_centers; - device_mdspan, row_major> centers_rot; - device_mdspan, row_major> out_vectors; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[out] out_vectors the destination for the decoded vectors. - * @param[in] pq_centers the codebook - * @param[in] centers_rot - * @param[in] codebook_kind - * @param[in] cluster_ix label/id of the cluster. - */ - __device__ inline reconstruct_vectors( - device_matrix_view out_vectors, - device_mdspan, row_major> pq_centers, - device_matrix_view centers_rot, - codebook_gen codebook_kind, - uint32_t cluster_ix) - : codebook_kind{codebook_kind}, - cluster_ix{cluster_ix}, - pq_len{pq_centers.extent(1)}, - pq_centers{pq_centers}, - centers_rot{reinterpret_vectors(centers_rot, pq_centers)}, - out_vectors{reinterpret_vectors(out_vectors, pq_centers)} - { - } - - /** - * Decode j-th component of the i-th vector by its code and write it into a chunk of the output - * vectors (pq_len elements). - */ - __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j) - { - uint32_t partition_ix; - switch (codebook_kind) { - case codebook_gen::PER_CLUSTER: { - partition_ix = cluster_ix; - } break; - case codebook_gen::PER_SUBSPACE: { - partition_ix = j; - } break; - default: __builtin_unreachable(); - } - for (uint32_t k = 0; k < pq_len; k++) { - out_vectors(i, j, k) = pq_centers(partition_ix, k, code) + centers_rot(cluster_ix, j, k); - } - } -}; - -template -__launch_bounds__(BlockSize) RAFT_KERNEL reconstruct_list_data_kernel( - device_matrix_view out_vectors, - device_mdspan::list_extents, row_major> in_list_data, - device_mdspan, row_major> pq_centers, - device_matrix_view centers_rot, - codebook_gen codebook_kind, - uint32_t cluster_ix, - std::variant offset_or_indices) -{ - const uint32_t pq_dim = out_vectors.extent(1) / pq_centers.extent(1); - auto reconstruct_action = - reconstruct_vectors{out_vectors, pq_centers, centers_rot, codebook_kind, cluster_ix}; - run_on_list( - in_list_data, offset_or_indices, out_vectors.extent(0), pq_dim, reconstruct_action); -} - -/** Decode the list data; see the public interface for the api and usage. */ -template -void reconstruct_list_data(raft::resources const& res, - const index& index, - device_matrix_view out_vectors, - uint32_t label, - std::variant offset_or_indices) -{ - auto n_rows = out_vectors.extent(0); - if (n_rows == 0) { return; } - auto& list = index.lists()[label]; - if (std::holds_alternative(offset_or_indices)) { - auto n_skip = std::get(offset_or_indices); - // sic! I'm using the upper bound `list.size` instead of exact `list_sizes(label)` - // to avoid an extra device-host data copy and the stream sync. - RAFT_EXPECTS(n_skip + n_rows <= list->size.load(), - "offset + output size must be not bigger than the cluster size."); - } - - auto tmp = make_device_mdarray( - res, resource::get_workspace_resource(res), make_extents(n_rows, index.rot_dim())); - - constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [](uint32_t pq_bits) { - switch (pq_bits) { - case 4: return reconstruct_list_data_kernel; - case 5: return reconstruct_list_data_kernel; - case 6: return reconstruct_list_data_kernel; - case 7: return reconstruct_list_data_kernel; - case 8: return reconstruct_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(index.pq_bits()); - kernel<<>>(tmp.view(), - list->data.view(), - index.pq_centers(), - index.centers_rot(), - index.codebook_kind(), - label, - offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - - float* out_float_ptr = nullptr; - rmm::device_uvector out_float_buf( - 0, resource::get_cuda_stream(res), resource::get_workspace_resource(res)); - if constexpr (std::is_same_v) { - out_float_ptr = out_vectors.data_handle(); - } else { - out_float_buf.resize(size_t{n_rows} * size_t{index.dim()}, resource::get_cuda_stream(res)); - out_float_ptr = out_float_buf.data(); - } - // Rotate the results back to the original space - float alpha = 1.0; - float beta = 0.0; - linalg::gemm(res, - false, - false, - index.dim(), - n_rows, - index.rot_dim(), - &alpha, - index.rotation_matrix().data_handle(), - index.dim(), - tmp.data_handle(), - index.rot_dim(), - &beta, - out_float_ptr, - index.dim(), - resource::get_cuda_stream(res)); - // Transform the data to the original type, if necessary - if constexpr (!std::is_same_v) { - linalg::map(res, - out_vectors, - utils::mapping{}, - make_device_matrix_view(out_float_ptr, n_rows, index.dim())); - } -} - -/** - * A producer for the `write_list` and `write_vector` reads the codes byte-by-byte. That is, - * independent of the code width (pq_bits), one code uses the whole byte, hence one vectors uses - * pq_dim bytes. - */ -struct pass_codes { - device_matrix_view codes; - - /** - * Create a callable to be passed to `run_on_list`. - * - * @param[in] codes the source codes. - */ - __device__ inline pass_codes(device_matrix_view codes) - : codes{codes} - { - } - - /** Read j-th component (code) of the i-th vector from the source. */ - __device__ inline auto operator()(uint32_t i, uint32_t j) const -> uint8_t { return codes(i, j); } -}; - -template -__launch_bounds__(BlockSize) RAFT_KERNEL pack_list_data_kernel( - device_mdspan::list_extents, row_major> list_data, - device_matrix_view codes, - std::variant offset_or_indices) -{ - write_list( - list_data, offset_or_indices, codes.extent(0), codes.extent(1), pass_codes{codes}); -} - -/** - * Write flat PQ codes into an existing list by the given offset. - * - * NB: no memory allocation happens here; the list must fit the data (offset + n_rows). - * - * @param[out] list_data the packed ivf::list data. - * @param[in] codes flat PQ codes, one code per byte [n_rows, pq_dim] - * @param[in] offset_or_indices how many records in the list to skip or the exact indices. - * @param[in] pq_bits codebook size (1 << pq_bits) - * @param[in] stream - */ -inline void pack_list_data( - device_mdspan::list_extents, row_major> list_data, - device_matrix_view codes, - std::variant offset_or_indices, - uint32_t pq_bits, - rmm::cuda_stream_view stream) -{ - auto n_rows = codes.extent(0); - if (n_rows == 0) { return; } - - constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [pq_bits]() { - switch (pq_bits) { - case 4: return pack_list_data_kernel; - case 5: return pack_list_data_kernel; - case 6: return pack_list_data_kernel; - case 7: return pack_list_data_kernel; - case 8: return pack_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(); - kernel<<>>(list_data, codes, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void pack_list_data(raft::resources const& res, - index* index, - device_matrix_view new_codes, - uint32_t label, - std::variant offset_or_indices) -{ - pack_list_data(index->lists()[label]->data.view(), - new_codes, - offset_or_indices, - index->pq_bits(), - resource::get_cuda_stream(res)); -} - -/** - * A producer for the `write_vector` reads tightly packed flat codes. That is, - * the codes are not expanded to one code-per-byte. - */ -template -struct pack_contiguous { - const uint8_t* codes; - uint32_t code_size; - - /** - * Create a callable to be passed to `write_vector`. - * - * @param[in] codes flat compressed PQ codes - */ - __host__ __device__ inline pack_contiguous(const uint8_t* codes, uint32_t pq_dim) - : codes{codes}, code_size{raft::ceildiv(pq_dim * PqBits, 8)} - { - } - - /** Read j-th component (code) of the i-th vector from the source. */ - __host__ __device__ inline auto operator()(uint32_t i, uint32_t j) -> uint8_t - { - bitfield_view_t code_view{const_cast(codes + i * code_size)}; - return uint8_t(code_view[j]); - } -}; - -template -__launch_bounds__(BlockSize) RAFT_KERNEL pack_contiguous_list_data_kernel( - device_mdspan::list_extents, row_major> list_data, - const uint8_t* codes, - uint32_t n_rows, - uint32_t pq_dim, - std::variant offset_or_indices) -{ - write_list( - list_data, offset_or_indices, n_rows, pq_dim, pack_contiguous(codes, pq_dim)); -} - -/** - * Write flat PQ codes into an existing list by the given offset. - * - * NB: no memory allocation happens here; the list must fit the data (offset + n_rows). - * - * @param[out] list_data the packed ivf::list data. - * @param[in] codes flat compressed PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)] - * @param[in] offset_or_indices how many records in the list to skip or the exact indices. - * @param[in] pq_bits codebook size (1 << pq_bits) - * @param[in] stream - */ -inline void pack_contiguous_list_data( - device_mdspan::list_extents, row_major> list_data, - const uint8_t* codes, - uint32_t n_rows, - uint32_t pq_dim, - std::variant offset_or_indices, - uint32_t pq_bits, - rmm::cuda_stream_view stream) -{ - if (n_rows == 0) { return; } - - constexpr uint32_t kBlockSize = 256; - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [pq_bits]() { - switch (pq_bits) { - case 4: return pack_contiguous_list_data_kernel; - case 5: return pack_contiguous_list_data_kernel; - case 6: return pack_contiguous_list_data_kernel; - case 7: return pack_contiguous_list_data_kernel; - case 8: return pack_contiguous_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(); - kernel<<>>(list_data, codes, n_rows, pq_dim, offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -template -void pack_contiguous_list_data(raft::resources const& res, - index* index, - const uint8_t* new_codes, - uint32_t n_rows, - uint32_t label, - std::variant offset_or_indices) -{ - pack_contiguous_list_data(index->lists()[label]->data.view(), - new_codes, - n_rows, - index->pq_dim(), - offset_or_indices, - index->pq_bits(), - resource::get_cuda_stream(res)); -} - -/** - * - * A producer for the `write_list` and `write_vector` that encodes level-1 input vector residuals - * into lvl-2 PQ codes. - * Computing a PQ code means finding the closest cluster in a pq_dim-subspace. - * - * @tparam SubWarpSize - * how many threads work on a single vector; - * bounded by either WarpSize or pq_book_size. - * - * @param pq_centers - * - codebook_gen::PER_SUBSPACE: [pq_dim , pq_len, pq_book_size] - * - codebook_gen::PER_CLUSTER: [n_lists, pq_len, pq_book_size] - * @param new_vector a single input of length rot_dim, reinterpreted as [pq_dim, pq_len]. - * the input must be already transformed to floats, rotated, and the level 1 cluster - * center must be already substructed (i.e. this is the residual of a single input vector). - * @param codebook_kind - * @param j index along pq_dim "dimension" - * @param cluster_ix is used for PER_CLUSTER codebooks. - */ -/** - */ -template -struct encode_vectors { - codebook_gen codebook_kind; - uint32_t cluster_ix; - device_mdspan, row_major> pq_centers; - device_mdspan, row_major> in_vectors; - - __device__ inline encode_vectors( - device_mdspan, row_major> pq_centers, - device_matrix_view in_vectors, - codebook_gen codebook_kind, - uint32_t cluster_ix) - : codebook_kind{codebook_kind}, - cluster_ix{cluster_ix}, - pq_centers{pq_centers}, - in_vectors{reinterpret_vectors(in_vectors, pq_centers)} - { - } - - /** - * Decode j-th component of the i-th vector by its code and write it into a chunk of the output - * vectors (pq_len elements). - */ - __device__ inline auto operator()(IdxT i, uint32_t j) -> uint8_t - { - uint32_t lane_id = Pow2::mod(laneId()); - uint32_t partition_ix; - switch (codebook_kind) { - case codebook_gen::PER_CLUSTER: { - partition_ix = cluster_ix; - } break; - case codebook_gen::PER_SUBSPACE: { - partition_ix = j; - } break; - default: __builtin_unreachable(); - } - - const uint32_t pq_book_size = pq_centers.extent(2); - const uint32_t pq_len = pq_centers.extent(1); - float min_dist = std::numeric_limits::infinity(); - uint8_t code = 0; - // calculate the distance for each PQ cluster, find the minimum for each thread - for (uint32_t l = lane_id; l < pq_book_size; l += SubWarpSize) { - // NB: the L2 quantifiers on residuals are always trained on L2 metric. - float d = 0.0f; - for (uint32_t k = 0; k < pq_len; k++) { - auto t = in_vectors(i, j, k) - pq_centers(partition_ix, k, l); - d += t * t; - } - if (d < min_dist) { - min_dist = d; - code = uint8_t(l); - } - } - // reduce among threads -#pragma unroll - for (uint32_t stride = SubWarpSize >> 1; stride > 0; stride >>= 1) { - const auto other_dist = shfl_xor(min_dist, stride, SubWarpSize); - const auto other_code = shfl_xor(code, stride, SubWarpSize); - if (other_dist < min_dist) { - min_dist = other_dist; - code = other_code; - } - } - return code; - } -}; - -template -__launch_bounds__(BlockSize) RAFT_KERNEL process_and_fill_codes_kernel( - device_matrix_view new_vectors, - std::variant src_offset_or_indices, - const uint32_t* new_labels, - device_vector_view list_sizes, - device_vector_view inds_ptrs, - device_vector_view data_ptrs, - device_mdspan, row_major> pq_centers, - codebook_gen codebook_kind) -{ - constexpr uint32_t kSubWarpSize = std::min(WarpSize, 1u << PqBits); - using subwarp_align = Pow2; - const uint32_t lane_id = subwarp_align::mod(threadIdx.x); - const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x}); - if (row_ix >= new_vectors.extent(0)) { return; } - - const uint32_t cluster_ix = new_labels[row_ix]; - uint32_t out_ix; - if (lane_id == 0) { out_ix = atomicAdd(&list_sizes(cluster_ix), 1); } - out_ix = shfl(out_ix, 0, kSubWarpSize); - - // write the label (one record per subwarp) - auto pq_indices = inds_ptrs(cluster_ix); - if (lane_id == 0) { - if (std::holds_alternative(src_offset_or_indices)) { - pq_indices[out_ix] = std::get(src_offset_or_indices) + row_ix; - } else { - pq_indices[out_ix] = std::get(src_offset_or_indices)[row_ix]; - } - } - - // write the codes (one record per subwarp): - const uint32_t pq_dim = new_vectors.extent(1) / pq_centers.extent(1); - auto pq_extents = list_spec{PqBits, pq_dim, true}.make_list_extents(out_ix + 1); - auto pq_dataset = - make_mdspan(data_ptrs[cluster_ix], pq_extents); - write_vector( - pq_dataset, - out_ix, - row_ix, - pq_dim, - encode_vectors{pq_centers, new_vectors, codebook_kind, cluster_ix}); -} - -template -__launch_bounds__(BlockSize) RAFT_KERNEL encode_list_data_kernel( - device_mdspan::list_extents, row_major> list_data, - device_matrix_view new_vectors, - device_mdspan, row_major> pq_centers, - codebook_gen codebook_kind, - uint32_t cluster_ix, - std::variant offset_or_indices) -{ - constexpr uint32_t kSubWarpSize = std::min(WarpSize, 1u << PqBits); - const uint32_t pq_dim = new_vectors.extent(1) / pq_centers.extent(1); - auto encode_action = - encode_vectors{pq_centers, new_vectors, codebook_kind, cluster_ix}; - write_list( - list_data, offset_or_indices, new_vectors.extent(0), pq_dim, encode_action); -} - -template -void encode_list_data(raft::resources const& res, - index* index, - device_matrix_view new_vectors, - uint32_t label, - std::variant offset_or_indices) -{ - auto n_rows = new_vectors.extent(0); - if (n_rows == 0) { return; } - - auto mr = resource::get_workspace_resource(res); - - auto new_vectors_residual = - make_device_mdarray(res, mr, make_extents(n_rows, index->rot_dim())); - - flat_compute_residuals(res, - new_vectors_residual.data_handle(), - n_rows, - index->rotation_matrix(), - index->centers(), - new_vectors.data_handle(), - label, - mr); - - constexpr uint32_t kBlockSize = 256; - const uint32_t threads_per_vec = std::min(WarpSize, index->pq_book_size()); - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [](uint32_t pq_bits) { - switch (pq_bits) { - case 4: return encode_list_data_kernel; - case 5: return encode_list_data_kernel; - case 6: return encode_list_data_kernel; - case 7: return encode_list_data_kernel; - case 8: return encode_list_data_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(index->pq_bits()); - kernel<<>>(index->lists()[label]->data.view(), - new_vectors_residual.view(), - index->pq_centers(), - index->codebook_kind(), - label, - offset_or_indices); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** - * Assuming the index already has some data and allocated the space for more, write more data in it. - * There must be enough free space in `pq_dataset()` and `indices()`, as computed using - * `list_offsets()` and `list_sizes()`. - * - * NB: Since the pq_dataset is stored in the interleaved blocked format (see ivf_pq_types.hpp), one - * cannot just concatenate the old and the new codes; the positions for the codes are determined the - * same way as in the ivfpq_compute_similarity_kernel (see ivf_pq_search.cuh). - * - * @tparam T - * @tparam IdxT - * - * @param handle - * @param index - * @param[in] new_vectors - * a pointer to a row-major device array [index.dim(), n_rows]; - * @param[in] src_offset_or_indices - * references for the new data: - * either a starting index for the auto-indexing - * or a pointer to a device array of explicit indices [n_rows]; - * @param[in] new_labels - * cluster ids (first-level quantization) - a device array [n_rows]; - * @param n_rows - * the number of records to write in. - * @param mr - * a memory resource to use for device allocations - */ -template -void process_and_fill_codes(raft::resources const& handle, - index& index, - const T* new_vectors, - std::variant src_offset_or_indices, - const uint32_t* new_labels, - IdxT n_rows, - rmm::device_async_resource_ref mr) -{ - auto new_vectors_residual = - make_device_mdarray(handle, mr, make_extents(n_rows, index.rot_dim())); - - flat_compute_residuals(handle, - new_vectors_residual.data_handle(), - n_rows, - index.rotation_matrix(), - index.centers(), - new_vectors, - new_labels, - mr); - - constexpr uint32_t kBlockSize = 256; - const uint32_t threads_per_vec = std::min(WarpSize, index.pq_book_size()); - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); - dim3 threads(kBlockSize, 1, 1); - auto kernel = [](uint32_t pq_bits) { - switch (pq_bits) { - case 4: return process_and_fill_codes_kernel; - case 5: return process_and_fill_codes_kernel; - case 6: return process_and_fill_codes_kernel; - case 7: return process_and_fill_codes_kernel; - case 8: return process_and_fill_codes_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(index.pq_bits()); - kernel<<>>(new_vectors_residual.view(), - src_offset_or_indices, - new_labels, - index.list_sizes(), - index.inds_ptrs(), - index.data_ptrs(), - index.pq_centers(), - index.codebook_kind()); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** - * Helper function: allocate enough space in the list, compute the offset, at which to start - * writing, and fill-in indices. - * - * @return offset for writing the data - */ -template -auto extend_list_prepare(raft::resources const& res, - index* index, - device_vector_view new_indices, - uint32_t label) -> uint32_t -{ - uint32_t n_rows = new_indices.extent(0); - uint32_t offset; - // Allocate the lists to fit the new data - copy(&offset, index->list_sizes().data_handle() + label, 1, resource::get_cuda_stream(res)); - resource::sync_stream(res); - uint32_t new_size = offset + n_rows; - copy(index->list_sizes().data_handle() + label, &new_size, 1, resource::get_cuda_stream(res)); - auto spec = list_spec{ - index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; - auto& list = index->lists()[label]; - ivf::resize_list(res, list, spec, new_size, offset); - copy(list->indices.data_handle() + offset, - new_indices.data_handle(), - n_rows, - resource::get_cuda_stream(res)); - return offset; -} - -/** - * Extend one list of the index in-place, by the list label, skipping the classification and - * encoding steps. - * See the public interface for the api and usage. - */ -template -void extend_list_with_codes(raft::resources const& res, - index* index, - device_matrix_view new_codes, - device_vector_view new_indices, - uint32_t label) -{ - // Allocate memory and write indices - auto offset = extend_list_prepare(res, index, new_indices, label); - // Pack the data - pack_list_data(res, index, new_codes, label, offset); - // Update the pointers and the sizes - ivf::detail::recompute_internal_state(res, *index); -} - -/** - * Extend one list of the index in-place, by the list label, skipping the classification step. - * See the public interface for the api and usage. - */ -template -void extend_list(raft::resources const& res, - index* index, - device_matrix_view new_vectors, - device_vector_view new_indices, - uint32_t label) -{ - // Allocate memory and write indices - auto offset = extend_list_prepare(res, index, new_indices, label); - // Encode the data - encode_list_data(res, index, new_vectors, label, offset); - // Update the pointers and the sizes - ivf::detail::recompute_internal_state(res, *index); -} - -/** - * Remove all data from a single list. - * See the public interface for the api and usage. - */ -template -void erase_list(raft::resources const& res, index* index, uint32_t label) -{ - uint32_t zero = 0; - copy(index->list_sizes().data_handle() + label, &zero, 1, resource::get_cuda_stream(res)); - index->lists()[label].reset(); - ivf::detail::recompute_internal_state(res, *index); -} - -/** Copy the state of an index into a new index, but share the list data among the two. */ -template -auto clone(const raft::resources& res, const index& source) -> index -{ - auto stream = resource::get_cuda_stream(res); - - // Allocate the new index - index target(res, - source.metric(), - source.codebook_kind(), - source.n_lists(), - source.dim(), - source.pq_bits(), - source.pq_dim()); - - // Copy the independent parts - copy(target.list_sizes().data_handle(), - source.list_sizes().data_handle(), - source.list_sizes().size(), - stream); - copy(target.rotation_matrix().data_handle(), - source.rotation_matrix().data_handle(), - source.rotation_matrix().size(), - stream); - copy(target.pq_centers().data_handle(), - source.pq_centers().data_handle(), - source.pq_centers().size(), - stream); - copy(target.centers().data_handle(), - source.centers().data_handle(), - source.centers().size(), - stream); - copy(target.centers_rot().data_handle(), - source.centers_rot().data_handle(), - source.centers_rot().size(), - stream); - - // Copy shared pointers - target.lists() = source.lists(); - - // Make sure the device pointers point to the new lists - ivf::detail::recompute_internal_state(res, target); - - return target; -} - -/** - * Extend the index in-place. - * See raft::spatial::knn::ivf_pq::extend docs. - */ -template -void extend(raft::resources const& handle, - index* index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -{ - common::nvtx::range fun_scope( - "ivf_pq::extend(%zu, %u)", size_t(n_rows), index->dim()); - - auto stream = resource::get_cuda_stream(handle); - const auto n_clusters = index->n_lists(); - - RAFT_EXPECTS(new_indices != nullptr || index->size() == 0, - "You must pass data indices when the index is non-empty."); - - static_assert(std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v, - "Unsupported data type"); - - rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(handle); - - // The spec defines how the clusters look like - auto spec = list_spec{ - index->pq_bits(), index->pq_dim(), index->conservative_memory_allocation()}; - // Try to allocate an index with the same parameters and the projected new size - // (which can be slightly larger than index->size() + n_rows, due to padding). - // If this fails, the index would be too big to fit in the device anyway. - std::optional> placeholder_list( - std::in_place_t{}, - handle, - list_spec{spec}, - n_rows + (kIndexGroupSize - 1) * std::min(n_clusters, n_rows)); - - // Available device memory - size_t free_mem, total_mem; - RAFT_CUDA_TRY(cudaMemGetInfo(&free_mem, &total_mem)); - - // Allocate a buffer for the new labels (classifying the new data) - rmm::device_uvector new_data_labels(n_rows, stream, device_memory); - free_mem -= sizeof(uint32_t) * n_rows; - - // Calculate the batch size for the input data if it's not accessible directly from the device - constexpr size_t kReasonableMaxBatchSize = 65536; - size_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); - { - size_t size_factor = 0; - // we'll use two temporary buffers for converted inputs when computing the codes. - size_factor += (index->dim() + index->rot_dim()) * sizeof(float); - // ...and another buffer for indices - size_factor += sizeof(IdxT); - // if the input data is not accessible on device, we'd need a buffer for it. - switch (utils::check_pointer_residency(new_vectors)) { - case utils::pointer_residency::device_only: - case utils::pointer_residency::host_and_device: break; - default: size_factor += index->dim() * sizeof(T); - } - // the same with indices - if (new_indices != nullptr) { - switch (utils::check_pointer_residency(new_indices)) { - case utils::pointer_residency::device_only: - case utils::pointer_residency::host_and_device: break; - default: size_factor += sizeof(IdxT); - } - } - // make the batch size fit into the remaining memory - while (size_factor * max_batch_size > free_mem && max_batch_size > 128) { - max_batch_size >>= 1; - } - // If we're keeping the batches in device memory, update the available mem tracker. - free_mem -= size_factor * max_batch_size; - } - - // Predict the cluster labels for the new data, in batches if necessary - utils::batch_load_iterator vec_batches( - new_vectors, n_rows, index->dim(), max_batch_size, stream, device_memory); - // Release the placeholder memory, because we don't intend to allocate any more long-living - // temporary buffers before we allocate the index data. - // This memory could potentially speed up UVM accesses, if any. - placeholder_list.reset(); - { - // The cluster centers in the index are stored padded, which is not acceptable by - // the kmeans_balanced::predict. Thus, we need the restructuring copy. - rmm::device_uvector cluster_centers( - size_t(n_clusters) * size_t(index->dim()), stream, device_memory); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(cluster_centers.data(), - sizeof(float) * index->dim(), - index->centers().data_handle(), - sizeof(float) * index->dim_ext(), - sizeof(float) * index->dim(), - n_clusters, - cudaMemcpyDefault, - stream)); - for (const auto& batch : vec_batches) { - auto batch_data_view = raft::make_device_matrix_view( - batch.data(), batch.size(), index->dim()); - auto batch_labels_view = raft::make_device_vector_view( - new_data_labels.data() + batch.offset(), batch.size()); - auto centers_view = raft::make_device_matrix_view( - cluster_centers.data(), n_clusters, index->dim()); - raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.metric = index->metric(); - raft::cluster::kmeans_balanced::predict(handle, - kmeans_params, - batch_data_view, - centers_view, - batch_labels_view, - utils::mapping{}); - } - } - - auto list_sizes = index->list_sizes().data_handle(); - // store the current cluster sizes, because we'll need them later - rmm::device_uvector orig_list_sizes(n_clusters, stream, device_memory); - copy(orig_list_sizes.data(), list_sizes, n_clusters, stream); - - // Get the combined cluster sizes - raft::stats::histogram(raft::stats::HistTypeAuto, - reinterpret_cast(list_sizes), - IdxT(n_clusters), - new_data_labels.data(), - n_rows, - 1, - stream); - linalg::add(list_sizes, list_sizes, orig_list_sizes.data(), n_clusters, stream); - - // Allocate the lists to fit the new data - { - std::vector new_cluster_sizes(n_clusters); - std::vector old_cluster_sizes(n_clusters); - copy(new_cluster_sizes.data(), list_sizes, n_clusters, stream); - copy(old_cluster_sizes.data(), orig_list_sizes.data(), n_clusters, stream); - resource::sync_stream(handle); - for (uint32_t label = 0; label < n_clusters; label++) { - ivf::resize_list( - handle, index->lists()[label], spec, new_cluster_sizes[label], old_cluster_sizes[label]); - } - } - - // Update the pointers and the sizes - ivf::detail::recompute_internal_state(handle, *index); - - // Recover old cluster sizes: they are used as counters in the fill-codes kernel - copy(list_sizes, orig_list_sizes.data(), n_clusters, stream); - - // By this point, the index state is updated and valid except it doesn't contain the new data - // Fill the extended index with the new data (possibly, in batches) - utils::batch_load_iterator idx_batches( - new_indices, n_rows, 1, max_batch_size, stream, device_memory); - for (const auto& vec_batch : vec_batches) { - const auto& idx_batch = *idx_batches++; - process_and_fill_codes(handle, - *index, - vec_batch.data(), - new_indices != nullptr - ? std::variant(idx_batch.data()) - : std::variant(IdxT(idx_batch.offset())), - new_data_labels.data() + vec_batch.offset(), - IdxT(vec_batch.size()), - device_memory); - } -} - -/** - * Create a new index that contains more data. - * See raft::spatial::knn::ivf_pq::extend docs. - */ -template -auto extend(raft::resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index -{ - auto ext_index = clone(handle, orig_index); - detail::extend(handle, &ext_index, new_vectors, new_indices, n_rows); - return ext_index; -} - -/** See raft::spatial::knn::ivf_pq::build docs */ -template -auto build(raft::resources const& handle, - const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index -{ - common::nvtx::range fun_scope( - "ivf_pq::build(%zu, %u)", size_t(n_rows), dim); - static_assert(std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v, - "Unsupported data type"); - - RAFT_EXPECTS(n_rows > 0 && dim > 0, "empty dataset"); - RAFT_EXPECTS(n_rows >= params.n_lists, "number of rows can't be less than n_lists"); - - auto stream = resource::get_cuda_stream(handle); - - index index(handle, params, dim); - utils::memzero( - index.accum_sorted_sizes().data_handle(), index.accum_sorted_sizes().size(), stream); - utils::memzero(index.list_sizes().data_handle(), index.list_sizes().size(), stream); - utils::memzero(index.data_ptrs().data_handle(), index.data_ptrs().size(), stream); - utils::memzero(index.inds_ptrs().data_handle(), index.inds_ptrs().size(), stream); - - { - auto trainset_ratio = std::max( - 1, - size_t(n_rows) / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); - size_t n_rows_train = n_rows / trainset_ratio; - - auto* device_memory = resource::get_workspace_resource(handle); - rmm::mr::managed_memory_resource managed_memory_upstream; - - // Besides just sampling, we transform the input dataset into floats to make it easier - // to use gemm operations from cublas. - rmm::device_uvector trainset(n_rows_train * index.dim(), stream, device_memory); - // TODO: a proper sampling - if constexpr (std::is_same_v) { - RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), - sizeof(T) * index.dim(), - dataset, - sizeof(T) * index.dim() * trainset_ratio, - sizeof(T) * index.dim(), - n_rows_train, - cudaMemcpyDefault, - stream)); - } else { - size_t dim = index.dim(); - cudaPointerAttributes dataset_attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&dataset_attr, dataset)); - if (dataset_attr.devicePointer != nullptr) { - // data is available on device: just run the kernel to copy and map the data - auto p = reinterpret_cast(dataset_attr.devicePointer); - auto trainset_view = - raft::make_device_vector_view(trainset.data(), dim * n_rows_train); - linalg::map_offset(handle, trainset_view, [p, trainset_ratio, dim] __device__(size_t i) { - auto col = i % dim; - return utils::mapping{}(p[(i - col) * size_t(trainset_ratio) + col]); - }); - } else { - // data is not available: first copy, then map inplace - auto trainset_tmp = reinterpret_cast(reinterpret_cast(trainset.data()) + - (sizeof(float) - sizeof(T)) * index.dim()); - // We copy the data in strides, one row at a time, and place the smaller rows of type T - // at the end of float rows. - RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset_tmp, - sizeof(float) * index.dim(), - dataset, - sizeof(T) * index.dim() * trainset_ratio, - sizeof(T) * index.dim(), - n_rows_train, - cudaMemcpyDefault, - stream)); - // Transform the input `{T -> float}`, one row per warp. - // The threads in each warp copy the data synchronously; this and the layout of the data - // (content is aligned to the end of the rows) together allow doing the transform in-place. - copy_warped(trainset.data(), - index.dim(), - trainset_tmp, - index.dim() * sizeof(float) / sizeof(T), - index.dim(), - n_rows_train, - stream); - } - } - - // NB: here cluster_centers is used as if it is [n_clusters, data_dim] not [n_clusters, - // dim_ext]! - rmm::device_uvector cluster_centers_buf( - index.n_lists() * index.dim(), stream, device_memory); - auto cluster_centers = cluster_centers_buf.data(); - - // Train balanced hierarchical kmeans clustering - auto trainset_const_view = raft::make_device_matrix_view( - trainset.data(), n_rows_train, index.dim()); - auto centers_view = raft::make_device_matrix_view( - cluster_centers, index.n_lists(), index.dim()); - raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = index.metric(); - raft::cluster::kmeans_balanced::fit( - handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); - - // Trainset labels are needed for training PQ codebooks - rmm::device_uvector labels(n_rows_train, stream, device_memory); - auto centers_const_view = raft::make_device_matrix_view( - cluster_centers, index.n_lists(), index.dim()); - auto labels_view = - raft::make_device_vector_view(labels.data(), n_rows_train); - raft::cluster::kmeans_balanced::predict(handle, - kmeans_params, - trainset_const_view, - centers_const_view, - labels_view, - utils::mapping()); - - // Make rotation matrix - make_rotation_matrix(handle, - params.force_random_rotation, - index.rot_dim(), - index.dim(), - index.rotation_matrix().data_handle()); - - set_centers(handle, &index, cluster_centers); - - // Train PQ codebooks - switch (index.codebook_kind()) { - case codebook_gen::PER_SUBSPACE: - train_per_subset(handle, - index, - n_rows_train, - trainset.data(), - labels.data(), - params.kmeans_n_iters, - &managed_memory_upstream); - break; - case codebook_gen::PER_CLUSTER: - train_per_cluster(handle, - index, - n_rows_train, - trainset.data(), - labels.data(), - params.kmeans_n_iters, - &managed_memory_upstream); - break; - default: RAFT_FAIL("Unreachable code"); - } - } - - // add the data if necessary - if (params.add_data_on_build) { - detail::extend(handle, &index, dataset, nullptr, n_rows); - } - return index; -} -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_codepacking.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_codepacking.cuh deleted file mode 100644 index bd03409f66..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_pq_codepacking.cuh +++ /dev/null @@ -1,213 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -/** A chunk of PQ-encoded vector managed by one CUDA thread. */ -using pq_vec_t = TxN_t::io_t; - -/** - * This type mimics the `uint8_t&` for the indexing operator of `bitfield_view_t`. - * - * @tparam Bits number of bits comprising the value. - */ -template -struct bitfield_ref_t { - static_assert(Bits <= 8 && Bits > 0, "Bit code must fit one byte"); - constexpr static uint8_t kMask = static_cast((1u << Bits) - 1u); - uint8_t* ptr; - uint32_t offset; - - constexpr operator uint8_t() // NOLINT - { - auto pair = static_cast(ptr[0]); - if (offset + Bits > 8) { pair |= static_cast(ptr[1]) << 8; } - return static_cast((pair >> offset) & kMask); - } - - constexpr auto operator=(uint8_t code) -> bitfield_ref_t& - { - if (offset + Bits > 8) { - auto pair = static_cast(ptr[0]); - pair |= static_cast(ptr[1]) << 8; - pair &= ~(static_cast(kMask) << offset); - pair |= static_cast(code) << offset; - ptr[0] = static_cast(Pow2<256>::mod(pair)); - ptr[1] = static_cast(Pow2<256>::div(pair)); - } else { - ptr[0] = (ptr[0] & ~(kMask << offset)) | (code << offset); - } - return *this; - } -}; - -/** - * View a byte array as an array of unsigned integers of custom small bit size. - * - * @tparam Bits number of bits comprising a single element of the array. - */ -template -struct bitfield_view_t { - static_assert(Bits <= 8 && Bits > 0, "Bit code must fit one byte"); - uint8_t* raw; - - constexpr auto operator[](uint32_t i) -> bitfield_ref_t - { - uint32_t bit_offset = i * Bits; - return bitfield_ref_t{raw + Pow2<8>::div(bit_offset), Pow2<8>::mod(bit_offset)}; - } -}; - -/** - * Process a single vector in a list. - * - * @tparam PqBits - * @tparam Action tells how to process a single vector (e.g. reconstruct or just unpack) - * - * @param[in] in_list_data the encoded cluster data. - * @param[in] in_ix in-cluster index of the vector to be decoded (one-per-thread). - * @param[in] out_ix the output index passed to the action - * @param[in] pq_dim - * @param action a callable action to be invoked on each PQ code (component of the encoding) - * type: void (uint8_t code, uint32_t out_ix, uint32_t j), where j = [0..pq_dim). - */ -template -__device__ void run_on_vector( - device_mdspan::list_extents, row_major> in_list_data, - uint32_t in_ix, - uint32_t out_ix, - uint32_t pq_dim, - Action action) -{ - using group_align = Pow2; - const uint32_t group_ix = group_align::div(in_ix); - const uint32_t ingroup_ix = group_align::mod(in_ix); - - pq_vec_t code_chunk; - bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; - constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; - for (uint32_t j = 0, i = 0; j < pq_dim; i++) { - // read the chunk - code_chunk = *reinterpret_cast(&in_list_data(group_ix, i, ingroup_ix, 0)); - // read the codes, one/pq_dim at a time -#pragma unroll - for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { - // read a piece of the reconstructed vector - action(code_view[k], out_ix, j); - } - } -} - -/** - * Process a single vector in a list. - * - * @tparam PqBits - * @tparam SubWarpSize how many threads work on the same ix (only the first thread writes data). - * @tparam IdxT type of the index passed to the action - * @tparam Action tells how to process a single vector (e.g. encode or just pack) - * - * @param[in] out_list_data the encoded cluster data. - * @param[in] out_ix in-cluster index of the vector to be processed (one-per-SubWarpSize threads). - * @param[in] in_ix the input index passed to the action (one-per-SubWarpSize threads). - * @param[in] pq_dim - * @param action a callable action to be invoked on each PQ code (component of the encoding) - * type: (uint32_t in_ix, uint32_t j) -> uint8_t, where j = [0..pq_dim). - */ -template -__device__ void write_vector( - device_mdspan::list_extents, row_major> out_list_data, - uint32_t out_ix, - IdxT in_ix, - uint32_t pq_dim, - Action action) -{ - const uint32_t lane_id = Pow2::mod(threadIdx.x); - - using group_align = Pow2; - const uint32_t group_ix = group_align::div(out_ix); - const uint32_t ingroup_ix = group_align::mod(out_ix); - - pq_vec_t code_chunk; - bitfield_view_t code_view{reinterpret_cast(&code_chunk)}; - constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits; - for (uint32_t j = 0, i = 0; j < pq_dim; i++) { - // clear the chunk - if (lane_id == 0) { code_chunk = pq_vec_t{}; } - // write the codes, one/pq_dim at a time -#pragma unroll - for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) { - // write a single code - uint8_t code = action(in_ix, j); - if (lane_id == 0) { code_view[k] = code; } - } - // write the chunk to the list - if (lane_id == 0) { - *reinterpret_cast(&out_list_data(group_ix, i, ingroup_ix, 0)) = code_chunk; - } - } -} - -/** Process the given indices or a block of a single list (cluster). */ -template -__device__ void run_on_list( - device_mdspan::list_extents, row_major> in_list_data, - std::variant offset_or_indices, - uint32_t len, - uint32_t pq_dim, - Action action) -{ - for (uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; ix < len; ix += blockDim.x) { - const uint32_t src_ix = std::holds_alternative(offset_or_indices) - ? std::get(offset_or_indices) + ix - : std::get(offset_or_indices)[ix]; - run_on_vector(in_list_data, src_ix, ix, pq_dim, action); - } -} - -/** Process the given indices or a block of a single list (cluster). */ -template -__device__ void write_list( - device_mdspan::list_extents, row_major> out_list_data, - std::variant offset_or_indices, - uint32_t len, - uint32_t pq_dim, - Action action) -{ - using subwarp_align = Pow2; - uint32_t stride = subwarp_align::div(blockDim.x); - uint32_t ix = subwarp_align::div(threadIdx.x + blockDim.x * blockIdx.x); - for (; ix < len; ix += stride) { - const uint32_t dst_ix = std::holds_alternative(offset_or_indices) - ? std::get(offset_or_indices) + ix - : std::get(offset_or_indices)[ix]; - write_vector(out_list_data, dst_ix, ix, pq_dim, action); - } -} - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh deleted file mode 100644 index 5e1a9b46d6..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh +++ /dev/null @@ -1,220 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // RAFT_WEAK_FUNCTION -#include // raft::distance::DistanceType -#include // raft::neighbors::ivf_pq::detail::fp_8bit -#include // raft::neighbors::ivf_pq::codebook_gen -#include // none_ivf_sample_filter -#include // RAFT_EXPLICIT - -#include // rmm::cuda_stream_view - -#include // __half - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft::neighbors::ivf_pq::detail { - -// is_local_topk_feasible is not inline here, because we would have to define it -// here as well. That would run the risk of the definitions here and in the -// -inl.cuh header diverging. -auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) - -> bool; - -template -RAFT_KERNEL compute_similarity_kernel(uint32_t dim, - uint32_t n_probes, - uint32_t pq_dim, - uint32_t n_queries, - uint32_t queries_offset, - distance::DistanceType metric, - codebook_gen codebook_kind, - uint32_t topk, - uint32_t max_samples, - const float* cluster_centers, - const float* pq_centers, - const uint8_t* const* pq_dataset, - const uint32_t* cluster_labels, - const uint32_t* _chunk_indices, - const float* queries, - const uint32_t* index_list, - float* query_kths, - IvfSampleFilterT sample_filter, - LutT* lut_scores, - OutT* _out_scores, - uint32_t* _out_indices) RAFT_EXPLICIT; - -// The signature of the kernel defined by a minimal set of template parameters -template -using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); - -template -struct selected { - compute_similarity_kernel_t kernel; - dim3 grid_dim; - dim3 block_dim; - size_t smem_size; - size_t device_lut_size; -}; - -template -void compute_similarity_run(selected s, - rmm::cuda_stream_view stream, - uint32_t dim, - uint32_t n_probes, - uint32_t pq_dim, - uint32_t n_queries, - uint32_t queries_offset, - distance::DistanceType metric, - codebook_gen codebook_kind, - uint32_t topk, - uint32_t max_samples, - const float* cluster_centers, - const float* pq_centers, - const uint8_t* const* pq_dataset, - const uint32_t* cluster_labels, - const uint32_t* _chunk_indices, - const float* queries, - const uint32_t* index_list, - float* query_kths, - IvfSampleFilterT sample_filter, - LutT* lut_scores, - OutT* _out_scores, - uint32_t* _out_indices) RAFT_EXPLICIT; - -/** - * Use heuristics to choose an optimal instance of the search kernel. - * It selects among a few kernel variants (with/out using shared mem for - * lookup tables / precomputed distances) and tries to choose the block size - * to maximize kernel occupancy. - * - * @param manage_local_topk - * whether use the fused calculate+select or just calculate the distances for each - * query and probed cluster. - * - * @param locality_hint - * beyond this limit do not consider increasing the number of active blocks per SM - * would improve locality anymore. - */ -template -auto compute_similarity_select(const cudaDeviceProp& dev_props, - bool manage_local_topk, - int locality_hint, - double preferred_shmem_carveout, - uint32_t pq_bits, - uint32_t pq_dim, - uint32_t precomp_data_count, - uint32_t n_queries, - uint32_t n_probes, - uint32_t topk) - -> selected RAFT_EXPLICIT; - -} // namespace raft::neighbors::ivf_pq::detail - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - extern template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - extern template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, - raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, - raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - half, - half, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, - half, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, - float, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, - raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); -instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( - float, - raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>, - raft::neighbors::filtering::ivf_to_sample_filter< - int64_t COMMA raft::neighbors::filtering::none_ivf_sample_filter>); - -#undef COMMA - -#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh deleted file mode 100644 index 462d134b8e..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh +++ /dev/null @@ -1,942 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::distance::DistanceType -#include // matrix::detail::select::warpsort::warp_sort_distributed -#include // dummy_block_sort_t -#include // codebook_gen -#include // none_ivf_sample_filter -#include // RAFT_CUDA_TRY -#include // raft::atomicMin -#include // raft::Pow2 -#include // raft::TxN_t - -#include // rmm::cuda_stream_view - -namespace raft::neighbors::ivf_pq::detail { - -/** - * Maximum value of k for the fused calculate & select in ivfpq. - * - * If runtime value of k is larger than this, the main search operation - * is split into two kernels (per batch, first calculate distance, then select top-k). - */ -static constexpr int kMaxCapacity = 128; -static_assert((kMaxCapacity >= 32) && !(kMaxCapacity & (kMaxCapacity - 1)), - "kMaxCapacity must be a power of two, not smaller than the WarpSize."); - -// using weak attribute here, because it may be compiled multiple times. -auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) - -> bool -{ - if (k > kMaxCapacity) { return false; } // warp_sort not possible - if (n_queries * n_probes <= 16) { return false; } // overall amount of work is too small - return true; -} - -template -struct pq_block_sort { - using type = matrix::detail::select::warpsort::block_sort< - matrix::detail::select::warpsort::warp_sort_distributed_ext, - Capacity, - true, - T, - IdxT>; - - static auto get_mem_required(uint32_t k_max) - { - if (k_max == 0 || k_max > Capacity) { - return pq_block_sort<0, T, IdxT>::get_mem_required(k_max); - } - if constexpr (Capacity > 1) { - if (k_max * 2 <= Capacity) { - return pq_block_sort<(Capacity / 2), T, IdxT>::get_mem_required(k_max); - } - } - return type::queue_t::mem_required; - } -}; - -template -struct pq_block_sort<0, T, IdxT> : ivf::detail::dummy_block_sort_t { - using type = ivf::detail::dummy_block_sort_t; - static auto mem_required(uint32_t) -> size_t { return 0; } - static auto get_mem_required(uint32_t) { return mem_required; } -}; - -template -using block_sort_t = typename pq_block_sort::type; - -/** - * Estimate a carveout value as expected by `cudaFuncAttributePreferredSharedMemoryCarveout` - * (which does not take into account `reservedSharedMemPerBlock`), - * given by a desired schmem-L1 split and a per-block memory requirement in bytes. - * - * NB: As per the programming guide, the memory carveout setting is just a hint for the driver; it's - * free to choose any shmem-L1 configuration it deems appropriate. For example, if you set the - * carveout to zero, it will choose a non-zero config that will allow to run at least one active - * block per SM. - * - * @param shmem_fraction - * a fraction representing a desired split (shmem / (shmem + L1)) [0, 1]. - * @param shmem_per_block - * a shared memory usage per block (dynamic + static shared memory sizes), in bytes. - * @param dev_props - * device properties. - * @return - * a carveout value in percents [0, 100]. - */ -constexpr inline auto estimate_carveout(double shmem_fraction, - size_t shmem_per_block, - const cudaDeviceProp& dev_props) -> int -{ - using shmem_unit = Pow2<128>; - size_t m = shmem_unit::roundUp(shmem_per_block); - size_t r = dev_props.reservedSharedMemPerBlock; - size_t s = dev_props.sharedMemPerMultiprocessor; - return (size_t(100 * s * m * shmem_fraction) - (m - 1) * r) / (s * (m + r)); -} - -/* Manually unrolled loop over a chunk of pq_dataset that fits into one VecT. */ -template -__device__ __forceinline__ void ivfpq_compute_chunk(OutT& score /* NOLINT */, - typename VecT::math_t& pq_code, - const VecT& pq_codes, - const LutT*& lut_head, - const LutT*& lut_end) -{ - if constexpr (CheckBounds) { - if (lut_head >= lut_end) { return; } - } - constexpr uint32_t kTotalBits = 8 * sizeof(typename VecT::math_t); - constexpr uint32_t kPqShift = 1u << PqBits; - constexpr uint32_t kPqMask = kPqShift - 1u; - if constexpr (BitsLeft >= PqBits) { - uint8_t code = pq_code & kPqMask; - pq_code >>= PqBits; - score += OutT(lut_head[code]); - lut_head += kPqShift; - return ivfpq_compute_chunk( - score, pq_code, pq_codes, lut_head, lut_end); - } else if constexpr (Ix < VecT::Ratio) { - uint8_t code = pq_code; - pq_code = pq_codes.val.data[Ix]; - constexpr uint32_t kRemBits = PqBits - BitsLeft; - constexpr uint32_t kRemMask = (1u << kRemBits) - 1u; - code |= (pq_code & kRemMask) << BitsLeft; - pq_code >>= kRemBits; - score += OutT(lut_head[code]); - lut_head += kPqShift; - return ivfpq_compute_chunk(score, pq_code, pq_codes, lut_head, lut_end); - } -} - -/* Compute the similarity for one vector in the pq_dataset */ -template -__device__ auto ivfpq_compute_score(uint32_t pq_dim, - const typename VecT::io_t* pq_head, - const LutT* lut_scores, - OutT early_stop_limit) -> OutT -{ - constexpr uint32_t kChunkSize = sizeof(VecT) * 8u / PqBits; - auto lut_head = lut_scores; - auto lut_end = lut_scores + (pq_dim << PqBits); - VecT pq_codes; - OutT score{0}; - for (; pq_dim >= kChunkSize; pq_dim -= kChunkSize) { - *pq_codes.vectorized_data() = *pq_head; - pq_head += kIndexGroupSize; - typename VecT::math_t pq_code = 0; - ivfpq_compute_chunk( - score, pq_code, pq_codes, lut_head, lut_end); - // Early stop when it makes sense (otherwise early_stop_limit is kDummy/infinity). - if (score >= early_stop_limit) { return score; } - } - if (pq_dim > 0) { - *pq_codes.vectorized_data() = *pq_head; - typename VecT::math_t pq_code = 0; - ivfpq_compute_chunk( - score, pq_code, pq_codes, lut_head, lut_end); - } - return score; -} - -/** - * The main kernel that computes similarity scores across multiple queries and probes. - * When `Capacity > 0`, it also selects top K candidates for each query and probe - * (which need to be merged across probes afterwards). - * - * Each block processes a (query, probe) pair: it calculates the distance between the single query - * vector and all the dataset vector in the cluster that we are probing. - * - * @tparam OutT - * The output type - distances. - * @tparam LutT - * The lookup table element type (lut_scores). - * @tparam PqBits - * The bit length of an encoded vector element after compression by PQ - * (NB: pq_book_size = 1 << PqBits). - * @tparam Capacity - * Power-of-two; the maximum possible `k` in top-k. Value zero disables fused top-k search. - * @tparam PrecompBaseDiff - * Defines whether we should precompute part of the distance and keep it in shared memory - * before the main part (score calculation) to increase memory usage efficiency in the latter. - * For L2, this is the distance between the query and the cluster center. - * @tparam EnableSMemLut - * Defines whether to use the shared memory for the lookup table (`lut_scores`). - * Setting this to `false` allows to reduce the shared memory usage (and maximum data dim) - * at the cost of reducing global memory reading throughput. - * - * @param dim the dimensionality of the data (NB: after rotation transform, i.e. `index.rot_dim()`). - * @param n_probes the number of clusters to search for each query - * @param pq_dim - * The dimensionality of an encoded vector after compression by PQ. - * @param n_queries the number of queries. - * @param queries_offset - * An offset of the current query batch. It is used for feeding sample_filter with the - * correct query index. - * @param metric the distance type. - * @param codebook_kind Defines the way PQ codebooks have been trained. - * @param topk the `k` in the select top-k. - * @param max_samples the size of the output for a single query. - * @param cluster_centers - * The device pointer to the cluster centers in the original space (NB: after rotation) - * [n_clusters, dim]. - * @param pq_centers - * The device pointer to the cluster centers in the PQ space - * [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len]. - * @param pq_dataset - * The device pointer to the PQ index (data) [n_rows, ...]. - * @param cluster_labels - * The device pointer to the labels (clusters) for each query and probe [n_queries, n_probes]. - * @param _chunk_indices - * The device pointer to the data offsets for each query and probe [n_queries, n_probes]. - * @param queries - * The device pointer to the queries (NB: after rotation) [n_queries, dim]. - * @param index_list - * An optional device pointer to the enforced order of search [n_queries, n_probes]. - * One can pass reordered indices here to try to improve data reading locality. - * @param query_kth - * query_kths keep the current state of the filtering - atomically updated distances to the - * k-th closest neighbors for each query [n_queries]. - * @param sample_filter - * A filter that selects samples for a given query. Use an instance of none_ivf_sample_filter to - * provide a green light for every sample. - * @param lut_scores - * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << PqBits]. - * Ignored when `EnableSMemLut == true`. - * @param _out_scores - * The device pointer to the output scores - * [n_queries, max_samples] or [n_queries, n_probes, topk]. - * @param _out_indices - * The device pointer to the output indices [n_queries, n_probes, topk]. - * These are the indices of the records as they appear in the database view formed by the probed - * clusters / defined by the `_chunk_indices`. - * The indices can have values within the range [0, max_samples). - * Ignored when `Capacity == 0`. - */ -template -RAFT_KERNEL compute_similarity_kernel(uint32_t dim, - uint32_t n_probes, - uint32_t pq_dim, - uint32_t n_queries, - uint32_t queries_offset, - distance::DistanceType metric, - codebook_gen codebook_kind, - uint32_t topk, - uint32_t max_samples, - const float* cluster_centers, - const float* pq_centers, - const uint8_t* const* pq_dataset, - const uint32_t* cluster_labels, - const uint32_t* _chunk_indices, - const float* queries, - const uint32_t* index_list, - float* query_kths, - IvfSampleFilterT sample_filter, - LutT* lut_scores, - OutT* _out_scores, - uint32_t* _out_indices) -{ - /* Shared memory: - - * lut_scores: lookup table (LUT) of size = `pq_dim << PqBits` (when EnableSMemLut) - * lut_end+: - * base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2 - * topk::warp_sort::mem_required - local topk temporary buffer (if necessary) - * topk::block_sort: some amount of shared memory, but overlaps with the rest: - block_sort only needs shared memory for `.done()` operation, which can come very last. - */ - extern __shared__ __align__(256) uint8_t smem_buf[]; // NOLINT - constexpr bool kManageLocalTopK = Capacity > 0; - - constexpr uint32_t PqShift = 1u << PqBits; // NOLINT - constexpr uint32_t PqMask = PqShift - 1u; // NOLINT - - const uint32_t pq_len = dim / pq_dim; - const uint32_t lut_size = pq_dim * PqShift; - - if constexpr (EnableSMemLut) { - lut_scores = reinterpret_cast(smem_buf); - } else { - lut_scores += lut_size * blockIdx.x; - } - - uint8_t* lut_end = nullptr; - if constexpr (EnableSMemLut) { - lut_end = reinterpret_cast(lut_scores + lut_size); - } else { - lut_end = smem_buf; - } - - for (int ib = blockIdx.x; ib < n_queries * n_probes; ib += gridDim.x) { - if (ib >= gridDim.x) { - // sync shared memory accesses on the second and further iterations - __syncthreads(); - } - uint32_t query_ix; - uint32_t probe_ix; - if (index_list == nullptr) { - query_ix = ib % n_queries; - probe_ix = ib / n_queries; - } else { - auto ordered_ix = index_list[ib]; - query_ix = ordered_ix / n_probes; - probe_ix = ordered_ix % n_probes; - } - - const uint32_t* chunk_indices = _chunk_indices + (n_probes * query_ix); - const float* query = queries + (dim * query_ix); - OutT* out_scores; - uint32_t* out_indices = nullptr; - if constexpr (kManageLocalTopK) { - // Store topk calculated distances to out_scores (and its indices to out_indices) - const uint64_t out_offset = probe_ix + n_probes * query_ix; - out_scores = _out_scores + out_offset * topk; - out_indices = _out_indices + out_offset * topk; - } else { - // Store all calculated distances to out_scores - out_scores = _out_scores + uint64_t(max_samples) * query_ix; - } - uint32_t label = cluster_labels[n_probes * query_ix + probe_ix]; - const float* cluster_center = cluster_centers + dim * label; - const float* pq_center; - if (codebook_kind == codebook_gen::PER_SUBSPACE) { - pq_center = pq_centers; - } else { - pq_center = pq_centers + (pq_len << PqBits) * label; - } - - if constexpr (PrecompBaseDiff) { - // Reduce number of memory reads later by pre-computing parts of the score - switch (metric) { - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2Expanded: { - for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { - reinterpret_cast(lut_end)[i] = query[i] - cluster_center[i]; - } - } break; - case distance::DistanceType::InnerProduct: { - float2 pvals; - for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { - pvals.x = query[i]; - pvals.y = cluster_center[i] * pvals.x; - reinterpret_cast(lut_end)[i] = pvals; - } - } break; - default: __builtin_unreachable(); - } - __syncthreads(); - } - - { - // Create a lookup table - // For each subspace, the lookup table stores the distance between the actual query vector - // (projected into the subspace) and all possible pq vectors in that subspace. - for (uint32_t i = threadIdx.x; i < lut_size; i += blockDim.x) { - const uint32_t i_pq = i >> PqBits; - uint32_t j = i_pq * pq_len; - const uint32_t j_end = pq_len + j; - auto cur_pq_center = pq_center + (i & PqMask) + - (codebook_kind == codebook_gen::PER_SUBSPACE ? j * PqShift : 0u); - float score = 0.0; - do { - float pq_c = *cur_pq_center; - cur_pq_center += PqShift; - switch (metric) { - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2Expanded: { - float diff; - if constexpr (PrecompBaseDiff) { - diff = reinterpret_cast(lut_end)[j]; - } else { - diff = query[j] - cluster_center[j]; - } - diff -= pq_c; - score += diff * diff; - } break; - case distance::DistanceType::InnerProduct: { - // NB: we negate the scores as we hardcoded select-topk to always compute the minimum - float q; - if constexpr (PrecompBaseDiff) { - float2 pvals = reinterpret_cast(lut_end)[j]; - q = pvals.x; - score -= pvals.y; - } else { - q = query[j]; - score -= q * cluster_center[j]; - } - score -= q * pq_c; - } break; - default: __builtin_unreachable(); - } - } while (++j < j_end); - lut_scores[i] = LutT(score); - } - } - - // Define helper types for efficient access to the pq_dataset, which is stored in an interleaved - // format. The chunks of PQ data are stored in kIndexGroupVecLen-bytes-long chunks, interleaved - // in groups of kIndexGroupSize elems (which is normally equal to the warp size) for the fastest - // possible access by thread warps. - // - // Consider one record in the pq_dataset is `pq_dim * pq_bits`-bit-long. - // Assuming `kIndexGroupVecLen = 16`, one chunk of data read by a thread at once is 128-bits. - // Then, such a chunk contains `chunk_size = 128 / pq_bits` record elements, and the record - // consists of `ceildiv(pq_dim, chunk_size)` chunks. The chunks are interleaved in groups of 32, - // so that the warp can achieve the best coalesced read throughput. - using group_align = Pow2; - using vec_align = Pow2; - using local_topk_t = block_sort_t; - using op_t = uint32_t; - using vec_t = TxN_t; - - uint32_t sample_offset = 0; - if (probe_ix > 0) { sample_offset = chunk_indices[probe_ix - 1]; } - uint32_t n_samples = chunk_indices[probe_ix] - sample_offset; - uint32_t n_samples_aligned = group_align::roundUp(n_samples); - constexpr uint32_t kChunkSize = (kIndexGroupVecLen * 8u) / PqBits; - uint32_t pq_line_width = div_rounding_up_unsafe(pq_dim, kChunkSize) * kIndexGroupVecLen; - auto pq_thread_data = pq_dataset[label] + group_align::roundDown(threadIdx.x) * pq_line_width + - group_align::mod(threadIdx.x) * vec_align::Value; - pq_line_width *= blockDim.x; - - constexpr OutT kDummy = upper_bound(); - OutT query_kth = kDummy; - if constexpr (kManageLocalTopK) { query_kth = OutT(query_kths[query_ix]); } - OutT early_stop_limit = kDummy; - switch (metric) { - // If the metric is non-negative, we can use the query_kth approximation as an early stop - // threshold to skip some iterations when computing the score. Add such metrics here. - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2Expanded: { - early_stop_limit = query_kth; - } break; - default: break; - } - - // Ensure lut_scores is written by all threads before using it in ivfpq-compute-score - __threadfence_block(); - __syncthreads(); - local_topk_t block_topk(topk, lut_end, query_kth); - - // Compute a distance for each sample - for (uint32_t i = threadIdx.x; i < n_samples_aligned; - i += blockDim.x, pq_thread_data += pq_line_width) { - OutT score = kDummy; - bool valid = i < n_samples; - // Check bounds and that the sample is acceptable for the query - if (valid && sample_filter(queries_offset + query_ix, label, i)) { - score = ivfpq_compute_score( - pq_dim, - reinterpret_cast(pq_thread_data), - lut_scores, - early_stop_limit); - } - if constexpr (kManageLocalTopK) { - block_topk.add(score, sample_offset + i); - } else { - if (valid) { out_scores[sample_offset + i] = score; } - } - } - if constexpr (kManageLocalTopK) { - // sync threads before the topk merging operation, because we reuse smem_buf - __syncthreads(); - block_topk.done(smem_buf); - block_topk.store(out_scores, out_indices); - if (threadIdx.x == 0) { atomicMin(query_kths + query_ix, float(out_scores[topk - 1])); } - } else { - // fill in the rest of the out_scores with dummy values - if (probe_ix + 1 == n_probes) { - for (uint32_t i = threadIdx.x + sample_offset + n_samples; i < max_samples; - i += blockDim.x) { - out_scores[i] = kDummy; - } - } - } - } -} - -// The signature of the kernel defined by a minimal set of template parameters -template -using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); - -// The config struct lifts the runtime parameters to the template parameters -template -struct compute_similarity_kernel_config { - public: - static auto get(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t - { - return kernel_choose_bits(pq_bits, k_max); - } - - private: - static auto kernel_choose_bits(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t - { - switch (pq_bits) { - case 4: return kernel_try_capacity<4, kMaxCapacity>(k_max); - case 5: return kernel_try_capacity<5, kMaxCapacity>(k_max); - case 6: return kernel_try_capacity<6, kMaxCapacity>(k_max); - case 7: return kernel_try_capacity<7, kMaxCapacity>(k_max); - case 8: return kernel_try_capacity<8, kMaxCapacity>(k_max); - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - } - - template - static auto kernel_try_capacity(uint32_t k_max) - -> compute_similarity_kernel_t - { - if constexpr (Capacity > 0) { - if (k_max == 0 || k_max > Capacity) { return kernel_try_capacity(k_max); } - } - if constexpr (Capacity > 1) { - if (k_max * 2 <= Capacity) { return kernel_try_capacity(k_max); } - } - return compute_similarity_kernel; - } -}; - -// A standalone accessor function was necessary to make sure template -// instantiation work correctly. This accessor function is not used anymore and -// may be removed. -template -auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t -{ - return compute_similarity_kernel_config::get(pq_bits, k_max); -} - -/** Estimate the occupancy for the given kernel on the given device. */ -template -struct occupancy_t { - using shmem_unit = Pow2<128>; - - int blocks_per_sm = 0; - double occupancy = 0.0; - double shmem_use = 1.0; - - inline occupancy_t() = default; - inline occupancy_t(size_t smem, - uint32_t n_threads, - compute_similarity_kernel_t kernel, - const cudaDeviceProp& dev_props) - { - RAFT_CUDA_TRY( - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, n_threads, smem)); - occupancy = double(blocks_per_sm * n_threads) / double(dev_props.maxThreadsPerMultiProcessor); - shmem_use = double(shmem_unit::roundUp(smem) * blocks_per_sm) / - double(dev_props.sharedMemPerMultiprocessor); - } -}; - -template -struct selected { - compute_similarity_kernel_t kernel; - dim3 grid_dim; - dim3 block_dim; - size_t smem_size; - size_t device_lut_size; -}; - -template -void compute_similarity_run(selected s, - rmm::cuda_stream_view stream, - uint32_t dim, - uint32_t n_probes, - uint32_t pq_dim, - uint32_t n_queries, - uint32_t queries_offset, - distance::DistanceType metric, - codebook_gen codebook_kind, - uint32_t topk, - uint32_t max_samples, - const float* cluster_centers, - const float* pq_centers, - const uint8_t* const* pq_dataset, - const uint32_t* cluster_labels, - const uint32_t* _chunk_indices, - const float* queries, - const uint32_t* index_list, - float* query_kths, - IvfSampleFilterT sample_filter, - LutT* lut_scores, - OutT* _out_scores, - uint32_t* _out_indices) -{ - s.kernel<<>>(dim, - n_probes, - pq_dim, - n_queries, - queries_offset, - metric, - codebook_kind, - topk, - max_samples, - cluster_centers, - pq_centers, - pq_dataset, - cluster_labels, - _chunk_indices, - queries, - index_list, - query_kths, - sample_filter, - lut_scores, - _out_scores, - _out_indices); - RAFT_CHECK_CUDA(stream); -} - -/** - * Use heuristics to choose an optimal instance of the search kernel. - * It selects among a few kernel variants (with/out using shared mem for - * lookup tables / precomputed distances) and tries to choose the block size - * to maximize kernel occupancy. - * - * @param manage_local_topk - * whether use the fused calculate+select or just calculate the distances for each - * query and probed cluster. - * - * @param locality_hint - * beyond this limit do not consider increasing the number of active blocks per SM - * would improve locality anymore. - */ -template -auto compute_similarity_select(const cudaDeviceProp& dev_props, - bool manage_local_topk, - int locality_hint, - double preferred_shmem_carveout, - uint32_t pq_bits, - uint32_t pq_dim, - uint32_t precomp_data_count, - uint32_t n_queries, - uint32_t n_probes, - uint32_t topk) -> selected -{ - // Shared memory for storing the lookup table - size_t lut_mem = sizeof(LutT) * (pq_dim << pq_bits); - // Shared memory for storing pre-computed pieces to speedup the lookup table construction - // (e.g. the distance between a cluster center and the query for L2). - size_t bdf_mem = sizeof(float) * precomp_data_count; - - // Shared memory used by the fused top-k during cluster scanning; - // may overlap with the precomputed distance array - struct ltk_add_mem_t { - size_t (*mem_required)(uint32_t); - - ltk_add_mem_t(bool manage_local_topk, uint32_t topk) - : mem_required(pq_block_sort::get_mem_required( - manage_local_topk ? topk : 0)) - { - } - - [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t - { - return mem_required(n_threads); - } - } ltk_add_mem{manage_local_topk, topk}; - - // Shared memory for the fused top-k component; - // may overlap with all other uses of shared memory - struct ltk_reduce_mem_t { - uint32_t subwarp_size; - uint32_t topk; - bool manage_local_topk; - ltk_reduce_mem_t(bool manage_local_topk, uint32_t topk) - : manage_local_topk(manage_local_topk), topk(topk) - { - subwarp_size = WarpSize; - while (topk * 2 <= subwarp_size) { - subwarp_size /= 2; - } - } - - [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t - { - return manage_local_topk - ? matrix::detail::select::warpsort::template calc_smem_size_for_block_wide( - n_threads / subwarp_size, topk) - : 0; - } - } ltk_reduce_mem{manage_local_topk, topk}; - - struct total_shared_mem_t { - ltk_add_mem_t& ltk_add_mem; - ltk_reduce_mem_t& ltk_reduce_mem; - size_t lut_mem; - size_t bdf_mem; - [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t - { - return std::max(ltk_reduce_mem(n_threads), - lut_mem + std::max(bdf_mem, ltk_add_mem(n_threads))); - } - }; - - // Total amount of work; should be enough to occupy the GPU. - uint32_t n_blocks = n_queries * n_probes; - - // The minimum block size we may want: - // 1. It's a power-of-two for efficient L1 caching of pq_centers values - // (multiples of `1 << pq_bits`). - // 2. It should be large enough to fully utilize an SM. - uint32_t n_threads_min = WarpSize; - while (dev_props.maxBlocksPerMultiProcessor * int(n_threads_min) < - dev_props.maxThreadsPerMultiProcessor) { - n_threads_min *= 2; - } - // Further increase the minimum block size to make sure full device occupancy - // (NB: this may lead to `n_threads_min` being larger than the kernel's maximum) - while (int(n_blocks * n_threads_min) < - dev_props.multiProcessorCount * dev_props.maxThreadsPerMultiProcessor && - int(n_threads_min) < dev_props.maxThreadsPerBlock) { - n_threads_min *= 2; - } - // Even further, increase it to allow less blocks per SM if there not enough queries. - // With this, we reduce the chance of different clusters being processed by two blocks - // on the same SM and thus improve the data locality for L1 caching. - while (int(n_queries * n_threads_min) < dev_props.maxThreadsPerMultiProcessor && - int(n_threads_min) < dev_props.maxThreadsPerBlock) { - n_threads_min *= 2; - } - - // Granularity of changing the number of threads when computing the maximum block size. - // It's good to have it multiple of the PQ book width. - uint32_t n_threads_gty = round_up_safe(1u << pq_bits, WarpSize); - - /* - Shared memory / L1 cache balance is the main limiter of this kernel. - The more blocks per SM we launch, the more shared memory we need. Besides that, we have - three versions of the kernel varying in performance and shmem usage. - - We try the most demanding and the fastest kernel first, trying to maximize occupancy with - the minimum number of blocks (just one, really). Then, we tweak the `n_threads` to further - optimize occupancy and data locality for the L1 cache. - */ - auto conf_fast = get_compute_similarity_kernel; - auto conf_no_basediff = get_compute_similarity_kernel; - auto conf_no_smem_lut = get_compute_similarity_kernel; - auto topk_or_zero = manage_local_topk ? topk : 0u; - std::array candidates{ - std::make_tuple(conf_fast(pq_bits, topk_or_zero), - total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, lut_mem, bdf_mem}, - true), - std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), - total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, lut_mem, 0}, - true), - std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero), - total_shared_mem_t{ltk_add_mem, ltk_reduce_mem, 0, bdf_mem}, - false)}; - - // we may allow slightly lower than 100% occupancy; - constexpr double kTargetOccupancy = 0.75; - // This struct is used to select the better candidate - occupancy_t selected_perf{}; - selected selected_config; - for (auto [kernel, smem_size_f, lut_is_in_shmem] : candidates) { - if (smem_size_f(WarpSize) > dev_props.sharedMemPerBlockOptin) { - // Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate. - continue; - } - - // First, we set the carveout hint to the preferred value. The driver will increase this if - // needed to run at least one block per SM. At the same time, if more blocks fit into one SM, - // this carveout value will limit the calculated occupancy. When we're done selecting the best - // launch configuration, we will tighten the carveout once more, based on the final memory - // usage and occupancy. - const int max_carveout = - estimate_carveout(preferred_shmem_carveout, smem_size_f(WarpSize), dev_props); - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, max_carveout)); - - // Get the theoretical maximum possible number of threads per block - cudaFuncAttributes kernel_attrs; - RAFT_CUDA_TRY(cudaFuncGetAttributes(&kernel_attrs, kernel)); - uint32_t n_threads = round_down_safe(kernel_attrs.maxThreadsPerBlock, n_threads_gty); - - // Actual required shmem depens on the number of threads - size_t smem_size = smem_size_f(n_threads); - - // Make sure the kernel can get enough shmem. - cudaError_t cuda_status = - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - if (cuda_status != cudaSuccess) { - RAFT_EXPECTS( - cuda_status == cudaGetLastError(), - "Tried to reset the expected cuda error code, but it didn't match the expectation"); - // Failed to request enough shmem for the kernel. Skip the candidate. - continue; - } - - occupancy_t cur(smem_size, n_threads, kernel, dev_props); - if (cur.blocks_per_sm <= 0) { - // For some reason, we still cannot make this kernel run. Skip the candidate. - continue; - } - - { - // Try to reduce the number of threads to increase occupancy and data locality - auto n_threads_tmp = n_threads_min; - while (n_threads_tmp * 2 < n_threads) { - n_threads_tmp *= 2; - } - if (n_threads_tmp < n_threads) { - while (n_threads_tmp >= n_threads_min) { - auto smem_size_tmp = smem_size_f(n_threads_tmp); - occupancy_t tmp( - smem_size_tmp, n_threads_tmp, kernel, dev_props); - bool select_it = false; - if (lut_is_in_shmem && locality_hint >= tmp.blocks_per_sm) { - // Normally, the smaller the block the better for L1 cache hit rate. - // Hence, the occupancy should be "just good enough" - select_it = tmp.occupancy >= min(kTargetOccupancy, cur.occupancy); - } else if (lut_is_in_shmem) { - // If we don't have enough repeating probes (locality_hint < tmp.blocks_per_sm), - // the locality is not going to improve with increasing the number of blocks per SM. - // Hence, the only metric here is the occupancy. - bool improves_occupancy = tmp.occupancy > cur.occupancy; - // Otherwise, the performance still improves with a smaller block size, - // given there is enough work to do - bool improves_parallelism = - tmp.occupancy == cur.occupancy && - 7u * tmp.blocks_per_sm * dev_props.multiProcessorCount <= n_blocks; - select_it = improves_occupancy || improves_parallelism; - } else { - // If we don't use shared memory for the lookup table, increasing the number of blocks - // is very taxing on the global memory usage. - // In this case, the occupancy must increase a lot to make it worth the cost. - select_it = tmp.occupancy >= min(1.0, cur.occupancy / kTargetOccupancy); - } - if (select_it) { - n_threads = n_threads_tmp; - smem_size = smem_size_tmp; - cur = tmp; - } - n_threads_tmp /= 2; - } - } - } - - { - if (selected_perf.occupancy <= 0.0 // no candidate yet - || (selected_perf.occupancy < cur.occupancy * kTargetOccupancy && - selected_perf.shmem_use >= cur.shmem_use) // much improved occupancy - ) { - selected_perf = cur; - if (lut_is_in_shmem) { - selected_config = { - kernel, dim3(n_blocks, 1, 1), dim3(n_threads, 1, 1), smem_size, size_t(0)}; - } else { - // When the global memory is used for the lookup table, we need to minimize the grid - // size; otherwise, the kernel may quickly run out of memory. - auto n_blocks_min = - std::min(n_blocks, cur.blocks_per_sm * dev_props.multiProcessorCount); - selected_config = {kernel, - dim3(n_blocks_min, 1, 1), - dim3(n_threads, 1, 1), - smem_size, - size_t(n_blocks_min) * size_t(pq_dim << pq_bits)}; - } - // Actual shmem/L1 split wildly rounds up the specified preferred carveout, so we set here - // a rather conservative bar; most likely, the kernel gets more shared memory than this, - // and the occupancy doesn't get hurt. - auto carveout = std::min(max_carveout, std::ceil(100.0 * cur.shmem_use)); - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, carveout)); - if (cur.occupancy >= kTargetOccupancy) { break; } - } else if (selected_perf.occupancy > 0.0) { - // If we found a reasonable candidate on a previous iteration, and this one is not better, - // then don't try any more candidates because they are much slower anyway. - break; - } - } - } - - RAFT_EXPECTS(selected_perf.occupancy > 0.0, - "Couldn't determine a working kernel launch configuration."); - - return selected_config; -} - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh deleted file mode 100644 index d987c0d4ed..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) -#include "ivf_pq_compute_similarity-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "ivf_pq_compute_similarity-ext.cuh" -#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh deleted file mode 100644 index 83dd994bd6..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity_template.cuh +++ /dev/null @@ -1,71 +0,0 @@ - -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -/* - * NOTE: this file is to be used in source files generated by - * src/neighbors/detailivf_pq_compute_similarity_00_generate.py - */ - -#pragma once - -#include -#include -#include - -#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( \ - OutT, LutT, IvfSampleFilterT) \ - template auto \ - raft::neighbors::ivf_pq::detail::compute_similarity_select( \ - const cudaDeviceProp& dev_props, \ - bool manage_local_topk, \ - int locality_hint, \ - double preferred_shmem_carveout, \ - uint32_t pq_bits, \ - uint32_t pq_dim, \ - uint32_t precomp_data_count, \ - uint32_t n_queries, \ - uint32_t n_probes, \ - uint32_t topk) \ - ->raft::neighbors::ivf_pq::detail::selected; \ - \ - template void \ - raft::neighbors::ivf_pq::detail::compute_similarity_run( \ - raft::neighbors::ivf_pq::detail::selected s, \ - rmm::cuda_stream_view stream, \ - uint32_t dim, \ - uint32_t n_probes, \ - uint32_t pq_dim, \ - uint32_t n_queries, \ - uint32_t queries_offset, \ - raft::distance::DistanceType metric, \ - raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ - uint32_t topk, \ - uint32_t max_samples, \ - const float* cluster_centers, \ - const float* pq_centers, \ - const uint8_t* const* pq_dataset, \ - const uint32_t* cluster_labels, \ - const uint32_t* _chunk_indices, \ - const float* queries, \ - const uint32_t* index_list, \ - float* query_kths, \ - IvfSampleFilterT sample_filter, \ - LutT* lut_scores, \ - OutT* _out_scores, \ - uint32_t* _out_indices); - -#define COMMA , diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh deleted file mode 100644 index 1dea998f9b..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -/** 8-bit floating-point storage type. - * - * This is a custom type for the current IVF-PQ implementation. No arithmetic operations defined - * only conversion to and from fp32. This type is unrelated to the proposed FP8 specification. - */ -template -struct fp_8bit { - static_assert(ExpBits + uint8_t{Signed} <= 8, "The type does not fit in 8 bits."); - constexpr static uint32_t ExpMask = (1u << (ExpBits - 1u)) - 1u; // NOLINT - constexpr static uint32_t ValBits = 8u - ExpBits; // NOLINT - - public: - uint8_t bitstring; - - HDI explicit fp_8bit(uint8_t bs) : bitstring(bs) {} - HDI explicit fp_8bit(float fp) : fp_8bit(float2fp_8bit(fp).bitstring) {} - HDI auto operator=(float fp) -> fp_8bit& - { - bitstring = float2fp_8bit(fp).bitstring; - return *this; - } - HDI explicit operator float() const { return fp_8bit2float(*this); } - HDI explicit operator half() const { return fp_8bit2half(*this); } - - private: - static constexpr float kMin = 1.0f / float(1u << ExpMask); - static constexpr float kMax = float(1u << (ExpMask + 1)) * (2.0f - 1.0f / float(1u << ValBits)); - - static HDI auto float2fp_8bit(float v) -> fp_8bit - { - if constexpr (Signed) { - auto u = fp_8bit(std::abs(v)).bitstring; - u = (u & 0xfeu) | uint8_t{v < 0}; // set the sign bit - return fp_8bit(u); - } else { - // sic! all small and negative numbers are truncated to zero. - if (v < kMin) { return fp_8bit{static_cast(0)}; } - // protect from overflow - if (v >= kMax) { return fp_8bit{static_cast(0xffu)}; } - // the rest of possible float values should be within the normalized range - return fp_8bit{static_cast( - (*reinterpret_cast(&v) + (ExpMask << 23u) - 0x3f800000u) >> (15u + ExpBits))}; - } - } - - static HDI auto fp_8bit2float(const fp_8bit& v) -> float - { - uint32_t u = v.bitstring; - if constexpr (Signed) { - u &= ~1; // zero the sign bit - } - float r; - constexpr uint32_t kBase32 = (0x3f800000u | (0x00400000u >> ValBits)) - (ExpMask << 23); - *reinterpret_cast(&r) = kBase32 + (u << (15u + ExpBits)); - if constexpr (Signed) { // recover the sign bit - if (v.bitstring & 1) { r = -r; } - } - return r; - } - - static HDI auto fp_8bit2half(const fp_8bit& v) -> half - { - uint16_t u = v.bitstring; - if constexpr (Signed) { - u &= ~1; // zero the sign bit - } - half r; - constexpr uint16_t kBase16 = (0x3c00u | (0x0200u >> ValBits)) - (ExpMask << 10); - *reinterpret_cast(&r) = kBase16 + (u << (2u + ExpBits)); - if constexpr (Signed) { // recover the sign bit - if (v.bitstring & 1) { r = -r; } - } - return r; - } -}; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh deleted file mode 100644 index 87e6d0a774..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ /dev/null @@ -1,718 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -using namespace raft::spatial::knn::detail; // NOLINT - -/** - * Select the clusters to probe and, as a side-effect, translate the queries type `T -> float` - * - * Assuming the number of clusters is not that big (a few thousands), we do a plain GEMM - * followed by select_k to select the clusters to probe. There's no need to return the similarity - * scores here. - */ -template -void select_clusters(raft::resources const& handle, - uint32_t* clusters_to_probe, // [n_queries, n_probes] - float* float_queries, // [n_queries, dim_ext] - uint32_t n_queries, - uint32_t n_probes, - uint32_t n_lists, - uint32_t dim, - uint32_t dim_ext, - raft::distance::DistanceType metric, - const T* queries, // [n_queries, dim] - const float* cluster_centers, // [n_lists, dim_ext] - rmm::device_async_resource_ref mr) -{ - common::nvtx::range fun_scope( - "ivf_pq::search::select_clusters(n_probes = %u, n_queries = %u, n_lists = %u, dim = %u)", - n_probes, - n_queries, - n_lists, - dim); - auto stream = resource::get_cuda_stream(handle); - /* NOTE[qc_distances] - - We compute query-center distances to choose the clusters to probe. - We accomplish that with just one GEMM operation thanks to some preprocessing: - - L2 distance: - cluster_centers[i, dim()] contains the squared norm of the center vector i; - we extend the dimension K of the GEMM to compute it together with all the dot products: - - `qc_distances[i, j] = |cluster_centers[j]|^2 - 2 * (queries[i], cluster_centers[j])` - - This is a monotonous mapping of the proper L2 distance. - - IP distance: - `qc_distances[i, j] = - (queries[i], cluster_centers[j])` - - This is a negative inner-product distance. We minimize it to find the similar clusters. - - NB: qc_distances is NOT used further in ivfpq_search. - */ - float norm_factor; - switch (metric) { - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2Expanded: norm_factor = 1.0 / -2.0; break; - case raft::distance::DistanceType::InnerProduct: norm_factor = 0.0; break; - default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); - } - auto float_queries_view = - raft::make_device_vector_view(float_queries, dim_ext * n_queries); - linalg::map_offset( - handle, float_queries_view, [queries, dim, dim_ext, norm_factor] __device__(uint32_t ix) { - uint32_t col = ix % dim_ext; - uint32_t row = ix / dim_ext; - return col < dim ? utils::mapping{}(queries[col + dim * row]) : norm_factor; - }); - - float alpha; - float beta; - uint32_t gemm_k = dim; - switch (metric) { - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2Expanded: { - alpha = -2.0; - beta = 0.0; - gemm_k = dim + 1; - RAFT_EXPECTS(gemm_k <= dim_ext, "unexpected gemm_k or dim_ext"); - } break; - case raft::distance::DistanceType::InnerProduct: { - alpha = -1.0; - beta = 0.0; - } break; - default: RAFT_FAIL("Unsupported distance type %d.", int(metric)); - } - rmm::device_uvector qc_distances(n_queries * n_lists, stream, mr); - linalg::gemm(handle, - true, - false, - n_lists, - n_queries, - gemm_k, - &alpha, - cluster_centers, - dim_ext, - float_queries, - dim_ext, - &beta, - qc_distances.data(), - n_lists, - stream); - - // Select neighbor clusters for each query. - rmm::device_uvector cluster_dists(n_queries * n_probes, stream, mr); - matrix::detail::select_k(handle, - qc_distances.data(), - nullptr, - n_queries, - n_lists, - n_probes, - cluster_dists.data(), - clusters_to_probe, - true); -} - -/** - * An approximation to the number of times each cluster appears in a batched sample. - * - * If the pairs (probe_ix, query_ix) are sorted by the probe_ix, there is a good chance that - * the same probe_ix (cluster) is processed by several blocks on a single SM. This greatly - * increases the L1 cache hit rate (i.e. increases the data locality). - * - * This function gives an estimate of how many times a specific cluster may appear in the - * batch. Thus, it gives a practical limit to how many blocks should be active on the same SM - * to improve the L1 cache hit rate. - */ -constexpr inline auto expected_probe_coresidency(uint32_t n_clusters, - uint32_t n_probes, - uint32_t n_queries) -> uint32_t -{ - /* - Let say: - n = n_clusters - k = n_probes - m = n_queries - r = # of times a specific block appears in the batched sample. - - Then, r has the Binomial distribution (p = k / n): - P(r) = C(m,r) * k^r * (n - k)^(m - r) / n^m - E[r] = m * k / n - E[r | r > 0] = m * k / n / (1 - (1 - k/n)^m) - - The latter can be approximated by a much simpler formula, assuming (k / n) -> 0: - E[r | r > 0] = 1 + (m - 1) * k / (2 * n) + O( (k/n)^2 ) - */ - return 1 + (n_queries - 1) * n_probes / (2 * n_clusters); -} - -struct search_kernel_key { - bool manage_local_topk; - uint32_t locality_hint; - double preferred_shmem_carveout; - uint32_t pq_bits; - uint32_t pq_dim; - uint32_t precomp_data_count; - uint32_t n_queries; - uint32_t n_probes; - uint32_t topk; -}; - -inline auto operator==(const search_kernel_key& a, const search_kernel_key& b) -> bool -{ - return a.manage_local_topk == b.manage_local_topk && a.locality_hint == b.locality_hint && - a.preferred_shmem_carveout == b.preferred_shmem_carveout && a.pq_bits == b.pq_bits && - a.pq_dim == b.pq_dim && a.precomp_data_count == b.precomp_data_count && - a.n_queries == b.n_queries && a.n_probes == b.n_probes && a.topk == b.topk; -} - -struct search_kernel_key_hash { - inline auto operator()(const search_kernel_key& x) const noexcept -> std::size_t - { - return (size_t{x.manage_local_topk} << 63) + - size_t{x.topk} * size_t{x.n_probes} * size_t{x.n_queries} + - size_t{x.precomp_data_count} * size_t{x.pq_dim} * size_t{x.pq_bits}; - } -}; - -template -struct search_kernel_cache { - /** Number of matmul invocations to cache. */ - static constexpr size_t kDefaultSize = 100; - cache::lru, - selected> - value{kDefaultSize}; -}; - -/** - * The "main part" of the search, which assumes that outer-level `search` has already: - * - * 1. computed the closest clusters to probe (`clusters_to_probe`); - * 2. transformed input queries into the rotated space (rot_dim); - * 3. split the query batch into smaller chunks, so that the device workspace - * is guaranteed to fit into GPU memory. - */ -template -void ivfpq_search_worker(raft::resources const& handle, - const index& index, - uint32_t max_samples, - uint32_t n_probes, - uint32_t topK, - uint32_t n_queries, - uint32_t queries_offset, // needed for filtering - const uint32_t* clusters_to_probe, // [n_queries, n_probes] - const float* query, // [n_queries, rot_dim] - IdxT* neighbors, // [n_queries, topK] - float* distances, // [n_queries, topK] - float scaling_factor, - double preferred_shmem_carveout, - IvfSampleFilterT sample_filter) -{ - common::nvtx::range fun_scope( - "ivf_pq::search-worker(n_queries = %u, n_probes = %u, k = %u, dim = %zu)", - n_queries, - n_probes, - topK, - index.dim()); - auto stream = resource::get_cuda_stream(handle); - auto mr = resource::get_workspace_resource(handle); - - bool manage_local_topk = is_local_topk_feasible(topK, n_probes, n_queries); - auto topk_len = manage_local_topk ? n_probes * topK : max_samples; - std::size_t n_queries_probes = std::size_t(n_queries) * std::size_t(n_probes); - std::size_t n_queries_topk_len = std::size_t(n_queries) * std::size_t(topk_len); - if (manage_local_topk) { - RAFT_LOG_DEBUG("Fused version of the search kernel is selected (manage_local_topk == true)"); - } else { - RAFT_LOG_DEBUG( - "Non-fused version of the search kernel is selected (manage_local_topk == false)"); - } - - rmm::device_uvector index_list_sorted_buf(0, stream, mr); - uint32_t* index_list_sorted = nullptr; - rmm::device_uvector num_samples(n_queries, stream, mr); - rmm::device_uvector chunk_index(n_queries_probes, stream, mr); - // [maxBatchSize, max_samples] or [maxBatchSize, n_probes, topk] - rmm::device_uvector distances_buf(n_queries_topk_len, stream, mr); - rmm::device_uvector neighbors_buf(0, stream, mr); - uint32_t* neighbors_ptr = nullptr; - if (manage_local_topk) { - neighbors_buf.resize(n_queries_topk_len, stream); - neighbors_ptr = neighbors_buf.data(); - } - rmm::device_uvector neighbors_uint32_buf(0, stream, mr); - uint32_t* neighbors_uint32 = nullptr; - if constexpr (sizeof(IdxT) == sizeof(uint32_t)) { - neighbors_uint32 = reinterpret_cast(neighbors); - } else { - neighbors_uint32_buf.resize(n_queries * topK, stream); - neighbors_uint32 = neighbors_uint32_buf.data(); - } - - ivf::detail::calc_chunk_indices::configure(n_probes, n_queries)(index.list_sizes().data_handle(), - clusters_to_probe, - chunk_index.data(), - num_samples.data(), - stream); - - auto coresidency = expected_probe_coresidency(index.n_lists(), n_probes, n_queries); - - if (coresidency > 1) { - // Sorting index by cluster number (label). - // The goal is to incrase the L2 cache hit rate to read the vectors - // of a cluster by processing the cluster at the same time as much as - // possible. - index_list_sorted_buf.resize(n_queries_probes, stream); - auto index_list_buf = - make_device_mdarray(handle, mr, make_extents(n_queries_probes)); - rmm::device_uvector cluster_labels_out(n_queries_probes, stream, mr); - auto index_list = index_list_buf.data_handle(); - index_list_sorted = index_list_sorted_buf.data(); - - linalg::map_offset(handle, index_list_buf.view(), identity_op{}); - - int begin_bit = 0; - int end_bit = sizeof(uint32_t) * 8; - size_t cub_workspace_size = 0; - cub::DeviceRadixSort::SortPairs(nullptr, - cub_workspace_size, - clusters_to_probe, - cluster_labels_out.data(), - index_list, - index_list_sorted, - n_queries_probes, - begin_bit, - end_bit, - stream); - rmm::device_buffer cub_workspace(cub_workspace_size, stream, mr); - cub::DeviceRadixSort::SortPairs(cub_workspace.data(), - cub_workspace_size, - clusters_to_probe, - cluster_labels_out.data(), - index_list, - index_list_sorted, - n_queries_probes, - begin_bit, - end_bit, - stream); - } - - // select and run the main search kernel - uint32_t precomp_data_count = 0; - switch (index.metric()) { - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2SqrtUnexpanded: - case distance::DistanceType::L2Unexpanded: - case distance::DistanceType::L2Expanded: { - // stores basediff (query[i] - center[i]) - precomp_data_count = index.rot_dim(); - } break; - case distance::DistanceType::InnerProduct: { - // stores two components (query[i] * center[i], query[i] * center[i]) - precomp_data_count = index.rot_dim() * 2; - } break; - default: { - RAFT_FAIL("Unsupported metric"); - } break; - } - - selected search_instance; - search_kernel_key search_key{manage_local_topk, - coresidency, - preferred_shmem_carveout, - index.pq_bits(), - index.pq_dim(), - precomp_data_count, - n_queries, - n_probes, - topK}; - auto& cache = - resource::get_custom_resource>(handle) - ->value; - if (!cache.get(search_key, &search_instance)) { - search_instance = compute_similarity_select( - resource::get_device_properties(handle), - manage_local_topk, - coresidency, - preferred_shmem_carveout, - index.pq_bits(), - index.pq_dim(), - precomp_data_count, - n_queries, - n_probes, - topK); - cache.set(search_key, search_instance); - } - - rmm::device_uvector device_lut(search_instance.device_lut_size, stream, mr); - std::optional> query_kths_buf{std::nullopt}; - float* query_kths = nullptr; - if (manage_local_topk) { - query_kths_buf.emplace( - make_device_mdarray(handle, mr, make_extents(n_queries))); - linalg::map( - handle, - query_kths_buf->view(), - raft::const_op{ivf::detail::dummy_block_sort_t::queue_t::kDummy}); - query_kths = query_kths_buf->data_handle(); - } - compute_similarity_run(search_instance, - stream, - index.rot_dim(), - n_probes, - index.pq_dim(), - n_queries, - queries_offset, - index.metric(), - index.codebook_kind(), - topK, - max_samples, - index.centers_rot().data_handle(), - index.pq_centers().data_handle(), - index.data_ptrs().data_handle(), - clusters_to_probe, - chunk_index.data(), - query, - index_list_sorted, - query_kths, - sample_filter, - device_lut.data(), - distances_buf.data(), - neighbors_ptr); - - // Select topk vectors for each query - rmm::device_uvector topk_dists(n_queries * topK, stream, mr); - matrix::detail::select_k(handle, - distances_buf.data(), - neighbors_ptr, - n_queries, - topk_len, - topK, - topk_dists.data(), - neighbors_uint32, - true, - false, - matrix::SelectAlgo::kAuto, - manage_local_topk ? nullptr : num_samples.data()); - - // Postprocessing - ivf::detail::postprocess_distances( - distances, topk_dists.data(), index.metric(), n_queries, topK, scaling_factor, true, stream); - ivf::detail::postprocess_neighbors(neighbors, - neighbors_uint32, - index.inds_ptrs().data_handle(), - clusters_to_probe, - chunk_index.data(), - n_queries, - n_probes, - topK, - stream); -} - -/** - * This structure helps selecting a proper instance of the worker search function, - * which contains a few template parameters. - */ -template -struct ivfpq_search { - public: - using fun_t = decltype(&ivfpq_search_worker); - - /** - * Select an instance of the ivf-pq search function based on search tuning parameters, - * such as the look-up data type or the internal score type. - */ - static auto fun(const search_params& params, distance::DistanceType metric) -> fun_t - { - return fun_try_score_t(params, metric); - } - - private: - template - static auto filter_reasonable_instances(const search_params& params) -> fun_t - { - if constexpr (sizeof(ScoreT) >= sizeof(LutT)) { - return ivfpq_search_worker; - } else { - RAFT_FAIL( - "Unexpected lut_dtype / internal_distance_dtype combination (%d, %d). " - "Size of the internal_distance_dtype should be not smaller than the size of the lut_dtype.", - int(params.lut_dtype), - int(params.internal_distance_dtype)); - } - } - - template - static auto fun_try_lut_t(const search_params& params, distance::DistanceType metric) -> fun_t - { - bool signed_metric = false; - switch (metric) { - case raft::distance::DistanceType::InnerProduct: signed_metric = true; break; - default: break; - } - - switch (params.lut_dtype) { - case CUDA_R_32F: return filter_reasonable_instances(params); - case CUDA_R_16F: return filter_reasonable_instances(params); - case CUDA_R_8U: - case CUDA_R_8I: - if (signed_metric) { - return filter_reasonable_instances>(params); - } else { - return filter_reasonable_instances>(params); - } - default: RAFT_FAIL("Unexpected lut_dtype (%d)", int(params.lut_dtype)); - } - } - - static auto fun_try_score_t(const search_params& params, distance::DistanceType metric) -> fun_t - { - switch (params.internal_distance_dtype) { - case CUDA_R_32F: return fun_try_lut_t(params, metric); - case CUDA_R_16F: return fun_try_lut_t(params, metric); - default: - RAFT_FAIL("Unexpected internal_distance_dtype (%d)", int(params.internal_distance_dtype)); - } - } -}; - -/** - * A heuristic for bounding the number of queries per batch, to improve GPU utilization. - * (based on the number of SMs and the work size). - * - * @param res is used to query the workspace size - * @param k top-k - * @param n_probes number of selected clusters per query - * @param n_queries number of queries hoped to be processed at once. - * (maximum value for the returned batch size) - * @param max_samples maximum possible number of samples to be processed for the given `n_probes` - * - * @return maximum recommended batch size. - */ -inline auto get_max_batch_size(raft::resources const& res, - uint32_t k, - uint32_t n_probes, - uint32_t n_queries, - uint32_t max_samples) -> uint32_t -{ - uint32_t max_batch_size = n_queries; - uint32_t n_ctas_total = resource::get_device_properties(res).multiProcessorCount * 2; - uint32_t n_ctas_total_per_batch = n_ctas_total / max_batch_size; - float utilization = float(n_ctas_total_per_batch * max_batch_size) / n_ctas_total; - if (n_ctas_total_per_batch > 1 || (n_ctas_total_per_batch == 1 && utilization < 0.6)) { - uint32_t n_ctas_total_per_batch_1 = n_ctas_total_per_batch + 1; - uint32_t max_batch_size_1 = n_ctas_total / n_ctas_total_per_batch_1; - float utilization_1 = float(n_ctas_total_per_batch_1 * max_batch_size_1) / n_ctas_total; - if (utilization < utilization_1) { max_batch_size = max_batch_size_1; } - } - // Check in the tmp distance buffer is not too big - auto ws_size = [k, n_probes, max_samples](uint32_t bs) -> uint64_t { - const uint64_t buffers_fused = 12ull * k * n_probes; - const uint64_t buffers_non_fused = 4ull * max_samples; - const uint64_t other = 32ull * n_probes; - return static_cast(bs) * - (other + (is_local_topk_feasible(k, n_probes, bs) ? buffers_fused : buffers_non_fused)); - }; - auto max_ws_size = resource::get_workspace_free_bytes(res); - if (ws_size(max_batch_size) > max_ws_size) { - uint32_t smaller_batch_size = bound_by_power_of_two(max_batch_size); - // gradually reduce the batch size until we fit into the max size limit. - while (smaller_batch_size > 1 && ws_size(smaller_batch_size) > max_ws_size) { - smaller_batch_size >>= 1; - } - return smaller_batch_size; - } - return max_batch_size; -} - -/** See raft::spatial::knn::ivf_pq::search docs */ -template -inline void search(raft::resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) -{ - static_assert(std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v, - "Unsupported element type."); - common::nvtx::range fun_scope( - "ivf_pq::search(n_queries = %u, n_probes = %u, k = %u, dim = %zu)", - n_queries, - params.n_probes, - k, - index.dim()); - - RAFT_EXPECTS( - params.internal_distance_dtype == CUDA_R_16F || params.internal_distance_dtype == CUDA_R_32F, - "internal_distance_dtype must be either CUDA_R_16F or CUDA_R_32F"); - RAFT_EXPECTS(params.lut_dtype == CUDA_R_16F || params.lut_dtype == CUDA_R_32F || - params.lut_dtype == CUDA_R_8U, - "lut_dtype must be CUDA_R_16F, CUDA_R_32F or CUDA_R_8U"); - RAFT_EXPECTS(k > 0, "parameter `k` in top-k must be positive."); - RAFT_EXPECTS( - k <= index.size(), - "parameter `k` (%u) in top-k must not be larger that the total size of the index (%zu)", - k, - static_cast(index.size())); - RAFT_EXPECTS(params.n_probes > 0, - "n_probes (number of clusters to probe in the search) must be positive."); - - switch (utils::check_pointer_residency(queries, neighbors, distances)) { - case utils::pointer_residency::device_only: - case utils::pointer_residency::host_and_device: break; - default: RAFT_FAIL("all pointers must be accessible from the device."); - } - - auto stream = resource::get_cuda_stream(handle); - - auto dim = index.dim(); - auto dim_ext = index.dim_ext(); - auto n_probes = std::min(params.n_probes, index.n_lists()); - - uint32_t max_samples = 0; - { - IdxT ms = Pow2<128>::roundUp(index.accum_sorted_sizes()(n_probes)); - RAFT_EXPECTS(ms <= IdxT(std::numeric_limits::max()), - "The maximum sample size is too big."); - max_samples = ms; - } - - auto mr = resource::get_workspace_resource(handle); - - // Maximum number of query vectors to search at the same time. - const auto max_queries = std::min(std::max(n_queries, 1), 4096); - auto max_batch_size = get_max_batch_size(handle, k, n_probes, max_queries, max_samples); - - rmm::device_uvector float_queries(max_queries * dim_ext, stream, mr); - rmm::device_uvector rot_queries(max_queries * index.rot_dim(), stream, mr); - rmm::device_uvector clusters_to_probe(max_queries * n_probes, stream, mr); - - auto filter_adapter = raft::neighbors::filtering::ivf_to_sample_filter( - index.inds_ptrs().data_handle(), sample_filter); - auto search_instance = ivfpq_search::fun(params, index.metric()); - - for (uint32_t offset_q = 0; offset_q < n_queries; offset_q += max_queries) { - uint32_t queries_batch = min(max_queries, n_queries - offset_q); - common::nvtx::range batch_scope( - "ivf_pq::search-batch(queries: %u - %u)", offset_q, offset_q + queries_batch); - - select_clusters(handle, - clusters_to_probe.data(), - float_queries.data(), - queries_batch, - n_probes, - index.n_lists(), - dim, - dim_ext, - index.metric(), - queries + static_cast(dim) * offset_q, - index.centers().data_handle(), - mr); - - // Rotate queries - float alpha = 1.0; - float beta = 0.0; - linalg::gemm(handle, - true, - false, - index.rot_dim(), - queries_batch, - dim, - &alpha, - index.rotation_matrix().data_handle(), - dim, - float_queries.data(), - dim_ext, - &beta, - rot_queries.data(), - index.rot_dim(), - stream); - - for (uint32_t offset_b = 0; offset_b < queries_batch; offset_b += max_batch_size) { - uint32_t batch_size = min(max_batch_size, queries_batch - offset_b); - /* The distance calculation is done in the rotated/transformed space; - as long as `index.rotation_matrix()` is orthogonal, the distances and thus results are - preserved. - */ - search_instance(handle, - index, - max_samples, - n_probes, - k, - batch_size, - offset_q + offset_b, - clusters_to_probe.data() + uint64_t(n_probes) * offset_b, - rot_queries.data() + uint64_t(index.rot_dim()) * offset_b, - neighbors + uint64_t(k) * (offset_q + offset_b), - distances + uint64_t(k) * (offset_q + offset_b), - utils::config::kDivisor / utils::config::kDivisor, - params.preferred_shmem_carveout, - filter_adapter); - } - } -} - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh deleted file mode 100644 index 4428fa370b..0000000000 --- a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh +++ /dev/null @@ -1,190 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace raft::neighbors::ivf_pq::detail { - -// Serialization version -// No backward compatibility yet; that is, can't add additional fields without breaking -// backward compatibility. -// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward -// compatible fashion. -constexpr int kSerializationVersion = 3; - -/** - * Write the index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] os output stream - * @param[in] index IVF-PQ index - * - */ -template -void serialize(raft::resources const& handle_, std::ostream& os, const index& index) -{ - RAFT_LOG_DEBUG("Size %zu, dim %d, pq_dim %d, pq_bits %d", - static_cast(index.size()), - static_cast(index.dim()), - static_cast(index.pq_dim()), - static_cast(index.pq_bits())); - - serialize_scalar(handle_, os, kSerializationVersion); - serialize_scalar(handle_, os, index.size()); - serialize_scalar(handle_, os, index.dim()); - serialize_scalar(handle_, os, index.pq_bits()); - serialize_scalar(handle_, os, index.pq_dim()); - serialize_scalar(handle_, os, index.conservative_memory_allocation()); - - serialize_scalar(handle_, os, index.metric()); - serialize_scalar(handle_, os, index.codebook_kind()); - serialize_scalar(handle_, os, index.n_lists()); - - serialize_mdspan(handle_, os, index.pq_centers()); - serialize_mdspan(handle_, os, index.centers()); - serialize_mdspan(handle_, os, index.centers_rot()); - serialize_mdspan(handle_, os, index.rotation_matrix()); - - auto sizes_host = make_host_mdarray(index.list_sizes().extents()); - copy(sizes_host.data_handle(), - index.list_sizes().data_handle(), - sizes_host.size(), - resource::get_cuda_stream(handle_)); - resource::sync_stream(handle_); - serialize_mdspan(handle_, os, sizes_host.view()); - auto list_store_spec = list_spec{index.pq_bits(), index.pq_dim(), true}; - for (uint32_t label = 0; label < index.n_lists(); label++) { - ivf::serialize_list(handle_, os, index.lists()[label], list_store_spec, sizes_host(label)); - } -} - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index IVF-PQ index - * - */ -template -void serialize(raft::resources const& handle_, - const std::string& filename, - const index& index) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - detail::serialize(handle_, of, index); - - of.close(); - if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } - return; -} - -/** - * Load index from input stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] is input stream - * - */ -template -auto deserialize(raft::resources const& handle_, std::istream& is) -> index -{ - auto ver = deserialize_scalar(handle_, is); - if (ver != kSerializationVersion) { - RAFT_FAIL("serialization version mismatch %d vs. %d", ver, kSerializationVersion); - } - auto n_rows = deserialize_scalar(handle_, is); - auto dim = deserialize_scalar(handle_, is); - auto pq_bits = deserialize_scalar(handle_, is); - auto pq_dim = deserialize_scalar(handle_, is); - auto cma = deserialize_scalar(handle_, is); - - auto metric = deserialize_scalar(handle_, is); - auto codebook_kind = deserialize_scalar(handle_, is); - auto n_lists = deserialize_scalar(handle_, is); - - RAFT_LOG_DEBUG("n_rows %zu, dim %d, pq_dim %d, pq_bits %d, n_lists %d", - static_cast(n_rows), - static_cast(dim), - static_cast(pq_dim), - static_cast(pq_bits), - static_cast(n_lists)); - - auto index = raft::neighbors::ivf_pq::index( - handle_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, cma); - - deserialize_mdspan(handle_, is, index.pq_centers()); - deserialize_mdspan(handle_, is, index.centers()); - deserialize_mdspan(handle_, is, index.centers_rot()); - deserialize_mdspan(handle_, is, index.rotation_matrix()); - deserialize_mdspan(handle_, is, index.list_sizes()); - auto list_device_spec = list_spec{pq_bits, pq_dim, cma}; - auto list_store_spec = list_spec{pq_bits, pq_dim, true}; - for (auto& list : index.lists()) { - ivf::deserialize_list(handle_, is, list, list_store_spec, list_device_spec); - } - - resource::sync_stream(handle_); - - ivf::detail::recompute_internal_state(handle_, index); - - return index; -} - -/** - * Load index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * - */ -template -auto deserialize(raft::resources const& handle_, const std::string& filename) -> index -{ - std::ifstream infile(filename, std::ios::in | std::ios::binary); - - if (!infile) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - auto index = detail::deserialize(handle_, infile); - - infile.close(); - - return index; -} - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh deleted file mode 100644 index daa2798b00..0000000000 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ /dev/null @@ -1,551 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include - -#include -#include -#include - -namespace raft::neighbors::detail { -using namespace raft::spatial::knn::detail; -using namespace raft::spatial::knn; - -/** - * Calculates brute force knn, using a fixed memory budget - * by tiling over both the rows and columns of pairwise_distances - */ -template -void tiled_brute_force_knn(const raft::resources& handle, - const ElementType* search, // size (m ,d) - const ElementType* index, // size (n ,d) - size_t m, - size_t n, - size_t d, - size_t k, - ElementType* distances, // size (m, k) - IndexType* indices, // size (m, k) - raft::distance::DistanceType metric, - float metric_arg = 2.0, - size_t max_row_tile_size = 0, - size_t max_col_tile_size = 0, - DistanceEpilogue distance_epilogue = raft::identity_op(), - const ElementType* precomputed_index_norms = nullptr, - const ElementType* precomputed_search_norms = nullptr) -{ - // Figure out the number of rows/cols to tile for - size_t tile_rows = 0; - size_t tile_cols = 0; - auto stream = resource::get_cuda_stream(handle); - auto device_memory = resource::get_workspace_resource(handle); - auto total_mem = rmm::available_device_memory().second; - faiss_select::chooseTileSize(m, n, d, sizeof(ElementType), total_mem, tile_rows, tile_cols); - - // for unittesting, its convenient to be able to put a max size on the tiles - // so we can test the tiling logic without having to use huge inputs. - if (max_row_tile_size && (tile_rows > max_row_tile_size)) { tile_rows = max_row_tile_size; } - if (max_col_tile_size && (tile_cols > max_col_tile_size)) { tile_cols = max_col_tile_size; } - - // tile_cols must be at least k items - tile_cols = std::max(tile_cols, k); - - // stores pairwise distances for the current tile - rmm::device_uvector temp_distances(tile_rows * tile_cols, stream); - - // calculate norms for L2 expanded distances - this lets us avoid calculating - // norms repeatedly per-tile, and just do once for the entire input - auto pairwise_metric = metric; - rmm::device_uvector search_norms(0, stream); - rmm::device_uvector index_norms(0, stream); - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded || - metric == raft::distance::DistanceType::CosineExpanded) { - if (!precomputed_search_norms) { search_norms.resize(m, stream); } - if (!precomputed_index_norms) { index_norms.resize(n, stream); } - // cosine needs the l2norm, where as l2 distances needs the squared norm - if (metric == raft::distance::DistanceType::CosineExpanded) { - if (!precomputed_search_norms) { - raft::linalg::rowNorm(search_norms.data(), - search, - d, - m, - raft::linalg::NormType::L2Norm, - true, - stream, - raft::sqrt_op{}); - } - if (!precomputed_index_norms) { - raft::linalg::rowNorm(index_norms.data(), - index, - d, - n, - raft::linalg::NormType::L2Norm, - true, - stream, - raft::sqrt_op{}); - } - } else { - if (!precomputed_search_norms) { - raft::linalg::rowNorm( - search_norms.data(), search, d, m, raft::linalg::NormType::L2Norm, true, stream); - } - if (!precomputed_index_norms) { - raft::linalg::rowNorm( - index_norms.data(), index, d, n, raft::linalg::NormType::L2Norm, true, stream); - } - } - pairwise_metric = raft::distance::DistanceType::InnerProduct; - } - - // if we're tiling over columns, we need additional buffers for temporary output - // distances/indices - size_t num_col_tiles = raft::ceildiv(n, tile_cols); - size_t temp_out_cols = k * num_col_tiles; - - // the final column tile could have less than 'k' items in it - // in which case the number of columns here is too high in the temp output. - // adjust if necessary - auto last_col_tile_size = n % tile_cols; - if (last_col_tile_size && (last_col_tile_size < k)) { temp_out_cols -= k - last_col_tile_size; } - - // if we have less than k items in the index, we should fill out the result - // to indicate that we are missing items (and match behaviour in faiss) - if (n < k) { - raft::matrix::fill(handle, - raft::make_device_matrix_view(distances, m, k), - std::numeric_limits::lowest()); - - if constexpr (std::is_signed_v) { - raft::matrix::fill(handle, raft::make_device_matrix_view(indices, m, k), IndexType{-1}); - } - } - - rmm::device_uvector temp_out_distances(tile_rows * temp_out_cols, stream); - rmm::device_uvector temp_out_indices(tile_rows * temp_out_cols, stream); - - bool select_min = raft::distance::is_min_close(metric); - - for (size_t i = 0; i < m; i += tile_rows) { - size_t current_query_size = std::min(tile_rows, m - i); - - for (size_t j = 0; j < n; j += tile_cols) { - size_t current_centroid_size = std::min(tile_cols, n - j); - size_t current_k = std::min(current_centroid_size, k); - - // calculate the top-k elements for the current tile, by calculating the - // full pairwise distance for the tile - and then selecting the top-k from that - // note: we're using a int32 IndexType here on purpose in order to - // use the pairwise_distance instantiations. Since the tile size will ensure - // that the total memory is < 1GB per tile, this will not cause any issues - distance::pairwise_distance(handle, - search + i * d, - index + j * d, - temp_distances.data(), - current_query_size, - current_centroid_size, - d, - pairwise_metric, - true, - metric_arg); - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded) { - auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); - auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); - auto dist = temp_distances.data(); - bool sqrt = metric == raft::distance::DistanceType::L2SqrtExpanded; - - raft::linalg::map_offset( - handle, - raft::make_device_vector_view(dist, current_query_size * current_centroid_size), - [=] __device__(IndexType idx) { - IndexType row = i + (idx / current_centroid_size); - IndexType col = j + (idx % current_centroid_size); - - raft::distance::detail::ops::l2_exp_cutlass_op l2_op(sqrt); - auto val = l2_op(row_norms[row], col_norms[col], dist[idx]); - return distance_epilogue(val, row, col); - }); - } else if (metric == raft::distance::DistanceType::CosineExpanded) { - auto row_norms = precomputed_search_norms ? precomputed_search_norms : search_norms.data(); - auto col_norms = precomputed_index_norms ? precomputed_index_norms : index_norms.data(); - auto dist = temp_distances.data(); - - raft::linalg::map_offset( - handle, - raft::make_device_vector_view(dist, current_query_size * current_centroid_size), - [=] __device__(IndexType idx) { - IndexType row = i + (idx / current_centroid_size); - IndexType col = j + (idx % current_centroid_size); - auto val = 1.0 - dist[idx] / (row_norms[row] * col_norms[col]); - val = distance_epilogue(val, row, col); - return val; - }); - } else { - // if we're not l2 distance, and we have a distance epilogue - run it now - if constexpr (!std::is_same_v) { - auto distances_ptr = temp_distances.data(); - raft::linalg::map_offset( - handle, - raft::make_device_vector_view(temp_distances.data(), - current_query_size * current_centroid_size), - [=] __device__(size_t idx) { - IndexType row = i + (idx / current_centroid_size); - IndexType col = j + (idx % current_centroid_size); - return distance_epilogue(distances_ptr[idx], row, col); - }); - } - } - - matrix::select_k( - handle, - raft::make_device_matrix_view( - temp_distances.data(), current_query_size, current_centroid_size), - std::nullopt, - raft::make_device_matrix_view( - distances + i * k, current_query_size, current_k), - raft::make_device_matrix_view( - indices + i * k, current_query_size, current_k), - select_min, - true); - - // if we're tiling over columns, we need to do a couple things to fix up - // the output of select_k - // 1. The column id's in the output are relative to the tile, so we need - // to adjust the column ids by adding the column the tile starts at (j) - // 2. select_k writes out output in a row-major format, which means we - // can't just concat the output of all the tiles and do a select_k on the - // concatenation. - // Fix both of these problems in a single pass here - if (tile_cols != n) { - const ElementType* in_distances = distances + i * k; - const IndexType* in_indices = indices + i * k; - ElementType* out_distances = temp_out_distances.data(); - IndexType* out_indices = temp_out_indices.data(); - - auto count = thrust::make_counting_iterator(0); - thrust::for_each(resource::get_thrust_policy(handle), - count, - count + current_query_size * current_k, - [=] __device__(IndexType i) { - IndexType row = i / current_k, col = i % current_k; - IndexType out_index = row * temp_out_cols + j * k / tile_cols + col; - - out_distances[out_index] = in_distances[i]; - out_indices[out_index] = in_indices[i] + j; - }); - } - } - - if (tile_cols != n) { - // select the actual top-k items here from the temporary output - matrix::select_k( - handle, - raft::make_device_matrix_view( - temp_out_distances.data(), current_query_size, temp_out_cols), - raft::make_device_matrix_view( - temp_out_indices.data(), current_query_size, temp_out_cols), - raft::make_device_matrix_view( - distances + i * k, current_query_size, k), - raft::make_device_matrix_view( - indices + i * k, current_query_size, k), - select_min, - true); - } - } -} - -/** - * Search the kNN for the k-nearest neighbors of a set of query vectors - * @param[in] input vector of device device memory array pointers to search - * @param[in] sizes vector of memory sizes for each device array pointer in input - * @param[in] D number of cols in input and search_items - * @param[in] search_items set of vectors to query for neighbors - * @param[in] n number of items in search_items - * @param[out] res_I pointer to device memory for returning k nearest indices - * @param[out] res_D pointer to device memory for returning k nearest distances - * @param[in] k number of neighbors to query - * @param[in] userStream the main cuda stream to use - * @param[in] internalStreams optional when n_params > 0, the index partitions can be - * queried in parallel using these streams. Note that n_int_streams also - * has to be > 0 for these to be used and their cardinality does not need - * to correspond to n_parts. - * @param[in] n_int_streams size of internalStreams. When this is <= 0, only the - * user stream will be used. - * @param[in] rowMajorIndex are the index arrays in row-major layout? - * @param[in] rowMajorQuery are the query array in row-major layout? - * @param[in] translations translation ids for indices when index rows represent - * non-contiguous partitions - * @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded) - * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm - */ -template -void brute_force_knn_impl( - raft::resources const& handle, - std::vector& input, - std::vector& sizes, - IntType D, - value_t* search_items, - IntType n, - IdxType* res_I, - value_t* res_D, - IntType k, - bool rowMajorIndex = true, - bool rowMajorQuery = true, - std::vector* translations = nullptr, - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, - float metricArg = 0, - DistanceEpilogue distance_epilogue = raft::identity_op(), - std::vector* input_norms = nullptr, - const value_t* search_norms = nullptr) -{ - auto userStream = resource::get_cuda_stream(handle); - - ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size"); - - std::vector* id_ranges; - if (translations == nullptr) { - // If we don't have explicit translations - // for offsets of the indices, build them - // from the local partitions - id_ranges = new std::vector(); - IdxType total_n = 0; - for (size_t i = 0; i < input.size(); i++) { - id_ranges->push_back(total_n); - total_n += sizes[i]; - } - } else { - // otherwise, use the given translations - id_ranges = translations; - } - - int device; - RAFT_CUDA_TRY(cudaGetDevice(&device)); - - rmm::device_uvector trans(id_ranges->size(), userStream); - raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream); - - rmm::device_uvector all_D(0, userStream); - rmm::device_uvector all_I(0, userStream); - - value_t* out_D = res_D; - IdxType* out_I = res_I; - - if (input.size() > 1) { - all_D.resize(input.size() * k * n, userStream); - all_I.resize(input.size() * k * n, userStream); - - out_D = all_D.data(); - out_I = all_I.data(); - } - - // currently we don't support col_major inside tiled_brute_force_knn, because - // of limitations of the pairwise_distance API: - // 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have - // multiple options here (like rowMajorQuery/rowMajorIndex) - // 2) because of tiling, we need to be able to set a custom stride in the PW - // api, which isn't supported - // Instead, transpose the input matrices if they are passed as col-major. - auto search = search_items; - rmm::device_uvector search_row_major(0, userStream); - if (!rowMajorQuery) { - search_row_major.resize(n * D, userStream); - raft::linalg::transpose(handle, search, search_row_major.data(), n, D, userStream); - search = search_row_major.data(); - } - - // transpose into a temporary buffer if necessary - rmm::device_uvector index_row_major(0, userStream); - if (!rowMajorIndex) { - size_t total_size = 0; - for (auto size : sizes) { - total_size += size; - } - index_row_major.resize(total_size * D, userStream); - } - - // Make other streams from pool wait on main stream - resource::wait_stream_pool_on_stream(handle); - - size_t total_rows_processed = 0; - for (size_t i = 0; i < input.size(); i++) { - value_t* out_d_ptr = out_D + (i * k * n); - IdxType* out_i_ptr = out_I + (i * k * n); - - auto stream = resource::get_next_usable_stream(handle, i); - - if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && - std::is_same_v && - (metric == raft::distance::DistanceType::L2Unexpanded || - metric == raft::distance::DistanceType::L2SqrtUnexpanded || - metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded)) { - fusedL2Knn(D, - out_i_ptr, - out_d_ptr, - input[i], - search_items, - sizes[i], - n, - k, - rowMajorIndex, - rowMajorQuery, - stream, - metric, - input_norms ? (*input_norms)[i] : nullptr, - search_norms); - - // Perform necessary post-processing - if (metric == raft::distance::DistanceType::L2SqrtExpanded || - metric == raft::distance::DistanceType::L2SqrtUnexpanded || - metric == raft::distance::DistanceType::LpUnexpanded) { - value_t p = 0.5; // standard l2 - if (metric == raft::distance::DistanceType::LpUnexpanded) p = 1.0 / metricArg; - raft::linalg::unaryOp( - res_D, - res_D, - n * k, - [p] __device__(value_t input) { return powf(fabsf(input), p); }, - stream); - } - } else { - switch (metric) { - case raft::distance::DistanceType::Haversine: - ASSERT(D == 2, - "Haversine distance requires 2 dimensions " - "(latitude / longitude)."); - - haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k, stream); - break; - default: - // Create a new handle with the current stream from the stream pool - raft::resources stream_pool_handle(handle); - raft::resource::set_cuda_stream(stream_pool_handle, stream); - - auto index = input[i]; - if (!rowMajorIndex) { - index = index_row_major.data() + total_rows_processed * D; - total_rows_processed += sizes[i]; - raft::linalg::transpose(handle, input[i], index, sizes[i], D, stream); - } - - tiled_brute_force_knn(stream_pool_handle, - search, - index, - n, - sizes[i], - D, - k, - out_d_ptr, - out_i_ptr, - metric, - metricArg, - 0, - 0, - distance_epilogue, - input_norms ? (*input_norms)[i] : nullptr, - search_norms); - break; - } - } - - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } - - // Sync internal streams if used. We don't need to - // sync the user stream because we'll already have - // fully serial execution. - resource::sync_stream_pool(handle); - - if (input.size() > 1 || translations != nullptr) { - // This is necessary for proper index translations. If there are - // no translations or partitions to combine, it can be skipped. - knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, trans.data()); - } - - if (translations == nullptr) delete id_ranges; -}; - -template -void brute_force_search( - raft::resources const& res, - const raft::neighbors::brute_force::index& idx, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - std::optional> query_norms = std::nullopt) -{ - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs"); - RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1), - "Number of columns in queries must match brute force index"); - - auto k = neighbors.extent(1); - auto d = idx.dataset().extent(1); - - std::vector dataset = {const_cast(idx.dataset().data_handle())}; - std::vector sizes = {idx.dataset().extent(0)}; - std::vector norms; - if (idx.has_norms()) { norms.push_back(const_cast(idx.norms().data_handle())); } - - brute_force_knn_impl(res, - dataset, - sizes, - d, - const_cast(queries.data_handle()), - queries.extent(0), - neighbors.data_handle(), - distances.data_handle(), - k, - true, - true, - nullptr, - idx.metric(), - idx.metric_arg(), - raft::identity_op(), - norms.size() ? &norms : nullptr, - query_norms ? query_norms->data_handle() : nullptr); -} -} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force_batch_k_query.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force_batch_k_query.cuh deleted file mode 100644 index 384eacae79..0000000000 --- a/cpp/include/raft/neighbors/detail/knn_brute_force_batch_k_query.cuh +++ /dev/null @@ -1,98 +0,0 @@ -/* - * Copyright (c) 2020-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include -#include -#include -#include - -namespace raft::neighbors::brute_force::detail { -template -class gpu_batch_k_query : public batch_k_query { - public: - gpu_batch_k_query(const raft::resources& res, - const raft::neighbors::brute_force::index& index, - raft::device_matrix_view query, - int64_t batch_size) - : batch_k_query(res, index.size(), query.extent(0), batch_size), - index(index), - query(query) - { - auto metric = index.metric(); - - // precompute query norms, and re-use across batches - if (metric == raft::distance::DistanceType::L2Expanded || - metric == raft::distance::DistanceType::L2SqrtExpanded || - metric == raft::distance::DistanceType::CosineExpanded) { - query_norms = make_device_vector(res, query.extent(0)); - - if (metric == raft::distance::DistanceType::CosineExpanded) { - raft::linalg::norm(res, - query, - query_norms->view(), - raft::linalg::NormType::L2Norm, - raft::linalg::Apply::ALONG_ROWS, - raft::sqrt_op{}); - } else { - raft::linalg::norm(res, - query, - query_norms->view(), - raft::linalg::NormType::L2Norm, - raft::linalg::Apply::ALONG_ROWS); - } - } - } - - protected: - void load_batch(int64_t offset, int64_t next_batch_size, batch* output) const override - { - if (offset >= index.size()) { return; } - - // we're aiming to load multiple batches here - since we don't know the max iteration - // grow the size we're loading exponentially - int64_t batch_size = std::min(std::max(offset * 2, next_batch_size * 2), this->index_size); - output->resize(this->res, this->query_size, batch_size); - - std::optional> query_norms_view; - if (query_norms) { query_norms_view = query_norms->view(); } - - raft::neighbors::detail::brute_force_search( - this->res, index, query, output->indices(), output->distances(), query_norms_view); - }; - - void slice_batch(const batch& input, - int64_t offset, - int64_t batch_size, - batch* output) const override - { - auto num_queries = input.indices().extent(0); - batch_size = std::min(batch_size, index.size() - offset); - - output->resize(this->res, num_queries, batch_size); - - if (!num_queries || !batch_size) { return; } - - matrix::slice_coordinates coords{0, offset, num_queries, offset + batch_size}; - matrix::slice(this->res, input.indices(), output->indices(), coords); - matrix::slice(this->res, input.distances(), output->distances(), coords); - } - - const raft::neighbors::brute_force::index& index; - raft::device_matrix_view query; - std::optional> query_norms; -}; -} // namespace raft::neighbors::brute_force::detail diff --git a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh b/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh deleted file mode 100644 index 33324714fd..0000000000 --- a/cpp/include/raft/neighbors/detail/knn_merge_parts.cuh +++ /dev/null @@ -1,175 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -#include - -namespace raft::neighbors::detail { - -template -RAFT_KERNEL knn_merge_parts_kernel(const value_t* inK, - const value_idx* inV, - value_t* outK, - value_idx* outV, - size_t n_samples, - int n_parts, - value_t initK, - value_idx initV, - int k, - value_idx* translations) -{ - constexpr int kNumWarps = tpb / WarpSize; - - __shared__ value_t smemK[kNumWarps * warp_q]; - __shared__ value_idx smemV[kNumWarps * warp_q]; - - /** - * Uses shared memory - */ - faiss_select:: - BlockSelect, warp_q, thread_q, tpb> - heap(initK, initV, smemK, smemV, k); - - // Grid is exactly sized to rows available - int row = blockIdx.x; - int total_k = k * n_parts; - - int i = threadIdx.x; - - // Get starting pointers for cols in current thread - int part = i / k; - size_t row_idx = (row * k) + (part * n_samples * k); - - int col = i % k; - - const value_t* inKStart = inK + (row_idx + col); - const value_idx* inVStart = inV + (row_idx + col); - - int limit = Pow2::roundDown(total_k); - value_idx translation = 0; - - for (; i < limit; i += tpb) { - translation = translations[part]; - heap.add(*inKStart, (*inVStart) + translation); - - part = (i + tpb) / k; - row_idx = (row * k) + (part * n_samples * k); - - col = (i + tpb) % k; - - inKStart = inK + (row_idx + col); - inVStart = inV + (row_idx + col); - } - - // Handle last remainder fraction of a warp of elements - if (i < total_k) { - translation = translations[part]; - heap.addThreadQ(*inKStart, (*inVStart) + translation); - } - - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += tpb) { - outK[row * k + i] = smemK[i]; - outV[row * k + i] = smemV[i]; - } -} - -template -inline void knn_merge_parts_impl(const value_t* inK, - const value_idx* inV, - value_t* outK, - value_idx* outV, - size_t n_samples, - int n_parts, - int k, - cudaStream_t stream, - value_idx* translations) -{ - auto grid = dim3(n_samples); - - constexpr int n_threads = (warp_q < 1024) ? 128 : 64; - auto block = dim3(n_threads); - - auto kInit = std::numeric_limits::max(); - auto vInit = -1; - knn_merge_parts_kernel - <<>>( - inK, inV, outK, outV, n_samples, n_parts, kInit, vInit, k, translations); - RAFT_CUDA_TRY(cudaPeekAtLastError()); -} - -/** - * @brief Merge knn distances and index matrix, which have been partitioned - * by row, into a single matrix with only the k-nearest neighbors. - * - * @param inK partitioned knn distance matrix - * @param inV partitioned knn index matrix - * @param outK merged knn distance matrix - * @param outV merged knn index matrix - * @param n_samples number of samples per partition - * @param n_parts number of partitions - * @param k number of neighbors per partition (also number of merged neighbors) - * @param stream CUDA stream to use - * @param translations mapping of index offsets for each partition - */ -template -inline void knn_merge_parts(const value_t* inK, - const value_idx* inV, - value_t* outK, - value_idx* outV, - size_t n_samples, - int n_parts, - int k, - cudaStream_t stream, - value_idx* translations) -{ - if (k == 1) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 32) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 64) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 128) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 256) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 512) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else if (k <= 1024) - knn_merge_parts_impl( - inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); - else - THROW("Unimplemented for k=%d, knn_merge_parts works for k<=1024", k); -} -} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/nn_descent.cuh b/cpp/include/raft/neighbors/detail/nn_descent.cuh deleted file mode 100644 index 02610f9afb..0000000000 --- a/cpp/include/raft/neighbors/detail/nn_descent.cuh +++ /dev/null @@ -1,1534 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "../nn_descent_types.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include // raft::util::arch::SM_* -#include -#include -#include -#include - -#include - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include -#include -#include -#include - -namespace raft::neighbors::experimental::nn_descent::detail { - -using pinned_memory_resource = thrust::universal_host_pinned_memory_resource; -template -using pinned_memory_allocator = thrust::mr::stateless_resource_allocator; - -using DistData_t = float; -constexpr int DEGREE_ON_DEVICE{32}; -constexpr int SEGMENT_SIZE{32}; -constexpr int counter_interval{100}; -template -struct InternalID_t; - -// InternalID_t uses 1 bit for marking (new or old). -template <> -class InternalID_t { - private: - using Index_t = int; - Index_t id_{std::numeric_limits::max()}; - - public: - __host__ __device__ bool is_new() const { return id_ >= 0; } - __host__ __device__ Index_t& id_with_flag() { return id_; } - __host__ __device__ Index_t id() const - { - if (is_new()) return id_; - return -id_ - 1; - } - __host__ __device__ void mark_old() - { - if (id_ >= 0) id_ = -id_ - 1; - } - __host__ __device__ bool operator==(const InternalID_t& other) const - { - return id() == other.id(); - } -}; - -template -struct ResultItem; - -template <> -class ResultItem { - private: - using Index_t = int; - Index_t id_; - DistData_t dist_; - - public: - __host__ __device__ ResultItem() - : id_(std::numeric_limits::max()), dist_(std::numeric_limits::max()){}; - __host__ __device__ ResultItem(const Index_t id_with_flag, const DistData_t dist) - : id_(id_with_flag), dist_(dist){}; - __host__ __device__ bool is_new() const { return id_ >= 0; } - __host__ __device__ Index_t& id_with_flag() { return id_; } - __host__ __device__ Index_t id() const - { - if (is_new()) return id_; - return -id_ - 1; - } - __host__ __device__ DistData_t& dist() { return dist_; } - - __host__ __device__ void mark_old() - { - if (id_ >= 0) id_ = -id_ - 1; - } - - __host__ __device__ bool operator<(const ResultItem& other) const - { - if (dist_ == other.dist_) return id() < other.id(); - return dist_ < other.dist_; - } - __host__ __device__ bool operator==(const ResultItem& other) const - { - return id() == other.id(); - } - __host__ __device__ bool operator>=(const ResultItem& other) const - { - return !(*this < other); - } - __host__ __device__ bool operator<=(const ResultItem& other) const - { - return (*this == other) || (*this < other); - } - __host__ __device__ bool operator>(const ResultItem& other) const - { - return !(*this <= other); - } - __host__ __device__ bool operator!=(const ResultItem& other) const - { - return !(*this == other); - } -}; - -using align32 = raft::Pow2<32>; - -template -int get_batch_size(const int it_now, const T nrow, const int batch_size) -{ - int it_total = ceildiv(nrow, batch_size); - return (it_now == it_total - 1) ? nrow - it_now * batch_size : batch_size; -} - -// for avoiding bank conflict -template -constexpr __host__ __device__ __forceinline__ int skew_dim(int ndim) -{ - // all "4"s are for alignment - if constexpr (std::is_same::value) { - ndim = ceildiv(ndim, 4) * 4; - return ndim + (ndim % 32 == 0) * 4; - } -} - -template -__device__ __forceinline__ ResultItem xor_swap(ResultItem x, int mask, int dir) -{ - ResultItem y; - y.dist() = __shfl_xor_sync(raft::warp_full_mask(), x.dist(), mask, raft::warp_size()); - y.id_with_flag() = - __shfl_xor_sync(raft::warp_full_mask(), x.id_with_flag(), mask, raft::warp_size()); - return x < y == dir ? y : x; -} - -__device__ __forceinline__ int xor_swap(int x, int mask, int dir) -{ - int y = __shfl_xor_sync(raft::warp_full_mask(), x, mask, raft::warp_size()); - return x < y == dir ? y : x; -} - -// TODO: Move to RAFT utils https://github.com/rapidsai/raft/issues/1827 -__device__ __forceinline__ uint bfe(uint lane_id, uint pos) -{ - uint res; - asm("bfe.u32 %0,%1,%2,%3;" : "=r"(res) : "r"(lane_id), "r"(pos), "r"(1)); - return res; -} - -template -__device__ __forceinline__ void warp_bitonic_sort(T* element_ptr, const int lane_id) -{ - static_assert(raft::warp_size() == 32); - auto& element = *element_ptr; - element = xor_swap(element, 0x01, bfe(lane_id, 1) ^ bfe(lane_id, 0)); - element = xor_swap(element, 0x02, bfe(lane_id, 2) ^ bfe(lane_id, 1)); - element = xor_swap(element, 0x01, bfe(lane_id, 2) ^ bfe(lane_id, 0)); - element = xor_swap(element, 0x04, bfe(lane_id, 3) ^ bfe(lane_id, 2)); - element = xor_swap(element, 0x02, bfe(lane_id, 3) ^ bfe(lane_id, 1)); - element = xor_swap(element, 0x01, bfe(lane_id, 3) ^ bfe(lane_id, 0)); - element = xor_swap(element, 0x08, bfe(lane_id, 4) ^ bfe(lane_id, 3)); - element = xor_swap(element, 0x04, bfe(lane_id, 4) ^ bfe(lane_id, 2)); - element = xor_swap(element, 0x02, bfe(lane_id, 4) ^ bfe(lane_id, 1)); - element = xor_swap(element, 0x01, bfe(lane_id, 4) ^ bfe(lane_id, 0)); - element = xor_swap(element, 0x10, bfe(lane_id, 4)); - element = xor_swap(element, 0x08, bfe(lane_id, 3)); - element = xor_swap(element, 0x04, bfe(lane_id, 2)); - element = xor_swap(element, 0x02, bfe(lane_id, 1)); - element = xor_swap(element, 0x01, bfe(lane_id, 0)); - return; -} - -struct BuildConfig { - size_t max_dataset_size; - size_t dataset_dim; - size_t node_degree{64}; - size_t internal_node_degree{0}; - // If internal_node_degree == 0, the value of node_degree will be assigned to it - size_t max_iterations{50}; - float termination_threshold{0.0001}; - size_t output_graph_degree{32}; -}; - -template -class BloomFilter { - public: - BloomFilter(size_t nrow, size_t num_sets_per_list, size_t num_hashs) - : nrow_(nrow), - num_sets_per_list_(num_sets_per_list), - num_hashs_(num_hashs), - bitsets_(nrow * num_bits_per_set_ * num_sets_per_list) - { - } - - void add(size_t list_id, Index_t key) - { - if (is_cleared) { is_cleared = false; } - uint32_t hash = hash_0(key); - size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + - key % num_sets_per_list_ * num_bits_per_set_; - bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; - for (size_t i = 1; i < num_hashs_; i++) { - hash = hash + hash_1(key); - bitsets_[global_set_idx + hash % num_bits_per_set_] = 1; - } - } - - bool check(size_t list_id, Index_t key) - { - bool is_present = true; - uint32_t hash = hash_0(key); - size_t global_set_idx = list_id * num_bits_per_set_ * num_sets_per_list_ + - key % num_sets_per_list_ * num_bits_per_set_; - is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; - - if (!is_present) return false; - for (size_t i = 1; i < num_hashs_; i++) { - hash = hash + hash_1(key); - is_present &= bitsets_[global_set_idx + hash % num_bits_per_set_]; - if (!is_present) return false; - } - return true; - } - - void clear() - { - if (is_cleared) return; -#pragma omp parallel for - for (size_t i = 0; i < nrow_ * num_bits_per_set_ * num_sets_per_list_; i++) { - bitsets_[i] = 0; - } - is_cleared = true; - } - - private: - uint32_t hash_0(uint32_t value) - { - value *= 1103515245; - value += 12345; - value ^= value << 13; - value ^= value >> 17; - value ^= value << 5; - return value; - } - - uint32_t hash_1(uint32_t value) - { - value *= 1664525; - value += 1013904223; - value ^= value << 13; - value ^= value >> 17; - value ^= value << 5; - return value; - } - - static constexpr int num_bits_per_set_ = 512; - bool is_cleared{true}; - std::vector bitsets_; - size_t nrow_; - size_t num_sets_per_list_; - size_t num_hashs_; -}; - -template -struct GnndGraph { - static constexpr int segment_size = 32; - InternalID_t* h_graph; - - size_t nrow; - size_t node_degree; - int num_samples; - int num_segments; - - raft::host_matrix h_dists; - - thrust::host_vector> h_graph_new; - thrust::host_vector> h_list_sizes_new; - - thrust::host_vector> h_graph_old; - thrust::host_vector> h_list_sizes_old; - BloomFilter bloom_filter; - - GnndGraph(const GnndGraph&) = delete; - GnndGraph& operator=(const GnndGraph&) = delete; - GnndGraph(const size_t nrow, - const size_t node_degree, - const size_t internal_node_degree, - const size_t num_samples); - void init_random_graph(); - // TODO: Create a generic bloom filter utility https://github.com/rapidsai/raft/issues/1827 - // Use Bloom filter to sample "new" neighbors for local joining - void sample_graph_new(InternalID_t* new_neighbors, const size_t width); - void sample_graph(bool sample_new); - void update_graph(const InternalID_t* new_neighbors, - const DistData_t* new_dists, - const size_t width, - std::atomic& update_counter); - void sort_lists(); - void clear(); - ~GnndGraph(); -}; - -template > -class GNND { - public: - GNND(raft::resources const& res, const BuildConfig& build_config); - GNND(const GNND&) = delete; - GNND& operator=(const GNND&) = delete; - - void build(Data_t* data, - const Index_t nrow, - Index_t* output_graph, - bool return_distances, - DistData_t* output_distances, - epilogue_op distance_epilogue = DistEpilogue()); - ~GNND() = default; - using ID_t = InternalID_t; - void reset(raft::resources const& res); - - private: - void add_reverse_edges(Index_t* graph_ptr, - Index_t* h_rev_graph_ptr, - Index_t* d_rev_graph_ptr, - int2* list_sizes, - cudaStream_t stream = 0); - void local_join(cudaStream_t stream = 0, - epilogue_op distance_epilogue = DistEpilogue()); - - raft::resources const& res; - - BuildConfig build_config_; - GnndGraph graph_; - std::atomic update_counter_; - - size_t nrow_; - size_t ndim_; - - raft::device_matrix<__half, size_t, raft::row_major> d_data_; - raft::device_vector l2_norms_; - - raft::device_matrix graph_buffer_; - raft::device_matrix dists_buffer_; - - // TODO: Investigate using RMM/RAFT types https://github.com/rapidsai/raft/issues/1827 - thrust::host_vector> graph_host_buffer_; - thrust::host_vector> dists_host_buffer_; - - raft::device_vector d_locks_; - - thrust::host_vector> h_rev_graph_new_; - thrust::host_vector> h_graph_old_; - thrust::host_vector> h_rev_graph_old_; - // int2.x is the number of forward edges, int2.y is the number of reverse edges - - raft::device_vector d_list_sizes_new_; - raft::device_vector d_list_sizes_old_; -}; - -constexpr int TILE_ROW_WIDTH = 64; -constexpr int TILE_COL_WIDTH = 128; - -constexpr int NUM_SAMPLES = 32; -// For now, the max. number of samples is 32, so the sample cache size is fixed -// to 64 (32 * 2). -constexpr int MAX_NUM_BI_SAMPLES = 64; -constexpr int SKEWED_MAX_NUM_BI_SAMPLES = skew_dim(MAX_NUM_BI_SAMPLES); -constexpr int BLOCK_SIZE = 512; -constexpr int WMMA_M = 16; -constexpr int WMMA_N = 16; -constexpr int WMMA_K = 16; - -template -__device__ __forceinline__ void load_vec(Data_t* vec_buffer, - const Data_t* d_vec, - const int load_dims, - const int padding_dims, - const int lane_id) -{ - if constexpr (std::is_same_v or std::is_same_v or - std::is_same_v) { - constexpr int num_load_elems_per_warp = raft::warp_size(); - for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { - int idx = step * num_load_elems_per_warp + lane_id; - if (idx < load_dims) { - vec_buffer[idx] = d_vec[idx]; - } else if (idx < padding_dims) { - vec_buffer[idx] = 0.0f; - } - } - } - if constexpr (std::is_same_v) { - if ((size_t)d_vec % sizeof(float2) == 0 && (size_t)vec_buffer % sizeof(float2) == 0 && - load_dims % 4 == 0 && padding_dims % 4 == 0) { - constexpr int num_load_elems_per_warp = raft::warp_size() * 4; -#pragma unroll - for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { - int idx_in_vec = step * num_load_elems_per_warp + lane_id * 4; - if (idx_in_vec + 4 <= load_dims) { - *(float2*)(vec_buffer + idx_in_vec) = *(float2*)(d_vec + idx_in_vec); - } else if (idx_in_vec + 4 <= padding_dims) { - *(float2*)(vec_buffer + idx_in_vec) = float2({0.0f, 0.0f}); - } - } - } else { - constexpr int num_load_elems_per_warp = raft::warp_size(); - for (int step = 0; step < ceildiv(padding_dims, num_load_elems_per_warp); step++) { - int idx = step * num_load_elems_per_warp + lane_id; - if (idx < load_dims) { - vec_buffer[idx] = d_vec[idx]; - } else if (idx < padding_dims) { - vec_buffer[idx] = 0.0f; - } - } - } - } -} - -// TODO: Replace with RAFT utilities https://github.com/rapidsai/raft/issues/1827 -/** Calculate L2 norm, and cast data to __half */ -template -RAFT_KERNEL preprocess_data_kernel(const Data_t* input_data, - __half* output_data, - int dim, - DistData_t* l2_norms, - size_t list_offset = 0) -{ - extern __shared__ char buffer[]; - __shared__ float l2_norm; - Data_t* s_vec = (Data_t*)buffer; - size_t list_id = list_offset + blockIdx.x; - - load_vec(s_vec, input_data + blockIdx.x * dim, dim, dim, threadIdx.x % raft::warp_size()); - if (threadIdx.x == 0) { l2_norm = 0; } - __syncthreads(); - int lane_id = threadIdx.x % raft::warp_size(); - for (int step = 0; step < ceildiv(dim, raft::warp_size()); step++) { - int idx = step * raft::warp_size() + lane_id; - float part_dist = 0; - if (idx < dim) { - part_dist = s_vec[idx]; - part_dist = part_dist * part_dist; - } - __syncwarp(); - for (int offset = raft::warp_size() >> 1; offset >= 1; offset >>= 1) { - part_dist += __shfl_down_sync(raft::warp_full_mask(), part_dist, offset); - } - if (lane_id == 0) { l2_norm += part_dist; } - __syncwarp(); - } - - for (int step = 0; step < ceildiv(dim, raft::warp_size()); step++) { - int idx = step * raft::warp_size() + threadIdx.x; - if (idx < dim) { - if (l2_norms == nullptr) { - output_data[list_id * dim + idx] = - (float)input_data[(size_t)blockIdx.x * dim + idx] / sqrt(l2_norm); - } else { - output_data[list_id * dim + idx] = input_data[(size_t)blockIdx.x * dim + idx]; - if (idx == 0) { l2_norms[list_id] = l2_norm; } - } - } - } -} - -template -RAFT_KERNEL add_rev_edges_kernel(const Index_t* graph, - Index_t* rev_graph, - int num_samples, - int2* list_sizes) -{ - size_t list_id = blockIdx.x; - int2 list_size = list_sizes[list_id]; - - for (int idx = threadIdx.x; idx < list_size.x; idx += blockDim.x) { - // each node has same number (num_samples) of forward and reverse edges - size_t rev_list_id = graph[list_id * num_samples + idx]; - // there are already num_samples forward edges - int idx_in_rev_list = atomicAdd(&list_sizes[rev_list_id].y, 1); - if (idx_in_rev_list >= num_samples) { - atomicExch(&list_sizes[rev_list_id].y, num_samples); - } else { - rev_graph[rev_list_id * num_samples + idx_in_rev_list] = list_id; - } - } -} - -template > -__device__ void insert_to_global_graph(ResultItem elem, - size_t list_id, - ID_t* graph, - DistData_t* dists, - int node_degree, - int* locks) -{ - int tx = threadIdx.x; - int lane_id = tx % raft::warp_size(); - size_t global_idx_base = list_id * node_degree; - if (elem.id() == list_id) return; - - const int num_segments = ceildiv(node_degree, raft::warp_size()); - - int loop_flag = 0; - do { - int segment_id = elem.id() % num_segments; - if (lane_id == 0) { - loop_flag = atomicCAS(&locks[list_id * num_segments + segment_id], 0, 1) == 0; - } - - loop_flag = __shfl_sync(raft::warp_full_mask(), loop_flag, 0); - - if (loop_flag == 1) { - ResultItem knn_list_frag; - int local_idx = segment_id * raft::warp_size() + lane_id; - size_t global_idx = global_idx_base + local_idx; - if (local_idx < node_degree) { - knn_list_frag.id_with_flag() = graph[global_idx].id_with_flag(); - knn_list_frag.dist() = dists[global_idx]; - } - - int pos_to_insert = -1; - ResultItem prev_elem; - - prev_elem.id_with_flag() = - __shfl_up_sync(raft::warp_full_mask(), knn_list_frag.id_with_flag(), 1); - prev_elem.dist() = __shfl_up_sync(raft::warp_full_mask(), knn_list_frag.dist(), 1); - - if (lane_id == 0) { - prev_elem = ResultItem{std::numeric_limits::min(), - std::numeric_limits::lowest()}; - } - if (elem > prev_elem && elem < knn_list_frag) { - pos_to_insert = segment_id * raft::warp_size() + lane_id; - } else if (elem == prev_elem || elem == knn_list_frag) { - pos_to_insert = -2; - } - uint mask = __ballot_sync(raft::warp_full_mask(), pos_to_insert >= 0); - if (mask) { - uint set_lane_id = __fns(mask, 0, 1); - pos_to_insert = __shfl_sync(raft::warp_full_mask(), pos_to_insert, set_lane_id); - } - - if (pos_to_insert >= 0) { - int local_idx = segment_id * raft::warp_size() + lane_id; - if (local_idx > pos_to_insert) { - local_idx++; - } else if (local_idx == pos_to_insert) { - graph[global_idx_base + local_idx].id_with_flag() = elem.id_with_flag(); - dists[global_idx_base + local_idx] = elem.dist(); - local_idx++; - } - size_t global_pos = global_idx_base + local_idx; - if (local_idx < (segment_id + 1) * raft::warp_size() && local_idx < node_degree) { - graph[global_pos].id_with_flag() = knn_list_frag.id_with_flag(); - dists[global_pos] = knn_list_frag.dist(); - } - } - __threadfence(); - if (loop_flag && lane_id == 0) { atomicExch(&locks[list_id * num_segments + segment_id], 0); } - } - } while (!loop_flag); -} - -template -__device__ ResultItem get_min_item(const Index_t id, - const int idx_in_list, - const Index_t* neighbs, - const DistData_t* distances, - const bool find_in_row = true) -{ - int lane_id = threadIdx.x % raft::warp_size(); - - static_assert(MAX_NUM_BI_SAMPLES == 64); - int idx[MAX_NUM_BI_SAMPLES / raft::warp_size()]; - float dist[MAX_NUM_BI_SAMPLES / raft::warp_size()] = {std::numeric_limits::max(), - std::numeric_limits::max()}; - idx[0] = lane_id; - idx[1] = raft::warp_size() + lane_id; - - if (neighbs[idx[0]] != id) { - dist[0] = find_in_row ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + lane_id] - : distances[idx_in_list + lane_id * SKEWED_MAX_NUM_BI_SAMPLES]; - } - - if (neighbs[idx[1]] != id) { - dist[1] = - find_in_row - ? distances[idx_in_list * SKEWED_MAX_NUM_BI_SAMPLES + raft::warp_size() + lane_id] - : distances[idx_in_list + (raft::warp_size() + lane_id) * SKEWED_MAX_NUM_BI_SAMPLES]; - } - - if (dist[1] < dist[0]) { - dist[0] = dist[1]; - idx[0] = idx[1]; - } - __syncwarp(); - for (int offset = raft::warp_size() >> 1; offset >= 1; offset >>= 1) { - float other_idx = __shfl_down_sync(raft::warp_full_mask(), idx[0], offset); - float other_dist = __shfl_down_sync(raft::warp_full_mask(), dist[0], offset); - if (other_dist < dist[0]) { - dist[0] = other_dist; - idx[0] = other_idx; - } - } - - ResultItem result; - result.dist() = __shfl_sync(raft::warp_full_mask(), dist[0], 0); - result.id_with_flag() = neighbs[__shfl_sync(raft::warp_full_mask(), idx[0], 0)]; - return result; -} - -template -__device__ __forceinline__ void remove_duplicates( - T* list_a, int list_a_size, T* list_b, int list_b_size, int& unique_counter, int execute_warp_id) -{ - static_assert(raft::warp_size() == 32); - if (!(threadIdx.x >= execute_warp_id * raft::warp_size() && - threadIdx.x < execute_warp_id * raft::warp_size() + raft::warp_size())) { - return; - } - int lane_id = threadIdx.x % raft::warp_size(); - T elem = std::numeric_limits::max(); - if (lane_id < list_a_size) { elem = list_a[lane_id]; } - warp_bitonic_sort(&elem, lane_id); - - if (elem != std::numeric_limits::max()) { list_a[lane_id] = elem; } - - T elem_b = std::numeric_limits::max(); - - if (lane_id < list_b_size) { elem_b = list_b[lane_id]; } - __syncwarp(); - - int idx_l = 0; - int idx_r = list_a_size; - bool existed = false; - while (idx_l < idx_r) { - int idx = (idx_l + idx_r) / 2; - int elem = list_a[idx]; - if (elem == elem_b) { - existed = true; - break; - } - if (elem_b > elem) { - idx_l = idx + 1; - } else { - idx_r = idx; - } - } - if (!existed && elem_b != std::numeric_limits::max()) { - int idx = atomicAdd(&unique_counter, 1); - list_a[list_a_size + idx] = elem_b; - } -} - -// launch_bounds here denote BLOCK_SIZE = 512 and MIN_BLOCKS_PER_SM = 4 -// Per -// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#features-and-technical-specifications, -// MAX_RESIDENT_THREAD_PER_SM = BLOCK_SIZE * BLOCKS_PER_SM = 2048 -// For architectures 750 and 860 (890), the values for MAX_RESIDENT_THREAD_PER_SM -// is 1024 and 1536 respectively, which means the bounds don't work anymore -template , - typename epilogue_op = DistEpilogue> -RAFT_KERNEL -#ifdef __CUDA_ARCH__ -#if (__CUDA_ARCH__) == 750 || ((__CUDA_ARCH__) >= 860 && (__CUDA_ARCH__) <= 890) -__launch_bounds__(BLOCK_SIZE) -#else -__launch_bounds__(BLOCK_SIZE, 4) -#endif -#endif - local_join_kernel(const Index_t* graph_new, - const Index_t* rev_graph_new, - const int2* sizes_new, - const Index_t* graph_old, - const Index_t* rev_graph_old, - const int2* sizes_old, - const int width, - const __half* data, - const int data_dim, - ID_t* graph, - DistData_t* dists, - int graph_width, - int* locks, - DistData_t* l2_norms, - epilogue_op distance_epilogue) -{ -#if (__CUDA_ARCH__ >= 700) - using namespace nvcuda; - __shared__ int s_list[MAX_NUM_BI_SAMPLES * 2]; - - constexpr int APAD = 8; - constexpr int BPAD = 8; - __shared__ __half s_nv[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + APAD]; // New vectors - __shared__ __half s_ov[MAX_NUM_BI_SAMPLES][TILE_COL_WIDTH + BPAD]; // Old vectors - static_assert(sizeof(float) * MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES <= - sizeof(__half) * MAX_NUM_BI_SAMPLES * (TILE_COL_WIDTH + BPAD)); - // s_distances: MAX_NUM_BI_SAMPLES x SKEWED_MAX_NUM_BI_SAMPLES, reuse the space of s_ov - float* s_distances = (float*)&s_ov[0][0]; - int* s_unique_counter = (int*)&s_ov[0][0]; - - if (threadIdx.x == 0) { - s_unique_counter[0] = 0; - s_unique_counter[1] = 0; - } - - Index_t* new_neighbors = s_list; - Index_t* old_neighbors = s_list + MAX_NUM_BI_SAMPLES; - - size_t list_id = blockIdx.x; - int2 list_new_size2 = sizes_new[list_id]; - int list_new_size = list_new_size2.x + list_new_size2.y; - int2 list_old_size2 = sizes_old[list_id]; - int list_old_size = list_old_size2.x + list_old_size2.y; - - if (!list_new_size) return; - int tx = threadIdx.x; - - if (tx < list_new_size2.x) { - new_neighbors[tx] = graph_new[list_id * width + tx]; - } else if (tx >= list_new_size2.x && tx < list_new_size) { - new_neighbors[tx] = rev_graph_new[list_id * width + tx - list_new_size2.x]; - } - - if (tx < list_old_size2.x) { - old_neighbors[tx] = graph_old[list_id * width + tx]; - } else if (tx >= list_old_size2.x && tx < list_old_size) { - old_neighbors[tx] = rev_graph_old[list_id * width + tx - list_old_size2.x]; - } - - __syncthreads(); - - remove_duplicates(new_neighbors, - list_new_size2.x, - new_neighbors + list_new_size2.x, - list_new_size2.y, - s_unique_counter[0], - 0); - - remove_duplicates(old_neighbors, - list_old_size2.x, - old_neighbors + list_old_size2.x, - list_old_size2.y, - s_unique_counter[1], - 1); - __syncthreads(); - list_new_size = list_new_size2.x + s_unique_counter[0]; - list_old_size = list_old_size2.x + s_unique_counter[1]; - - int warp_id = threadIdx.x / raft::warp_size(); - int lane_id = threadIdx.x % raft::warp_size(); - constexpr int num_warps = BLOCK_SIZE / raft::warp_size(); - - int warp_id_y = warp_id / 4; - int warp_id_x = warp_id % 4; - - wmma::fragment a_frag; - wmma::fragment b_frag; - wmma::fragment c_frag; - wmma::fill_fragment(c_frag, 0.0); - for (int step = 0; step < ceildiv(data_dim, TILE_COL_WIDTH); step++) { - int num_load_elems = (step == ceildiv(data_dim, TILE_COL_WIDTH) - 1) - ? data_dim - step * TILE_COL_WIDTH - : TILE_COL_WIDTH; -#pragma unroll - for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { - int idx = i * num_warps + warp_id; - if (idx < list_new_size) { - size_t neighbor_id = new_neighbors[idx]; - size_t idx_in_data = neighbor_id * data_dim; - load_vec(s_nv[idx], - data + idx_in_data + step * TILE_COL_WIDTH, - num_load_elems, - TILE_COL_WIDTH, - lane_id); - } - } - __syncthreads(); - - for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { - wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); - wmma::load_matrix_sync(b_frag, s_nv[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - __syncthreads(); - } - } - - wmma::store_matrix_sync( - s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, - c_frag, - SKEWED_MAX_NUM_BI_SAMPLES, - wmma::mem_row_major); - __syncthreads(); - - for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { - auto row_idx = i % SKEWED_MAX_NUM_BI_SAMPLES; - auto col_idx = i / SKEWED_MAX_NUM_BI_SAMPLES; - if (row_idx < list_new_size && col_idx < list_new_size) { - auto r = new_neighbors[row_idx]; - auto c = new_neighbors[col_idx]; - if (l2_norms == nullptr) { - auto dist_val = -s_distances[i]; - s_distances[i] = distance_epilogue(dist_val, r, c); - } else { - auto dist_val = l2_norms[r] + l2_norms[c] - 2.0 * s_distances[i]; - s_distances[i] = distance_epilogue(dist_val, r, c); - } - } else { - s_distances[i] = std::numeric_limits::max(); - } - } - __syncthreads(); - - for (int step = 0; step < ceildiv(list_new_size, num_warps); step++) { - int idx_in_list = step * num_warps + tx / raft::warp_size(); - if (idx_in_list >= list_new_size) continue; - auto min_elem = get_min_item(s_list[idx_in_list], idx_in_list, new_neighbors, s_distances); - if (min_elem.id() < gridDim.x) { - insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); - } - } - - if (!list_old_size) return; - - __syncthreads(); - - wmma::fill_fragment(c_frag, 0.0); - for (int step = 0; step < ceildiv(data_dim, TILE_COL_WIDTH); step++) { - int num_load_elems = (step == ceildiv(data_dim, TILE_COL_WIDTH) - 1) - ? data_dim - step * TILE_COL_WIDTH - : TILE_COL_WIDTH; - if (TILE_COL_WIDTH < data_dim) { -#pragma unroll - for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { - int idx = i * num_warps + warp_id; - if (idx < list_new_size) { - size_t neighbor_id = new_neighbors[idx]; - size_t idx_in_data = neighbor_id * data_dim; - load_vec(s_nv[idx], - data + idx_in_data + step * TILE_COL_WIDTH, - num_load_elems, - TILE_COL_WIDTH, - lane_id); - } - } - } -#pragma unroll - for (int i = 0; i < MAX_NUM_BI_SAMPLES / num_warps; i++) { - int idx = i * num_warps + warp_id; - if (idx < list_old_size) { - size_t neighbor_id = old_neighbors[idx]; - size_t idx_in_data = neighbor_id * data_dim; - load_vec(s_ov[idx], - data + idx_in_data + step * TILE_COL_WIDTH, - num_load_elems, - TILE_COL_WIDTH, - lane_id); - } - } - __syncthreads(); - - for (int i = 0; i < TILE_COL_WIDTH / WMMA_K; i++) { - wmma::load_matrix_sync(a_frag, s_nv[warp_id_y * WMMA_M] + i * WMMA_K, TILE_COL_WIDTH + APAD); - wmma::load_matrix_sync(b_frag, s_ov[warp_id_x * WMMA_N] + i * WMMA_K, TILE_COL_WIDTH + BPAD); - wmma::mma_sync(c_frag, a_frag, b_frag, c_frag); - __syncthreads(); - } - } - - wmma::store_matrix_sync( - s_distances + warp_id_y * WMMA_M * SKEWED_MAX_NUM_BI_SAMPLES + warp_id_x * WMMA_N, - c_frag, - SKEWED_MAX_NUM_BI_SAMPLES, - wmma::mem_row_major); - __syncthreads(); - - for (int i = threadIdx.x; i < MAX_NUM_BI_SAMPLES * SKEWED_MAX_NUM_BI_SAMPLES; i += blockDim.x) { - auto row_idx = i % SKEWED_MAX_NUM_BI_SAMPLES; - auto col_idx = i / SKEWED_MAX_NUM_BI_SAMPLES; - if (row_idx < list_old_size && col_idx < list_new_size) { - auto r = old_neighbors[row_idx]; - auto c = new_neighbors[col_idx]; - if (l2_norms == nullptr) { - auto dist_val = -s_distances[i]; - s_distances[i] = distance_epilogue(dist_val, r, c); - } else { - auto dist_val = l2_norms[r] + l2_norms[c] - 2.0 * s_distances[i]; - s_distances[i] = distance_epilogue(dist_val, r, c); - } - } else { - s_distances[i] = std::numeric_limits::max(); - } - } - __syncthreads(); - - for (int step = 0; step < ceildiv(MAX_NUM_BI_SAMPLES * 2, num_warps); step++) { - int idx_in_list = step * num_warps + tx / raft::warp_size(); - if (idx_in_list >= list_new_size && idx_in_list < MAX_NUM_BI_SAMPLES) continue; - if (idx_in_list >= MAX_NUM_BI_SAMPLES + list_old_size && idx_in_list < MAX_NUM_BI_SAMPLES * 2) - continue; - ResultItem min_elem{std::numeric_limits::max(), - std::numeric_limits::max()}; - if (idx_in_list < MAX_NUM_BI_SAMPLES) { - auto temp_min_item = - get_min_item(s_list[idx_in_list], idx_in_list, old_neighbors, s_distances); - if (temp_min_item.dist() < min_elem.dist()) { min_elem = temp_min_item; } - } else { - auto temp_min_item = get_min_item( - s_list[idx_in_list], idx_in_list - MAX_NUM_BI_SAMPLES, new_neighbors, s_distances, false); - if (temp_min_item.dist() < min_elem.dist()) { min_elem = temp_min_item; } - } - - if (min_elem.id() < gridDim.x) { - insert_to_global_graph(min_elem, s_list[idx_in_list], graph, dists, graph_width, locks); - } - } -#endif -} - -namespace { -template -int insert_to_ordered_list(InternalID_t* list, - DistData_t* dist_list, - const int width, - const InternalID_t neighb_id, - const DistData_t dist) -{ - if (dist > dist_list[width - 1]) { return width; } - - int idx_insert = width; - bool position_found = false; - for (int i = 0; i < width; i++) { - if (list[i].id() == neighb_id.id()) { return width; } - if (!position_found && dist_list[i] > dist) { - idx_insert = i; - position_found = true; - } - } - if (idx_insert == width) return idx_insert; - - memmove(list + idx_insert + 1, list + idx_insert, sizeof(*list) * (width - idx_insert - 1)); - memmove(dist_list + idx_insert + 1, - dist_list + idx_insert, - sizeof(*dist_list) * (width - idx_insert - 1)); - - list[idx_insert] = neighb_id; - dist_list[idx_insert] = dist; - return idx_insert; -}; - -} // namespace - -template -GnndGraph::GnndGraph(const size_t nrow, - const size_t node_degree, - const size_t internal_node_degree, - const size_t num_samples) - : nrow(nrow), - node_degree(node_degree), - num_samples(num_samples), - bloom_filter(nrow, internal_node_degree / segment_size, 3), - h_dists{raft::make_host_matrix(nrow, node_degree)}, - h_graph_new(nrow * num_samples), - h_list_sizes_new(nrow), - h_graph_old(nrow * num_samples), - h_list_sizes_old{nrow} -{ - // node_degree must be a multiple of segment_size; - assert(node_degree % segment_size == 0); - assert(internal_node_degree % segment_size == 0); - - num_segments = node_degree / segment_size; - // To save the CPU memory, graph should be allocated by external function - h_graph = nullptr; -} - -// This is the only operation on the CPU that cannot be overlapped. -// So it should be as fast as possible. -template -void GnndGraph::sample_graph_new(InternalID_t* new_neighbors, const size_t width) -{ -#pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - auto list_new = h_graph_new.data() + i * num_samples; - h_list_sizes_new[i].x = 0; - h_list_sizes_new[i].y = 0; - - for (size_t j = 0; j < width; j++) { - auto new_neighb_id = new_neighbors[i * width + j].id(); - if ((size_t)new_neighb_id >= nrow) break; - if (bloom_filter.check(i, new_neighb_id)) { continue; } - bloom_filter.add(i, new_neighb_id); - new_neighbors[i * width + j].mark_old(); - list_new[h_list_sizes_new[i].x++] = new_neighb_id; - if (h_list_sizes_new[i].x == num_samples) break; - } - } -} - -template -void GnndGraph::init_random_graph() -{ - for (size_t seg_idx = 0; seg_idx < static_cast(num_segments); seg_idx++) { - // random sequence (range: 0~nrow) - // segment_x stores neighbors which id % num_segments == x - std::vector rand_seq(nrow / num_segments); - std::iota(rand_seq.begin(), rand_seq.end(), 0); - auto gen = std::default_random_engine{seg_idx}; - std::shuffle(rand_seq.begin(), rand_seq.end(), gen); - -#pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - size_t base_idx = i * node_degree + seg_idx * segment_size; - auto h_neighbor_list = h_graph + base_idx; - auto h_dist_list = h_dists.data_handle() + base_idx; - for (size_t j = 0; j < static_cast(segment_size); j++) { - size_t idx = base_idx + j; - Index_t id = rand_seq[idx % rand_seq.size()] * num_segments + seg_idx; - if ((size_t)id == i) { - id = rand_seq[(idx + segment_size) % rand_seq.size()] * num_segments + seg_idx; - } - h_neighbor_list[j].id_with_flag() = id; - h_dist_list[j] = std::numeric_limits::max(); - } - } - } -} - -template -void GnndGraph::sample_graph(bool sample_new) -{ -#pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - h_list_sizes_old[i].x = 0; - h_list_sizes_old[i].y = 0; - h_list_sizes_new[i].x = 0; - h_list_sizes_new[i].y = 0; - - auto list = h_graph + i * node_degree; - auto list_old = h_graph_old.data() + i * num_samples; - auto list_new = h_graph_new.data() + i * num_samples; - for (int j = 0; j < segment_size; j++) { - for (int k = 0; k < num_segments; k++) { - auto neighbor = list[k * segment_size + j]; - if ((size_t)neighbor.id() >= nrow) continue; - if (!neighbor.is_new()) { - if (h_list_sizes_old[i].x < num_samples) { - list_old[h_list_sizes_old[i].x++] = neighbor.id(); - } - } else if (sample_new) { - if (h_list_sizes_new[i].x < num_samples) { - list[k * segment_size + j].mark_old(); - list_new[h_list_sizes_new[i].x++] = neighbor.id(); - } - } - if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { break; } - } - if (h_list_sizes_old[i].x == num_samples && h_list_sizes_new[i].x == num_samples) { break; } - } - } -} - -template -void GnndGraph::update_graph(const InternalID_t* new_neighbors, - const DistData_t* new_dists, - const size_t width, - std::atomic& update_counter) -{ -#pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - for (size_t j = 0; j < width; j++) { - auto new_neighb_id = new_neighbors[i * width + j]; - auto new_dist = new_dists[i * width + j]; - if (new_dist == std::numeric_limits::max()) break; - if ((size_t)new_neighb_id.id() == i) continue; - int seg_idx = new_neighb_id.id() % num_segments; - auto list = h_graph + i * node_degree + seg_idx * segment_size; - auto dist_list = h_dists.data_handle() + i * node_degree + seg_idx * segment_size; - int insert_pos = - insert_to_ordered_list(list, dist_list, segment_size, new_neighb_id, new_dist); - if (i % counter_interval == 0 && insert_pos != segment_size) { update_counter++; } - } - } -} - -template -void GnndGraph::sort_lists() -{ -#pragma omp parallel for - for (size_t i = 0; i < nrow; i++) { - std::vector> new_list; - for (size_t j = 0; j < node_degree; j++) { - new_list.emplace_back(h_dists.data_handle()[i * node_degree + j], - h_graph[i * node_degree + j].id()); - } - std::sort(new_list.begin(), new_list.end()); - for (size_t j = 0; j < node_degree; j++) { - h_graph[i * node_degree + j].id_with_flag() = new_list[j].second; - h_dists.data_handle()[i * node_degree + j] = new_list[j].first; - } - } -} - -template -void GnndGraph::clear() -{ - bloom_filter.clear(); -} - -template -GnndGraph::~GnndGraph() -{ - assert(h_graph == nullptr); -} - -template -GNND::GNND(raft::resources const& res, - const BuildConfig& build_config) - : res(res), - build_config_(build_config), - graph_(build_config.max_dataset_size, - align32::roundUp(build_config.node_degree), - align32::roundUp(build_config.internal_node_degree ? build_config.internal_node_degree - : build_config.node_degree), - NUM_SAMPLES), - nrow_(build_config.max_dataset_size), - ndim_(build_config.dataset_dim), - d_data_{raft::make_device_matrix<__half, size_t, raft::row_major>( - res, nrow_, build_config.dataset_dim)}, - l2_norms_{raft::make_device_vector(res, nrow_)}, - graph_buffer_{ - raft::make_device_matrix(res, nrow_, DEGREE_ON_DEVICE)}, - dists_buffer_{ - raft::make_device_matrix(res, nrow_, DEGREE_ON_DEVICE)}, - graph_host_buffer_(nrow_ * DEGREE_ON_DEVICE), - dists_host_buffer_(nrow_ * DEGREE_ON_DEVICE), - d_locks_{raft::make_device_vector(res, nrow_)}, - h_rev_graph_new_(nrow_ * NUM_SAMPLES), - h_graph_old_(nrow_ * NUM_SAMPLES), - h_rev_graph_old_(nrow_ * NUM_SAMPLES), - d_list_sizes_new_{raft::make_device_vector(res, nrow_)}, - d_list_sizes_old_{raft::make_device_vector(res, nrow_)} -{ - static_assert(NUM_SAMPLES <= 32); - raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits::max()); - auto graph_buffer_view = raft::make_device_matrix_view( - reinterpret_cast(graph_buffer_.data_handle()), nrow_, DEGREE_ON_DEVICE); - raft::matrix::fill(res, graph_buffer_view, std::numeric_limits::max()); - raft::matrix::fill(res, d_locks_.view(), 0); -}; - -template -void GNND::reset(raft::resources const& res) -{ - raft::matrix::fill(res, dists_buffer_.view(), std::numeric_limits::max()); - auto graph_buffer_view = raft::make_device_matrix_view( - reinterpret_cast(graph_buffer_.data_handle()), nrow_, DEGREE_ON_DEVICE); - raft::matrix::fill(res, graph_buffer_view, std::numeric_limits::max()); - raft::matrix::fill(res, d_locks_.view(), 0); -} - -template -void GNND::add_reverse_edges(Index_t* graph_ptr, - Index_t* h_rev_graph_ptr, - Index_t* d_rev_graph_ptr, - int2* list_sizes, - cudaStream_t stream) -{ - add_rev_edges_kernel<<>>( - graph_ptr, d_rev_graph_ptr, NUM_SAMPLES, list_sizes); - raft::copy( - h_rev_graph_ptr, d_rev_graph_ptr, nrow_ * NUM_SAMPLES, raft::resource::get_cuda_stream(res)); -} - -template -void GNND::local_join(cudaStream_t stream, - epilogue_op distance_epilogue) -{ - thrust::fill(thrust::device.on(stream), - dists_buffer_.data_handle(), - dists_buffer_.data_handle() + dists_buffer_.size(), - std::numeric_limits::max()); - local_join_kernel<<>>( - thrust::raw_pointer_cast(graph_.h_graph_new.data()), - thrust::raw_pointer_cast(h_rev_graph_new_.data()), - d_list_sizes_new_.data_handle(), - thrust::raw_pointer_cast(h_graph_old_.data()), - thrust::raw_pointer_cast(h_rev_graph_old_.data()), - d_list_sizes_old_.data_handle(), - NUM_SAMPLES, - d_data_.data_handle(), - ndim_, - graph_buffer_.data_handle(), - dists_buffer_.data_handle(), - DEGREE_ON_DEVICE, - d_locks_.data_handle(), - l2_norms_.data_handle(), - distance_epilogue); -} - -template -void GNND::build(Data_t* data, - const Index_t nrow, - Index_t* output_graph, - bool return_distances, - DistData_t* output_distances, - epilogue_op distance_epilogue) -{ - using input_t = typename std::remove_const::type; - - cudaStream_t stream = raft::resource::get_cuda_stream(res); - nrow_ = nrow; - graph_.nrow = nrow; - graph_.h_graph = (InternalID_t*)output_graph; - - cudaPointerAttributes data_ptr_attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&data_ptr_attr, data)); - size_t batch_size = (data_ptr_attr.devicePointer == nullptr) ? 100000 : nrow_; - - raft::spatial::knn::detail::utils::batch_load_iterator vec_batches{ - data, static_cast(nrow_), build_config_.dataset_dim, batch_size, stream}; - for (auto const& batch : vec_batches) { - preprocess_data_kernel<<< - batch.size(), - raft::warp_size(), - sizeof(Data_t) * ceildiv(build_config_.dataset_dim, static_cast(raft::warp_size())) * - raft::warp_size(), - stream>>>(batch.data(), - d_data_.data_handle(), - build_config_.dataset_dim, - l2_norms_.data_handle(), - batch.offset()); - } - - thrust::fill(thrust::device.on(stream), - (Index_t*)graph_buffer_.data_handle(), - (Index_t*)graph_buffer_.data_handle() + graph_buffer_.size(), - std::numeric_limits::max()); - - graph_.clear(); - graph_.init_random_graph(); - graph_.sample_graph(true); - - auto update_and_sample = [&](bool update_graph) { - if (update_graph) { - update_counter_ = 0; - graph_.update_graph(thrust::raw_pointer_cast(graph_host_buffer_.data()), - thrust::raw_pointer_cast(dists_host_buffer_.data()), - DEGREE_ON_DEVICE, - update_counter_); - if (update_counter_ < build_config_.termination_threshold * nrow_ * - build_config_.dataset_dim / counter_interval) { - update_counter_ = -1; - } - } - graph_.sample_graph(false); - }; - - for (size_t it = 0; it < build_config_.max_iterations; it++) { - raft::copy(d_list_sizes_new_.data_handle(), - thrust::raw_pointer_cast(graph_.h_list_sizes_new.data()), - nrow_, - raft::resource::get_cuda_stream(res)); - raft::copy(thrust::raw_pointer_cast(h_graph_old_.data()), - thrust::raw_pointer_cast(graph_.h_graph_old.data()), - nrow_ * NUM_SAMPLES, - raft::resource::get_cuda_stream(res)); - raft::copy(d_list_sizes_old_.data_handle(), - thrust::raw_pointer_cast(graph_.h_list_sizes_old.data()), - nrow_, - raft::resource::get_cuda_stream(res)); - raft::resource::sync_stream(res); - - std::thread update_and_sample_thread(update_and_sample, it); - - RAFT_LOG_DEBUG("# GNND iteraton: %lu / %lu", it + 1, build_config_.max_iterations); - - // Reuse dists_buffer_ to save GPU memory. graph_buffer_ cannot be reused, because it - // contains some information for local_join. - static_assert(DEGREE_ON_DEVICE * sizeof(*(dists_buffer_.data_handle())) >= - NUM_SAMPLES * sizeof(*(graph_buffer_.data_handle()))); - add_reverse_edges(thrust::raw_pointer_cast(graph_.h_graph_new.data()), - thrust::raw_pointer_cast(h_rev_graph_new_.data()), - (Index_t*)dists_buffer_.data_handle(), - d_list_sizes_new_.data_handle(), - stream); - add_reverse_edges(thrust::raw_pointer_cast(h_graph_old_.data()), - thrust::raw_pointer_cast(h_rev_graph_old_.data()), - (Index_t*)dists_buffer_.data_handle(), - d_list_sizes_old_.data_handle(), - stream); - - // Tensor operations from `mma.h` are guarded with archicteture - // __CUDA_ARCH__ >= 700. Since RAFT supports compilation for ARCH 600, - // we need to ensure that `local_join_kernel` (which uses tensor) operations - // is not only not compiled, but also a runtime error is presented to the user - auto kernel = preprocess_data_kernel; - void* kernel_ptr = reinterpret_cast(kernel); - auto runtime_arch = raft::util::arch::kernel_virtual_arch(kernel_ptr); - auto wmma_range = - raft::util::arch::SM_range(raft::util::arch::SM_70(), raft::util::arch::SM_future()); - - if (wmma_range.contains(runtime_arch)) { - local_join(stream, distance_epilogue); - } else { - THROW("NN_DESCENT cannot be run for __CUDA_ARCH__ < 700"); - } - - update_and_sample_thread.join(); - - if (update_counter_ == -1) { break; } - raft::copy(thrust::raw_pointer_cast(graph_host_buffer_.data()), - graph_buffer_.data_handle(), - nrow_ * DEGREE_ON_DEVICE, - raft::resource::get_cuda_stream(res)); - raft::resource::sync_stream(res); - raft::copy(thrust::raw_pointer_cast(dists_host_buffer_.data()), - dists_buffer_.data_handle(), - nrow_ * DEGREE_ON_DEVICE, - raft::resource::get_cuda_stream(res)); - - graph_.sample_graph_new(thrust::raw_pointer_cast(graph_host_buffer_.data()), DEGREE_ON_DEVICE); - } - - graph_.update_graph(thrust::raw_pointer_cast(graph_host_buffer_.data()), - thrust::raw_pointer_cast(dists_host_buffer_.data()), - DEGREE_ON_DEVICE, - update_counter_); - raft::resource::sync_stream(res); - graph_.sort_lists(); - - // Reuse graph_.h_dists as the buffer for shrink the lists in graph - static_assert(sizeof(decltype(*(graph_.h_dists.data_handle()))) >= sizeof(Index_t)); - - if (return_distances) { - auto graph_d_dists = raft::make_device_matrix( - res, nrow_, build_config_.node_degree); - raft::copy(graph_d_dists.data_handle(), - graph_.h_dists.data_handle(), - nrow_ * build_config_.node_degree, - raft::resource::get_cuda_stream(res)); - - auto output_dist_view = raft::make_device_matrix_view( - output_distances, nrow_, build_config_.output_graph_degree); - - raft::matrix::slice_coordinates coords{static_cast(0), - static_cast(0), - static_cast(nrow_), - static_cast(build_config_.output_graph_degree)}; - raft::matrix::slice( - res, raft::make_const_mdspan(graph_d_dists.view()), output_dist_view, coords); - raft::resource::sync_stream(res); - } - - Index_t* graph_shrink_buffer = (Index_t*)graph_.h_dists.data_handle(); - -#pragma omp parallel for - for (size_t i = 0; i < (size_t)nrow_; i++) { - for (size_t j = 0; j < build_config_.node_degree; j++) { - size_t idx = i * graph_.node_degree + j; - int id = graph_.h_graph[idx].id(); - if (id < static_cast(nrow_)) { - graph_shrink_buffer[i * build_config_.node_degree + j] = id; - } else { - graph_shrink_buffer[i * build_config_.node_degree + j] = - raft::neighbors::cagra::detail::device::xorshift64(idx) % nrow_; - } - } - } - graph_.h_graph = nullptr; - -#pragma omp parallel for - for (size_t i = 0; i < (size_t)nrow_; i++) { - for (size_t j = 0; j < build_config_.node_degree; j++) { - output_graph[i * build_config_.node_degree + j] = - graph_shrink_buffer[i * build_config_.node_degree + j]; - } - } -} - -template , - typename Accessor = - host_device_accessor, memory_type::host>> -void build(raft::resources const& res, - const index_params& params, - mdspan, row_major, Accessor> dataset, - index& idx, - epilogue_op distance_epilogue = DistEpilogue()) -{ - RAFT_EXPECTS(dataset.extent(0) < std::numeric_limits::max() - 1, - "The dataset size for GNND should be less than %d", - std::numeric_limits::max() - 1); - size_t intermediate_degree = params.intermediate_graph_degree; - size_t graph_degree = params.graph_degree; - - if (intermediate_degree >= static_cast(dataset.extent(0))) { - RAFT_LOG_WARN( - "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", - dataset.extent(0)); - intermediate_degree = dataset.extent(0) - 1; - } - if (intermediate_degree < graph_degree) { - RAFT_LOG_WARN( - "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " - "graph_degree.", - graph_degree, - intermediate_degree); - graph_degree = intermediate_degree; - } - - // The elements in each knn-list are partitioned into different buckets, and we need more buckets - // to mitigate bucket collisions. `intermediate_degree` is OK to larger than - // extended_graph_degree. - size_t extended_graph_degree = - align32::roundUp(static_cast(graph_degree * (graph_degree <= 32 ? 1.0 : 1.3))); - size_t extended_intermediate_degree = align32::roundUp( - static_cast(intermediate_degree * (intermediate_degree <= 32 ? 1.0 : 1.3))); - - auto int_graph = raft::make_host_matrix( - dataset.extent(0), static_cast(extended_graph_degree)); - - BuildConfig build_config{.max_dataset_size = static_cast(dataset.extent(0)), - .dataset_dim = static_cast(dataset.extent(1)), - .node_degree = extended_graph_degree, - .internal_node_degree = extended_intermediate_degree, - .max_iterations = params.max_iterations, - .termination_threshold = params.termination_threshold, - .output_graph_degree = params.graph_degree}; - - GNND nnd(res, build_config); - - if (idx.distances().has_value() || !params.return_distances) { - nnd.build(dataset.data_handle(), - dataset.extent(0), - int_graph.data_handle(), - params.return_distances, - idx.distances() - .value_or(raft::make_device_matrix(res, 0, 0).view()) - .data_handle(), - distance_epilogue); - } else { - RAFT_EXPECTS(!params.return_distances, - "Distance view not allocated. Using return_distances set to true requires " - "distance view to be allocated."); - } - -#pragma omp parallel for - for (size_t i = 0; i < static_cast(dataset.extent(0)); i++) { - for (size_t j = 0; j < graph_degree; j++) { - auto graph = idx.graph().data_handle(); - graph[i * graph_degree + j] = int_graph.data_handle()[i * extended_graph_degree + j]; - } - } -} - -template , - typename Accessor = - host_device_accessor, memory_type::host>> -index build(raft::resources const& res, - const index_params& params, - mdspan, row_major, Accessor> dataset, - epilogue_op distance_epilogue = DistEpilogue()) -{ - size_t intermediate_degree = params.intermediate_graph_degree; - size_t graph_degree = params.graph_degree; - - if (intermediate_degree < graph_degree) { - RAFT_LOG_WARN( - "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " - "graph_degree.", - graph_degree, - intermediate_degree); - graph_degree = intermediate_degree; - } - - index idx{ - res, dataset.extent(0), static_cast(graph_degree), params.return_distances}; - - build(res, params, dataset, idx, distance_epilogue); - - return idx; -} - -} // namespace raft::neighbors::experimental::nn_descent::detail diff --git a/cpp/include/raft/neighbors/detail/nn_descent_batch.cuh b/cpp/include/raft/neighbors/detail/nn_descent_batch.cuh deleted file mode 100644 index 78467c9741..0000000000 --- a/cpp/include/raft/neighbors/detail/nn_descent_batch.cuh +++ /dev/null @@ -1,701 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY - -#include "../nn_descent_types.hpp" -#include "nn_descent.cuh" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include - -#include -#include -#include -#include -#include -#include - -namespace raft::neighbors::experimental::nn_descent::detail { - -// -// Run balanced kmeans on a subsample of the dataset to get centroids -// -template , memory_type::host>> -void get_balanced_kmeans_centroids( - raft::resources const& res, - raft::distance::DistanceType metric, - mdspan, row_major, Accessor> dataset, - raft::device_matrix_view centroids) -{ - size_t num_rows = static_cast(dataset.extent(0)); - size_t num_cols = static_cast(dataset.extent(1)); - size_t n_clusters = centroids.extent(0); - size_t num_subsamples = - std::min(static_cast(num_rows / n_clusters), static_cast(num_rows * 0.1)); - - auto d_subsample_dataset = - raft::make_device_matrix(res, num_subsamples, num_cols); - raft::matrix::sample_rows( - res, raft::random::RngState{0}, dataset, d_subsample_dataset.view()); - - raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.metric = metric; - - auto d_subsample_dataset_const_view = - raft::make_device_matrix_view( - d_subsample_dataset.data_handle(), num_subsamples, num_cols); - raft::cluster::kmeans_balanced::fit( - res, kmeans_params, d_subsample_dataset_const_view, centroids); -} - -// -// Get the top k closest centroid indices for each data point -// Loads the data in batches onto device if data is on host for memory efficiency -// -template -void get_global_nearest_k( - raft::resources const& res, - size_t k, - size_t num_rows, - size_t n_clusters, - const T* dataset, - raft::host_matrix_view global_nearest_cluster, - raft::device_matrix_view centroids, - raft::distance::DistanceType metric) -{ - size_t num_cols = centroids.extent(1); - - cudaPointerAttributes attr; - RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, dataset)); - float* ptr = reinterpret_cast(attr.devicePointer); - - if (ptr == nullptr) { // data on host - size_t num_batches = n_clusters; - size_t batch_size = (num_rows + n_clusters) / n_clusters; - - auto d_dataset_batch = - raft::make_device_matrix(res, batch_size, num_cols); - - auto nearest_clusters_idx = - raft::make_device_matrix(res, batch_size, k); - auto nearest_clusters_dist = - raft::make_device_matrix(res, batch_size, k); - - for (size_t i = 0; i < num_batches; i++) { - size_t batch_size_ = batch_size; - - if (i == num_batches - 1) { batch_size_ = num_rows - batch_size * i; } - raft::copy(d_dataset_batch.data_handle(), - dataset + i * batch_size * num_cols, - batch_size_ * num_cols, - resource::get_cuda_stream(res)); - - raft::neighbors::brute_force::fused_l2_knn( - res, - raft::make_const_mdspan(centroids), - raft::make_const_mdspan(d_dataset_batch.view()), - nearest_clusters_idx.view(), - nearest_clusters_dist.view(), - metric); - raft::copy(global_nearest_cluster.data_handle() + i * batch_size * k, - nearest_clusters_idx.data_handle(), - batch_size_ * k, - resource::get_cuda_stream(res)); - } - } else { // data on device - auto nearest_clusters_idx = - raft::make_device_matrix(res, num_rows, k); - auto nearest_clusters_dist = - raft::make_device_matrix(res, num_rows, k); - - raft::neighbors::brute_force::fused_l2_knn( - res, - raft::make_const_mdspan(centroids), - raft::make_device_matrix_view(dataset, num_rows, num_cols), - nearest_clusters_idx.view(), - nearest_clusters_dist.view(), - metric); - - raft::copy(global_nearest_cluster.data_handle(), - nearest_clusters_idx.data_handle(), - num_rows * k, - resource::get_cuda_stream(res)); - } -} - -// -// global_nearest_cluster [num_rows X k=2] : top 2 closest clusters for each data point -// inverted_indices [num_rows x k vector] : sparse vector for data indices for each cluster -// cluster_size [n_cluster] : cluster size for each cluster -// offset [n_cluster] : offset in inverted_indices for each cluster -// Loads the data in batches onto device if data is on host for memory efficiency -// -template -void get_inverted_indices(raft::resources const& res, - size_t n_clusters, - size_t& max_cluster_size, - size_t& min_cluster_size, - raft::host_matrix_view global_nearest_cluster, - raft::host_vector_view inverted_indices, - raft::host_vector_view cluster_size, - raft::host_vector_view offset) -{ - // build sparse inverted indices and get number of data points for each cluster - size_t num_rows = global_nearest_cluster.extent(0); - size_t k = global_nearest_cluster.extent(1); - - auto local_offset = raft::make_host_vector(n_clusters); - - max_cluster_size = 0; - min_cluster_size = std::numeric_limits::max(); - - thrust::fill( - thrust::host, cluster_size.data_handle(), cluster_size.data_handle() + n_clusters, 0); - thrust::fill( - thrust::host, local_offset.data_handle(), local_offset.data_handle() + n_clusters, 0); - - // TODO: this part isn't really a bottleneck but maybe worth trying omp parallel - // for with atomic add - for (size_t i = 0; i < num_rows; i++) { - for (size_t j = 0; j < k; j++) { - IdxT cluster_id = global_nearest_cluster(i, j); - cluster_size(cluster_id) += 1; - } - } - - offset(0) = 0; - for (size_t i = 1; i < n_clusters; i++) { - offset(i) = offset(i - 1) + cluster_size(i - 1); - } - for (size_t i = 0; i < num_rows; i++) { - for (size_t j = 0; j < k; j++) { - IdxT cluster_id = global_nearest_cluster(i, j); - inverted_indices(offset(cluster_id) + local_offset(cluster_id)) = i; - local_offset(cluster_id) += 1; - } - } - - max_cluster_size = static_cast( - *std::max_element(cluster_size.data_handle(), cluster_size.data_handle() + n_clusters)); - min_cluster_size = static_cast( - *std::min_element(cluster_size.data_handle(), cluster_size.data_handle() + n_clusters)); -} - -template -struct KeyValuePair { - KeyType key; - ValueType value; -}; - -template -struct CustomKeyComparator { - __device__ bool operator()(const KeyValuePair& a, - const KeyValuePair& b) const - { - if (a.key == b.key) { return a.value < b.value; } - return a.key < b.key; - } -}; - -template -RAFT_KERNEL merge_subgraphs(IdxT* cluster_data_indices, - size_t graph_degree, - size_t num_cluster_in_batch, - float* global_distances, - float* batch_distances, - IdxT* global_indices, - IdxT* batch_indices) -{ - size_t batch_row = blockIdx.x; - typedef cub::BlockMergeSort, BLOCK_SIZE, ITEMS_PER_THREAD> - BlockMergeSortType; - __shared__ typename cub::BlockMergeSort, BLOCK_SIZE, ITEMS_PER_THREAD>:: - TempStorage tmpSmem; - - extern __shared__ char sharedMem[]; - float* blockKeys = reinterpret_cast(sharedMem); - IdxT* blockValues = reinterpret_cast(&sharedMem[graph_degree * 2 * sizeof(float)]); - int16_t* uniqueMask = - reinterpret_cast(&sharedMem[graph_degree * 2 * (sizeof(float) + sizeof(IdxT))]); - - if (batch_row < num_cluster_in_batch) { - // load batch or global depending on threadIdx - size_t global_row = cluster_data_indices[batch_row]; - - KeyValuePair threadKeyValuePair[ITEMS_PER_THREAD]; - - size_t halfway = BLOCK_SIZE / 2; - size_t do_global = threadIdx.x < halfway; - - float* distances; - IdxT* indices; - - if (do_global) { - distances = global_distances; - indices = global_indices; - } else { - distances = batch_distances; - indices = batch_indices; - } - - size_t idxBase = (threadIdx.x * do_global + (threadIdx.x - halfway) * (1lu - do_global)) * - static_cast(ITEMS_PER_THREAD); - size_t arrIdxBase = (global_row * do_global + batch_row * (1lu - do_global)) * graph_degree; - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - size_t colId = idxBase + i; - if (colId < graph_degree) { - threadKeyValuePair[i].key = distances[arrIdxBase + colId]; - threadKeyValuePair[i].value = indices[arrIdxBase + colId]; - } else { - threadKeyValuePair[i].key = std::numeric_limits::max(); - threadKeyValuePair[i].value = std::numeric_limits::max(); - } - } - - __syncthreads(); - - BlockMergeSortType(tmpSmem).Sort(threadKeyValuePair, CustomKeyComparator{}); - - // load sorted result into shared memory to get unique values - idxBase = threadIdx.x * ITEMS_PER_THREAD; - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - size_t colId = idxBase + i; - if (colId < 2 * graph_degree) { - blockKeys[colId] = threadKeyValuePair[i].key; - blockValues[colId] = threadKeyValuePair[i].value; - } - } - - __syncthreads(); - - // get unique mask - if (threadIdx.x == 0) { uniqueMask[0] = 1; } - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - size_t colId = idxBase + i; - if (colId > 0 && colId < 2 * graph_degree) { - uniqueMask[colId] = static_cast(blockValues[colId] != blockValues[colId - 1]); - } - } - - __syncthreads(); - - // prefix sum - if (threadIdx.x == 0) { - for (int i = 1; i < 2 * graph_degree; i++) { - uniqueMask[i] += uniqueMask[i - 1]; - } - } - - __syncthreads(); - // load unique values to global memory - if (threadIdx.x == 0) { - global_distances[global_row * graph_degree] = blockKeys[0]; - global_indices[global_row * graph_degree] = blockValues[0]; - } - - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - size_t colId = idxBase + i; - if (colId > 0 && colId < 2 * graph_degree) { - bool is_unique = uniqueMask[colId] != uniqueMask[colId - 1]; - int16_t global_colId = uniqueMask[colId] - 1; - if (is_unique && static_cast(global_colId) < graph_degree) { - global_distances[global_row * graph_degree + global_colId] = blockKeys[colId]; - global_indices[global_row * graph_degree + global_colId] = blockValues[colId]; - } - } - } - } -} - -// -// builds knn graph using NN Descent and merge with global graph -// -template , - typename Accessor = - host_device_accessor, memory_type::host>> -void build_and_merge(raft::resources const& res, - const index_params& params, - size_t num_data_in_cluster, - size_t graph_degree, - size_t int_graph_node_degree, - T* cluster_data, - IdxT* cluster_data_indices, - int* int_graph, - IdxT* inverted_indices, - IdxT* global_indices_d, - float* global_distances_d, - IdxT* batch_indices_h, - IdxT* batch_indices_d, - float* batch_distances_d, - GNND& nnd, - epilogue_op distance_epilogue) -{ - nnd.build( - cluster_data, num_data_in_cluster, int_graph, true, batch_distances_d, distance_epilogue); - - // remap indices -#pragma omp parallel for - for (size_t i = 0; i < num_data_in_cluster; i++) { - for (size_t j = 0; j < graph_degree; j++) { - size_t local_idx = int_graph[i * int_graph_node_degree + j]; - batch_indices_h[i * graph_degree + j] = inverted_indices[local_idx]; - } - } - - raft::copy(batch_indices_d, - batch_indices_h, - num_data_in_cluster * graph_degree, - raft::resource::get_cuda_stream(res)); - - size_t num_elems = graph_degree * 2; - size_t sharedMemSize = num_elems * (sizeof(float) + sizeof(IdxT) + sizeof(int16_t)); - - if (num_elems <= 128) { - merge_subgraphs - <<>>( - cluster_data_indices, - graph_degree, - num_data_in_cluster, - global_distances_d, - batch_distances_d, - global_indices_d, - batch_indices_d); - } else if (num_elems <= 512) { - merge_subgraphs - <<>>( - cluster_data_indices, - graph_degree, - num_data_in_cluster, - global_distances_d, - batch_distances_d, - global_indices_d, - batch_indices_d); - } else if (num_elems <= 1024) { - merge_subgraphs - <<>>( - cluster_data_indices, - graph_degree, - num_data_in_cluster, - global_distances_d, - batch_distances_d, - global_indices_d, - batch_indices_d); - } else if (num_elems <= 2048) { - merge_subgraphs - <<>>( - cluster_data_indices, - graph_degree, - num_data_in_cluster, - global_distances_d, - batch_distances_d, - global_indices_d, - batch_indices_d); - } else { - // this is as far as we can get due to the shared mem usage of cub::BlockMergeSort - RAFT_FAIL("The degree of knn is too large (%lu). It must be smaller than 1024", graph_degree); - } - raft::resource::sync_stream(res); -} - -// -// For each cluster, gather the data samples that belong to that cluster, and -// call build_and_merge -// -template > -void cluster_nnd(raft::resources const& res, - const index_params& params, - size_t graph_degree, - size_t extended_graph_degree, - size_t max_cluster_size, - raft::host_matrix_view dataset, - IdxT* offsets, - IdxT* cluster_size, - IdxT* cluster_data_indices, - int* int_graph, - IdxT* inverted_indices, - IdxT* global_indices_h, - float* global_distances_h, - IdxT* batch_indices_h, - IdxT* batch_indices_d, - float* batch_distances_d, - const BuildConfig& build_config, - epilogue_op distance_epilogue) -{ - size_t num_rows = dataset.extent(0); - size_t num_cols = dataset.extent(1); - - GNND nnd(res, build_config); - - auto cluster_data_matrix = - raft::make_host_matrix(max_cluster_size, num_cols); - - for (size_t cluster_id = 0; cluster_id < params.n_clusters; cluster_id++) { - RAFT_LOG_DEBUG( - "# Data on host. Running clusters: %lu / %lu", cluster_id + 1, params.n_clusters); - size_t num_data_in_cluster = cluster_size[cluster_id]; - size_t offset = offsets[cluster_id]; - -#pragma omp parallel for - for (size_t i = 0; i < num_data_in_cluster; i++) { - for (size_t j = 0; j < num_cols; j++) { - size_t global_row = (inverted_indices + offset)[i]; - cluster_data_matrix(i, j) = dataset(global_row, j); - } - } - - distance_epilogue.preprocess_for_batch(cluster_data_indices + offset, num_data_in_cluster); - - build_and_merge(res, - params, - num_data_in_cluster, - graph_degree, - extended_graph_degree, - cluster_data_matrix.data_handle(), - cluster_data_indices + offset, - int_graph, - inverted_indices + offset, - global_indices_h, - global_distances_h, - batch_indices_h, - batch_indices_d, - batch_distances_d, - nnd, - distance_epilogue); - nnd.reset(res); - } -} - -template > -void cluster_nnd(raft::resources const& res, - const index_params& params, - size_t graph_degree, - size_t extended_graph_degree, - size_t max_cluster_size, - raft::device_matrix_view dataset, - IdxT* offsets, - IdxT* cluster_size, - IdxT* cluster_data_indices, - int* int_graph, - IdxT* inverted_indices, - IdxT* global_indices_h, - float* global_distances_h, - IdxT* batch_indices_h, - IdxT* batch_indices_d, - float* batch_distances_d, - const BuildConfig& build_config, - epilogue_op distance_epilogue) -{ - size_t num_rows = dataset.extent(0); - size_t num_cols = dataset.extent(1); - - GNND nnd(res, build_config); - - auto cluster_data_matrix = - raft::make_device_matrix(res, max_cluster_size, num_cols); - - for (size_t cluster_id = 0; cluster_id < params.n_clusters; cluster_id++) { - RAFT_LOG_DEBUG( - "# Data on device. Running clusters: %lu / %lu", cluster_id + 1, params.n_clusters); - size_t num_data_in_cluster = cluster_size[cluster_id]; - size_t offset = offsets[cluster_id]; - - auto cluster_data_view = raft::make_device_matrix_view( - cluster_data_matrix.data_handle(), num_data_in_cluster, num_cols); - auto cluster_data_indices_view = raft::make_device_vector_view( - cluster_data_indices + offset, num_data_in_cluster); - distance_epilogue.preprocess_for_batch(cluster_data_indices + offset, num_data_in_cluster); - - auto dataset_IdxT = - raft::make_device_matrix_view(dataset.data_handle(), num_rows, num_cols); - raft::matrix::gather(res, dataset_IdxT, cluster_data_indices_view, cluster_data_view); - - build_and_merge(res, - params, - num_data_in_cluster, - graph_degree, - extended_graph_degree, - cluster_data_view.data_handle(), - cluster_data_indices + offset, - int_graph, - inverted_indices + offset, - global_indices_h, - global_distances_h, - batch_indices_h, - batch_indices_d, - batch_distances_d, - nnd, - distance_epilogue); - nnd.reset(res); - } -} - -template , - typename Accessor = - host_device_accessor, memory_type::host>> -index batch_build(raft::resources const& res, - const index_params& params, - mdspan, row_major, Accessor> dataset, - epilogue_op distance_epilogue = DistEpilogue()) -{ - size_t graph_degree = params.graph_degree; - size_t intermediate_degree = params.intermediate_graph_degree; - - size_t num_rows = static_cast(dataset.extent(0)); - size_t num_cols = static_cast(dataset.extent(1)); - - auto centroids = - raft::make_device_matrix(res, params.n_clusters, num_cols); - get_balanced_kmeans_centroids(res, params.metric, dataset, centroids.view()); - - size_t k = 2; - auto global_nearest_cluster = raft::make_host_matrix(num_rows, k); - get_global_nearest_k(res, - k, - num_rows, - params.n_clusters, - dataset.data_handle(), - global_nearest_cluster.view(), - centroids.view(), - params.metric); - - auto inverted_indices = raft::make_host_vector(num_rows * k); - auto cluster_size = raft::make_host_vector(params.n_clusters); - auto offset = raft::make_host_vector(params.n_clusters); - - size_t max_cluster_size, min_cluster_size; - get_inverted_indices(res, - params.n_clusters, - max_cluster_size, - min_cluster_size, - global_nearest_cluster.view(), - inverted_indices.view(), - cluster_size.view(), - offset.view()); - - if (intermediate_degree >= min_cluster_size) { - RAFT_LOG_WARN( - "Intermediate graph degree cannot be larger than minimum cluster size, reducing it to %lu", - dataset.extent(0)); - intermediate_degree = min_cluster_size - 1; - } - if (intermediate_degree < graph_degree) { - RAFT_LOG_WARN( - "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " - "graph_degree.", - graph_degree, - intermediate_degree); - graph_degree = intermediate_degree; - } - - size_t extended_graph_degree = - align32::roundUp(static_cast(graph_degree * (graph_degree <= 32 ? 1.0 : 1.3))); - size_t extended_intermediate_degree = align32::roundUp( - static_cast(intermediate_degree * (intermediate_degree <= 32 ? 1.0 : 1.3))); - - auto int_graph = raft::make_host_matrix( - max_cluster_size, static_cast(extended_graph_degree)); - - BuildConfig build_config{.max_dataset_size = max_cluster_size, - .dataset_dim = num_cols, - .node_degree = extended_graph_degree, - .internal_node_degree = extended_intermediate_degree, - .max_iterations = params.max_iterations, - .termination_threshold = params.termination_threshold, - .output_graph_degree = graph_degree}; - - auto global_indices_h = raft::make_managed_matrix(res, num_rows, graph_degree); - auto global_distances_h = raft::make_managed_matrix(res, num_rows, graph_degree); - - thrust::fill(thrust::host, - global_indices_h.data_handle(), - global_indices_h.data_handle() + num_rows * graph_degree, - std::numeric_limits::max()); - thrust::fill(thrust::host, - global_distances_h.data_handle(), - global_distances_h.data_handle() + num_rows * graph_degree, - std::numeric_limits::max()); - - auto batch_indices_h = - raft::make_host_matrix(max_cluster_size, graph_degree); - auto batch_indices_d = - raft::make_device_matrix(res, max_cluster_size, graph_degree); - auto batch_distances_d = - raft::make_device_matrix(res, max_cluster_size, graph_degree); - - auto cluster_data_indices = raft::make_device_vector(res, num_rows * k); - raft::copy(cluster_data_indices.data_handle(), - inverted_indices.data_handle(), - num_rows * k, - resource::get_cuda_stream(res)); - - cluster_nnd(res, - params, - graph_degree, - extended_graph_degree, - max_cluster_size, - dataset, - offset.data_handle(), - cluster_size.data_handle(), - cluster_data_indices.data_handle(), - int_graph.data_handle(), - inverted_indices.data_handle(), - global_indices_h.data_handle(), - global_distances_h.data_handle(), - batch_indices_h.data_handle(), - batch_indices_d.data_handle(), - batch_distances_d.data_handle(), - build_config, - distance_epilogue); - - index global_idx{ - res, dataset.extent(0), static_cast(graph_degree), params.return_distances}; - - raft::copy(global_idx.graph().data_handle(), - global_indices_h.data_handle(), - num_rows * graph_degree, - raft::resource::get_cuda_stream(res)); - if (params.return_distances && global_idx.distances().has_value()) { - raft::copy(global_idx.distances().value().data_handle(), - global_distances_h.data_handle(), - num_rows * graph_degree, - raft::resource::get_cuda_stream(res)); - } - return global_idx; -} - -} // namespace raft::neighbors::experimental::nn_descent::detail diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh deleted file mode 100644 index 170f973984..0000000000 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ /dev/null @@ -1,19 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "refine_device.cuh" -#include "refine_host.hpp" diff --git a/cpp/include/raft/neighbors/detail/refine_common.hpp b/cpp/include/raft/neighbors/detail/refine_common.hpp deleted file mode 100644 index bfd3341ee9..0000000000 --- a/cpp/include/raft/neighbors/detail/refine_common.hpp +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include - -namespace raft::neighbors::detail { - -/** Checks whether the input data extents are compatible. */ -template -void refine_check_input(ExtentsT dataset, - ExtentsT queries, - ExtentsT candidates, - ExtentsT indices, - ExtentsT distances, - distance::DistanceType metric) -{ - auto n_queries = queries.extent(0); - auto k = distances.extent(1); - - RAFT_EXPECTS(indices.extent(0) == n_queries && distances.extent(0) == n_queries && - candidates.extent(0) == n_queries, - "Number of rows in output indices, distances and candidates matrices must be equal" - " with the number of rows in search matrix. Expected %d, got %d, %d, and %d", - static_cast(n_queries), - static_cast(indices.extent(0)), - static_cast(distances.extent(0)), - static_cast(candidates.extent(0))); - - RAFT_EXPECTS(indices.extent(1) == k, - "Number of columns in output indices and distances matrices must be equal to k"); - - RAFT_EXPECTS(queries.extent(1) == dataset.extent(1), - "Number of columns must be equal for dataset and queries"); - - RAFT_EXPECTS(candidates.extent(1) >= k, - "Number of neighbor candidates must not be smaller than k (%d vs %d)", - static_cast(candidates.extent(1)), - static_cast(k)); -} - -} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/refine_device.cuh b/cpp/include/raft/neighbors/detail/refine_device.cuh deleted file mode 100644 index bdc29ca121..0000000000 --- a/cpp/include/raft/neighbors/detail/refine_device.cuh +++ /dev/null @@ -1,144 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -namespace raft::neighbors::detail { - -/** - * See raft::neighbors::refine for docs. - */ -template -void refine_device(raft::resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded) -{ - matrix_idx n_candidates = neighbor_candidates.extent(1); - matrix_idx n_queries = queries.extent(0); - matrix_idx dim = dataset.extent(1); - uint32_t k = static_cast(indices.extent(1)); - - // TODO: this restriction could be lifted with some effort - RAFT_EXPECTS(k <= raft::matrix::detail::select::warpsort::kMaxCapacity, - "k must be less than topk::kMaxCapacity (%d).", - raft::matrix::detail::select::warpsort::kMaxCapacity); - - common::nvtx::range fun_scope( - "neighbors::refine(%zu, %u)", size_t(n_queries), uint32_t(n_candidates)); - - refine_check_input(dataset.extents(), - queries.extents(), - neighbor_candidates.extents(), - indices.extents(), - distances.extents(), - metric); - - // The refinement search can be mapped to an IVF flat search: - // - We consider that the candidate vectors form a cluster, separately for each query. - // - In other words, the n_queries * n_candidates vectors form n_queries clusters, each with - // n_candidates elements. - // - We consider that the coarse level search is already performed and assigned a single cluster - // to search for each query (the cluster formed from the corresponding candidates). - // - We run IVF flat search with n_probes=1 to select the best k elements of the candidates. - rmm::device_uvector fake_coarse_idx(n_queries, resource::get_cuda_stream(handle)); - - thrust::sequence(resource::get_thrust_policy(handle), - fake_coarse_idx.data(), - fake_coarse_idx.data() + n_queries); - - raft::neighbors::ivf_flat::index refinement_index( - handle, metric, n_queries, false, true, dim); - - raft::neighbors::ivf_flat::detail::fill_refinement_index(handle, - &refinement_index, - dataset.data_handle(), - neighbor_candidates.data_handle(), - n_queries, - n_candidates); - uint32_t grid_dim_x = 1; - - // the neighbor ids will be computed in uint32_t as offset - rmm::device_uvector neighbors_uint32_buf(0, resource::get_cuda_stream(handle)); - // Offsets per probe for each query [n_queries] as n_probes = 1 - rmm::device_uvector chunk_index(n_queries, resource::get_cuda_stream(handle)); - - // we know that each cluster has exactly n_candidates entries - thrust::fill(resource::get_thrust_policy(handle), - chunk_index.data(), - chunk_index.data() + n_queries, - uint32_t(n_candidates)); - - uint32_t* neighbors_uint32 = nullptr; - if constexpr (sizeof(idx_t) == sizeof(uint32_t)) { - neighbors_uint32 = reinterpret_cast(indices.data_handle()); - } else { - neighbors_uint32_buf.resize(std::size_t(n_queries) * std::size_t(k), - resource::get_cuda_stream(handle)); - neighbors_uint32 = neighbors_uint32_buf.data(); - } - - raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< - data_t, - typename raft::spatial::knn::detail::utils::config::value_t, - idx_t>(refinement_index, - queries.data_handle(), - fake_coarse_idx.data(), - static_cast(n_queries), - 0, - refinement_index.metric(), - 1, - k, - 0, - chunk_index.data(), - raft::distance::is_min_close(metric), - raft::neighbors::filtering::none_ivf_sample_filter(), - neighbors_uint32, - distances.data_handle(), - grid_dim_x, - resource::get_cuda_stream(handle)); - - // postprocessing -- neighbors from position to actual id - ivf::detail::postprocess_neighbors(indices.data_handle(), - neighbors_uint32, - refinement_index.inds_ptrs().data_handle(), - fake_coarse_idx.data(), - chunk_index.data(), - n_queries, - 1, - k, - resource::get_cuda_stream(handle)); -} - -} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/refine_host-ext.hpp b/cpp/include/raft/neighbors/detail/refine_host-ext.hpp deleted file mode 100644 index f5c8c73bb9..0000000000 --- a/cpp/include/raft/neighbors/detail/refine_host-ext.hpp +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // _RAFT_HAS_CUDA -#include // raft::host_matrix_view -#include // raft::distance::DistanceType -#include // RAFT_EXPLICIT - -#include // int64_t - -#if defined(_RAFT_HAS_CUDA) -#include -#endif - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft::neighbors::detail { - -template -[[gnu::optimize(3), gnu::optimize("tree-vectorize")]] void refine_host( - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded) RAFT_EXPLICIT; - -} - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_refine(IdxT, DataT, DistanceT, ExtentsT) \ - extern template void raft::neighbors::detail::refine_host( \ - raft::host_matrix_view dataset, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbor_candidates, \ - raft::host_matrix_view indices, \ - raft::host_matrix_view distances, \ - distance::DistanceType metric); - -instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); -instantiate_raft_neighbors_refine(uint32_t, float, float, int64_t); -instantiate_raft_neighbors_refine(int64_t, int8_t, float, int64_t); -instantiate_raft_neighbors_refine(int64_t, uint8_t, float, int64_t); - -#if defined(_RAFT_HAS_CUDA) -instantiate_raft_neighbors_refine(int64_t, half, float, int64_t); -#endif - -#undef instantiate_raft_neighbors_refine diff --git a/cpp/include/raft/neighbors/detail/refine_host-inl.hpp b/cpp/include/raft/neighbors/detail/refine_host-inl.hpp deleted file mode 100644 index 9aff451dfc..0000000000 --- a/cpp/include/raft/neighbors/detail/refine_host-inl.hpp +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include - -#include - -#include - -namespace raft::neighbors::detail { - -template -[[gnu::optimize(3), gnu::optimize("tree-vectorize")]] void refine_host_impl( - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances) -{ - size_t n_queries = queries.extent(0); - size_t n_rows = dataset.extent(0); - size_t dim = dataset.extent(1); - size_t orig_k = neighbor_candidates.extent(1); - size_t refined_k = indices.extent(1); - - common::nvtx::range fun_scope( - "neighbors::refine_host(%zu, %zu -> %zu)", n_queries, orig_k, refined_k); - - auto suggested_n_threads = std::max(1, std::min(omp_get_num_procs(), omp_get_max_threads())); - - // If the number of queries is small, separate the distance calculation and - // the top-k calculation into separate loops, and apply finer-grained thread - // parallelism to the distance calculation loop. - if (n_queries < size_t(suggested_n_threads)) { - std::vector>> refined_pairs( - n_queries, std::vector>(orig_k)); - - // For efficiency, each thread should read a certain amount of array - // elements. The number of threads for distance computation is determined - // taking this into account. - auto n_elements = std::max(size_t(512), dim); - auto max_n_threads = raft::div_rounding_up_safe(n_queries * orig_k * dim, n_elements); - auto suggested_n_threads_for_distance = std::min(size_t(suggested_n_threads), max_n_threads); - - // The max number of threads for topk computation is the number of queries. - auto suggested_n_threads_for_topk = std::min(size_t(suggested_n_threads), n_queries); - - // Compute the refined distance using original dataset vectors -#pragma omp parallel for collapse(2) num_threads(suggested_n_threads_for_distance) - for (size_t i = 0; i < n_queries; i++) { - for (size_t j = 0; j < orig_k; j++) { - const DataT* query = queries.data_handle() + dim * i; - IdxT id = neighbor_candidates(i, j); - DistanceT distance = 0.0; - if (static_cast(id) >= n_rows) { - distance = std::numeric_limits::max(); - } else { - const DataT* row = dataset.data_handle() + dim * id; - for (size_t k = 0; k < dim; k++) { - distance += DC::template eval(query[k], row[k]); - } - } - refined_pairs[i][j] = std::make_tuple(distance, id); - } - } - - // Sort the query neighbors by their refined distances -#pragma omp parallel for num_threads(suggested_n_threads_for_topk) - for (size_t i = 0; i < n_queries; i++) { - std::sort(refined_pairs[i].begin(), refined_pairs[i].end()); - // Store first refined_k neighbors - for (size_t j = 0; j < refined_k; j++) { - indices(i, j) = std::get<1>(refined_pairs[i][j]); - if (distances.data_handle() != nullptr) { - distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[i][j])); - } - } - } - return; - } - - if (size_t(suggested_n_threads) > n_queries) { suggested_n_threads = n_queries; } - -#pragma omp parallel num_threads(suggested_n_threads) - { - std::vector> refined_pairs(orig_k); - for (size_t i = omp_get_thread_num(); i < n_queries; i += omp_get_num_threads()) { - // Compute the refined distance using original dataset vectors - const DataT* query = queries.data_handle() + dim * i; - for (size_t j = 0; j < orig_k; j++) { - IdxT id = neighbor_candidates(i, j); - DistanceT distance = 0.0; - if (static_cast(id) >= n_rows) { - distance = std::numeric_limits::max(); - } else { - const DataT* row = dataset.data_handle() + dim * id; - for (size_t k = 0; k < dim; k++) { - distance += DC::template eval(query[k], row[k]); - } - } - refined_pairs[j] = std::make_tuple(distance, id); - } - // Sort the query neighbors by their refined distances - std::sort(refined_pairs.begin(), refined_pairs.end()); - // Store first refined_k neighbors - for (size_t j = 0; j < refined_k; j++) { - indices(i, j) = std::get<1>(refined_pairs[j]); - if (distances.data_handle() != nullptr) { - distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[j])); - } - } - } - } -} - -struct distance_comp_l2 { - template - static inline auto eval(const DistanceT& a, const DistanceT& b) -> DistanceT - { - auto d = a - b; - return d * d; - } - template - static inline auto postprocess(const DistanceT& a) -> DistanceT - { - return a; - } -}; - -struct distance_comp_inner { - template - static inline auto eval(const DistanceT& a, const DistanceT& b) -> DistanceT - { - return -a * b; - } - template - static inline auto postprocess(const DistanceT& a) -> DistanceT - { - return -a; - } -}; - -/** - * Naive CPU implementation of refine operation - * - * All pointers are expected to be accessible on the host. - */ -template -[[gnu::optimize(3), gnu::optimize("tree-vectorize")]] void refine_host( - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded) -{ - refine_check_input(dataset.extents(), - queries.extents(), - neighbor_candidates.extents(), - indices.extents(), - distances.extents(), - metric); - - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - return refine_host_impl( - dataset, queries, neighbor_candidates, indices, distances); - case raft::distance::DistanceType::InnerProduct: - return refine_host_impl( - dataset, queries, neighbor_candidates, indices, distances); - default: throw raft::logic_error("Unsupported metric"); - } -} - -} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/refine_host.hpp b/cpp/include/raft/neighbors/detail/refine_host.hpp deleted file mode 100644 index ff0de75660..0000000000 --- a/cpp/include/raft/neighbors/detail/refine_host.hpp +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "refine_host-inl.hpp" -#endif - -#ifdef RAFT_COMPILED -#include "refine_host-ext.hpp" -#endif diff --git a/cpp/include/raft/neighbors/detail/vpq_dataset.cuh b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh deleted file mode 100644 index f6cd2a1ceb..0000000000 --- a/cpp/include/raft/neighbors/detail/vpq_dataset.cuh +++ /dev/null @@ -1,427 +0,0 @@ -/* - * Copyright (c) 2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "../dataset.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include // pq_bits-bitfield -#include // utils::mapping etc -#include -#include - -// A temporary stub till https://github.com/rapidsai/raft/pull/2077 is re-merged -namespace raft::util { - -/** - * Subsample the dataset to create a training set. - * - * @tparam DatasetT a row-major mdspan or mdarray (device or host) - * - * @param res raft handle - * @param dataset input row-major mdspan or mdarray (device or host) - * @param n_samples the size of the output mdarray - * - * @return a newly allocated subset of the dataset. - */ -template -auto subsample(raft::resources const& res, - const DatasetT& dataset, - typename DatasetT::index_type n_samples) - -> raft::device_matrix -{ - using value_type = typename DatasetT::value_type; - using index_type = typename DatasetT::index_type; - static_assert(std::is_same_v, - "Only row-major layout is supported at the moment"); - RAFT_EXPECTS(n_samples <= dataset.extent(0), - "The number of samples must be smaller than the number of input rows in the current " - "implementation."); - size_t dim = dataset.extent(1); - size_t trainset_ratio = dataset.extent(0) / n_samples; - auto result = raft::make_device_matrix(res, n_samples, dataset.extent(1)); - - RAFT_CUDA_TRY(cudaMemcpy2DAsync(result.data_handle(), - sizeof(value_type) * dim, - dataset.data_handle(), - sizeof(value_type) * dim * trainset_ratio, - sizeof(value_type) * dim, - n_samples, - cudaMemcpyDefault, - raft::resource::get_cuda_stream(res))); - return result; -} - -} // namespace raft::util - -namespace raft::neighbors::detail { - -template -auto fill_missing_params_heuristics(const vpq_params& params, const DatasetT& dataset) -> vpq_params -{ - vpq_params r = params; - double n_rows = dataset.extent(0); - size_t dim = dataset.extent(1); - if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{4}); } - if (r.pq_bits == 0) { r.pq_bits = 8; } - if (r.vq_n_centers == 0) { r.vq_n_centers = raft::round_up_safe(std::sqrt(n_rows), 8); } - if (r.vq_kmeans_trainset_fraction == 0) { - double vq_trainset_size = 100.0 * r.vq_n_centers; - r.vq_kmeans_trainset_fraction = std::min(1.0, vq_trainset_size / n_rows); - } - if (r.pq_kmeans_trainset_fraction == 0) { - // NB: we'll have actually `pq_dim` times more samples than this - // (because the dataset is reinterpreted as `[n_rows * pq_dim, pq_len]`) - double pq_trainset_size = 1000.0 * (1u << r.pq_bits); - r.pq_kmeans_trainset_fraction = std::min(1.0, pq_trainset_size / n_rows); - } - return r; -} - -template -auto transform_data(const raft::resources& res, DatasetT dataset) - -> device_mdarray -{ - using index_type = typename DatasetT::index_type; - using extents_type = typename DatasetT::extents_type; - using layout_type = typename DatasetT::layout_type; - using out_mdarray_type = device_mdarray; - if constexpr (std::is_same_v>) { return dataset; } - - auto result = raft::make_device_mdarray(res, dataset.extents()); - - linalg::map(res, - result.view(), - spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(dataset.view())); - - return result; -} - -/** Fix the internal indexing type to avoid integer underflows/overflows */ -using ix_t = int64_t; - -template -auto train_vq(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) - -> device_matrix -{ - const ix_t n_rows = dataset.extent(0); - const ix_t vq_n_centers = params.vq_n_centers; - const ix_t dim = dataset.extent(1); - const ix_t n_rows_train = n_rows * params.vq_kmeans_trainset_fraction; - - // Subsample the dataset and transform into the required type if necessary - auto vq_trainset = raft::util::subsample(res, dataset, n_rows_train); - auto vq_centers = raft::make_device_matrix(res, vq_n_centers, dim); - - using kmeans_in_type = typename DatasetT::value_type; - raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = raft::distance::DistanceType::L2Expanded; - auto vq_centers_view = - raft::make_device_matrix_view(vq_centers.data_handle(), vq_n_centers, dim); - auto vq_trainset_view = raft::make_device_matrix_view( - vq_trainset.data_handle(), n_rows_train, dim); - raft::cluster::kmeans_balanced::fit( - res, - kmeans_params, - vq_trainset_view, - vq_centers_view, - spatial::knn::detail::utils::mapping{}); - - return vq_centers; -} - -template -auto predict_vq(const raft::resources& res, const DatasetT& dataset, const VqCentersT& vq_centers) - -> device_vector -{ - using kmeans_data_type = typename DatasetT::value_type; - using kmeans_math_type = typename VqCentersT::value_type; - using index_type = typename DatasetT::index_type; - using label_type = LabelT; - - auto vq_labels = raft::make_device_vector(res, dataset.extent(0)); - - raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.metric = raft::distance::DistanceType::L2Expanded; - - auto vq_centers_view = raft::make_device_matrix_view( - vq_centers.data_handle(), vq_centers.extent(0), vq_centers.extent(1)); - - auto vq_dataset_view = raft::make_device_matrix_view( - dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - - raft::cluster::kmeans_balanced:: - predict( - res, - kmeans_params, - vq_dataset_view, - vq_centers_view, - vq_labels.view(), - spatial::knn::detail::utils::mapping{}); - - return vq_labels; -} - -template -auto train_pq(const raft::resources& res, - const vpq_params& params, - const DatasetT& dataset, - const device_matrix_view& vq_centers) - -> device_matrix -{ - const ix_t n_rows = dataset.extent(0); - const ix_t dim = dataset.extent(1); - const ix_t pq_dim = params.pq_dim; - const ix_t pq_bits = params.pq_bits; - const ix_t pq_n_centers = ix_t{1} << pq_bits; - const ix_t pq_len = raft::div_rounding_up_safe(dim, pq_dim); - const ix_t n_rows_train = n_rows * params.pq_kmeans_trainset_fraction; - - // Subsample the dataset and transform into the required type if necessary - auto pq_trainset = transform_data(res, raft::util::subsample(res, dataset, n_rows_train)); - - // Subtract VQ centers - { - auto vq_labels = predict_vq(res, pq_trainset, vq_centers); - using index_type = typename DatasetT::index_type; - linalg::map_offset( - res, - pq_trainset.view(), - [labels = vq_labels.view(), centers = vq_centers, dim] __device__(index_type off, MathT x) { - index_type i = off / dim; - index_type j = off % dim; - return x - centers(labels(i), j); - }, - raft::make_const_mdspan(pq_trainset.view())); - } - - auto pq_centers = raft::make_device_matrix(res, pq_n_centers, pq_len); - - // Train PQ centers - { - raft::cluster::kmeans_balanced_params kmeans_params; - kmeans_params.n_iters = params.kmeans_n_iters; - kmeans_params.metric = raft::distance::DistanceType::L2Expanded; - - auto pq_centers_view = - raft::make_device_matrix_view(pq_centers.data_handle(), pq_n_centers, pq_len); - - auto pq_trainset_view = raft::make_device_matrix_view( - pq_trainset.data_handle(), n_rows_train * pq_dim, pq_len); - - raft::cluster::kmeans_balanced::fit( - res, kmeans_params, pq_trainset_view, pq_centers_view); - } - - return pq_centers; -} - -template -__device__ auto compute_code(device_matrix_view dataset, - device_matrix_view vq_centers, - device_matrix_view pq_centers, - IdxT i, - uint32_t j, - LabelT vq_label) -> uint8_t -{ - auto data_mapping = spatial::knn::detail::utils::mapping{}; - uint32_t lane_id = Pow2::mod(laneId()); - - const uint32_t pq_book_size = pq_centers.extent(0); - const uint32_t pq_len = pq_centers.extent(1); - float min_dist = std::numeric_limits::infinity(); - uint8_t code = 0; - // calculate the distance for each PQ cluster, find the minimum for each thread - for (uint32_t l = lane_id; l < pq_book_size; l += SubWarpSize) { - // NB: the L2 quantifiers on residuals are always trained on L2 metric. - float d = 0.0f; - for (uint32_t k = 0; k < pq_len; k++) { - auto jk = j * pq_len + k; - auto x = data_mapping(dataset(i, jk)) - vq_centers(vq_label, jk); - auto t = x - pq_centers(l, k); - d += t * t; - } - if (d < min_dist) { - min_dist = d; - code = uint8_t(l); - } - } - // reduce among threads -#pragma unroll - for (uint32_t stride = SubWarpSize >> 1; stride > 0; stride >>= 1) { - const auto other_dist = shfl_xor(min_dist, stride, SubWarpSize); - const auto other_code = shfl_xor(code, stride, SubWarpSize); - if (other_dist < min_dist) { - min_dist = other_dist; - code = other_code; - } - } - return code; -} - -template -__launch_bounds__(BlockSize) RAFT_KERNEL - process_and_fill_codes_kernel(device_matrix_view out_codes, - device_matrix_view dataset, - device_matrix_view vq_centers, - device_vector_view vq_labels, - device_matrix_view pq_centers) -{ - constexpr uint32_t kSubWarpSize = std::min(WarpSize, 1u << PqBits); - using subwarp_align = Pow2; - const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x}); - if (row_ix >= out_codes.extent(0)) { return; } - - const uint32_t pq_dim = raft::div_rounding_up_unsafe(vq_centers.extent(1), pq_centers.extent(1)); - - const uint32_t lane_id = Pow2::mod(threadIdx.x); - const LabelT vq_label = vq_labels(row_ix); - - // write label - auto* out_label_ptr = reinterpret_cast(&out_codes(row_ix, 0)); - if (lane_id == 0) { *out_label_ptr = vq_label; } - - auto* out_codes_ptr = reinterpret_cast(out_label_ptr + 1); - ivf_pq::detail::bitfield_view_t code_view{out_codes_ptr}; - for (uint32_t j = 0; j < pq_dim; j++) { - // find PQ label - uint8_t code = compute_code(dataset, vq_centers, pq_centers, row_ix, j, vq_label); - // TODO: this writes in global memory one byte per warp, which is very slow. - // It's better to keep the codes in the shared memory or registers and dump them at once. - if (lane_id == 0) { code_view[j] = code; } - } -} - -template -auto process_and_fill_codes(const raft::resources& res, - const vpq_params& params, - const DatasetT& dataset, - device_matrix_view vq_centers, - device_matrix_view pq_centers) - -> device_matrix -{ - using data_t = typename DatasetT::value_type; - using cdataset_t = vpq_dataset; - using label_t = uint32_t; - - const ix_t n_rows = dataset.extent(0); - const ix_t dim = dataset.extent(1); - const ix_t pq_dim = params.pq_dim; - const ix_t pq_bits = params.pq_bits; - const ix_t pq_n_centers = ix_t{1} << pq_bits; - // NB: codes must be aligned at least to sizeof(label_t) to be able to read labels. - const ix_t codes_rowlen = - sizeof(label_t) * (1 + raft::div_rounding_up_safe(pq_dim * pq_bits, 8 * sizeof(label_t))); - - auto codes = raft::make_device_matrix(res, n_rows, codes_rowlen); - - auto stream = raft::resource::get_cuda_stream(res); - - // TODO: with scaling workspace we could choose the batch size dynamically - constexpr ix_t kReasonableMaxBatchSize = 65536; - constexpr ix_t kBlockSize = 256; - const ix_t threads_per_vec = std::min(WarpSize, pq_n_centers); - dim3 threads(kBlockSize, 1, 1); - ix_t max_batch_size = std::min(n_rows, kReasonableMaxBatchSize); - auto kernel = [](uint32_t pq_bits) { - switch (pq_bits) { - case 4: return process_and_fill_codes_kernel; - case 5: return process_and_fill_codes_kernel; - case 6: return process_and_fill_codes_kernel; - case 7: return process_and_fill_codes_kernel; - case 8: return process_and_fill_codes_kernel; - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - }(pq_bits); - for (const auto& batch : - spatial::knn::detail::utils::batch_load_iterator(dataset.data_handle(), - n_rows, - dim, - max_batch_size, - stream, - rmm::mr::get_current_device_resource())) { - auto batch_view = raft::make_device_matrix_view(batch.data(), ix_t(batch.size()), dim); - auto labels = predict_vq(res, batch_view, vq_centers); - dim3 blocks(div_rounding_up_safe(n_rows, kBlockSize / threads_per_vec), 1, 1); - kernel<<>>( - make_device_matrix_view( - codes.data_handle() + batch.offset() * codes_rowlen, batch.size(), codes_rowlen), - batch_view, - vq_centers, - make_const_mdspan(labels.view()), - pq_centers); - RAFT_CUDA_TRY(cudaPeekAtLastError()); - } - - return codes; -} - -template -auto vpq_convert_math_type(const raft::resources& res, vpq_dataset&& src) - -> vpq_dataset -{ - auto vq_code_book = make_device_mdarray(res, src.vq_code_book.extents()); - auto pq_code_book = make_device_mdarray(res, src.pq_code_book.extents()); - - linalg::map(res, - vq_code_book.view(), - spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.vq_code_book.view())); - linalg::map(res, - pq_code_book.view(), - spatial::knn::detail::utils::mapping{}, - raft::make_const_mdspan(src.pq_code_book.view())); - return vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(src.data)}; -} - -template -auto vpq_build(const raft::resources& res, const vpq_params& params, const DatasetT& dataset) - -> vpq_dataset -{ - // Use a heuristic to impute missing parameters. - auto ps = fill_missing_params_heuristics(params, dataset); - - // Train codes - auto vq_code_book = train_vq(res, ps, dataset); - auto pq_code_book = - train_pq(res, ps, dataset, raft::make_const_mdspan(vq_code_book.view())); - - // Encode dataset - auto codes = process_and_fill_codes(res, - ps, - dataset, - raft::make_const_mdspan(vq_code_book.view()), - raft::make_const_mdspan(pq_code_book.view())); - - return vpq_dataset{ - std::move(vq_code_book), std::move(pq_code_book), std::move(codes)}; -} - -} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/hnsw.hpp b/cpp/include/raft/neighbors/hnsw.hpp deleted file mode 100644 index ee3f61e550..0000000000 --- a/cpp/include/raft/neighbors/hnsw.hpp +++ /dev/null @@ -1,142 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "detail/hnsw.hpp" -#include "hnsw.hpp" - -#include -#include -#include - -#include -#include - -namespace raft::neighbors::hnsw { - -/** - * @addtogroup hnsw Build CAGRA index and search with hnswlib - * @{ - */ - -/** - * @brief Construct an hnswlib base-layer-only index from a CAGRA index - * NOTE: 1. This method uses the filesystem to write the CAGRA index in `/tmp/.bin` - * before reading it as an hnswlib index, then deleting the temporary file. - * 2. This function is only offered as a compiled symbol in `libraft.so` - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] res raft resources - * @param[in] cagra_index cagra index - * - * Usage example: - * @code{.cpp} - * // Build a CAGRA index - * using namespace raft::neighbors; - * // use default index parameters - * cagra::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = cagra::build(res, index_params, dataset); - * - * // Load CAGRA index as base-layer-only hnswlib index - * auto hnsw_index = hnsw::from_cagra(res, index); - * @endcode - */ -template -std::unique_ptr> from_cagra(raft::resources const& res, - raft::neighbors::cagra::index cagra_index); - -template <> -std::unique_ptr> from_cagra( - raft::resources const& res, raft::neighbors::cagra::index cagra_index); - -template <> -std::unique_ptr> from_cagra( - raft::resources const& res, raft::neighbors::cagra::index cagra_index); - -template <> -std::unique_ptr> from_cagra( - raft::resources const& res, raft::neighbors::cagra::index cagra_index); - -/** - * @brief Search hnswlib base-layer-only index constructed from a CAGRA index - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] res raft resources - * @param[in] params configure the search - * @param[in] idx cagra index - * @param[in] queries a host matrix view to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a host matrix view to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a host matrix view to the distances to the selected neighbors [n_queries, - * k] - * - * Usage example: - * @code{.cpp} - * // Build a CAGRA index - * using namespace raft::neighbors; - * // use default index parameters - * cagra::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = cagra::build(res, index_params, dataset); - * - * // Save CAGRA index as base layer only hnswlib index - * hnsw::serialize(res, "my_index.bin", index); - * - * // Load CAGRA index as base layer only hnswlib index - * raft::neighbors::hnsw::index* hnsw_index; - * auto hnsw_index = hnsw::deserialize(res, "my_index.bin", D, raft::distance::L2Expanded); - * - * // Search K nearest neighbors as an hnswlib index - * // using host threads for concurrency - * hnsw::search_params search_params; - * search_params.ef = 50 // ef >= K; - * search_params.num_threads = 10; - * auto neighbors = raft::make_host_matrix(res, n_queries, k); - * auto distances = raft::make_host_matrix(res, n_queries, k); - * hnsw::search(res, search_params, *index, queries, neighbors, distances); - * // de-allocate hnsw_index - * delete hnsw_index; - * @endcode - */ -template -void search(raft::resources const& res, - const search_params& params, - const index& idx, - raft::host_matrix_view queries, - raft::host_matrix_view neighbors, - raft::host_matrix_view distances) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must equal k"); - RAFT_EXPECTS(queries.extent(1) == idx.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - detail::search(res, params, idx, queries, neighbors, distances); -} - -/**@}*/ - -} // namespace raft::neighbors::hnsw diff --git a/cpp/include/raft/neighbors/hnsw_serialize.hpp b/cpp/include/raft/neighbors/hnsw_serialize.hpp deleted file mode 100644 index f7450bb3d6..0000000000 --- a/cpp/include/raft/neighbors/hnsw_serialize.hpp +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "detail/hnsw_serialize.hpp" -#include "hnsw_types.hpp" - -#include -#include - -namespace raft::neighbors::hnsw { - -/** - * @defgroup hnsw_serialize HNSW Serialize - * @{ - */ - -/** - * Load an hnswlib index which was serialized from a CAGRA index - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * // create an an unallocated pointer - * int dim = 10; - * raft::distance::DistanceType = raft::distance::L2Expanded - * auto index = raft::deserialize(handle, filename, dim, metric); - * @endcode - * - * @tparam T data element type - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] dim dimensionality of the index - * @param[in] metric metric used to build the index - * - * @return std::unique_ptr> - * - */ -template -std::unique_ptr> deserialize(raft::resources const& handle, - const std::string& filename, - int dim, - raft::distance::DistanceType metric) -{ - return detail::deserialize(handle, filename, dim, metric); -} - -/**@}*/ - -} // namespace raft::neighbors::hnsw diff --git a/cpp/include/raft/neighbors/hnsw_types.hpp b/cpp/include/raft/neighbors/hnsw_types.hpp deleted file mode 100644 index f78571f491..0000000000 --- a/cpp/include/raft/neighbors/hnsw_types.hpp +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "ann_types.hpp" - -#include - -#include - -#include -#include -#include - -namespace raft::neighbors::hnsw { - -/** - * @defgroup hnsw Build CAGRA index and search with hnswlib - * @{ - */ - -struct search_params : ann::search_params { - int ef; // size of the candidate list - int num_threads = 0; // number of host threads to use for concurrent searches. Value of 0 - // automatically maximizes parallelism -}; -template -struct index : ann::index { - public: - /** - * @brief load a base-layer-only hnswlib index originally saved from a built CAGRA index. - * This is a virtual class and it cannot be used directly. To create an index, use the factory - * function `raft::neighbors::hnsw::from_cagra` from the header - * `raft/neighbors/hnsw.hpp` - * - * @param[in] dim dimensions of the training dataset - * @param[in] metric distance metric to search. Supported metrics ("L2Expanded", "InnerProduct") - */ - [[deprecated("Use cuVS instead")]] index(int dim, raft::distance::DistanceType metric) - : dim_{dim}, metric_{metric} - { - } - - /** - @brief Get underlying index - */ - virtual auto get_index() const -> void const* = 0; - - auto dim() const -> int const { return dim_; } - - auto metric() const -> raft::distance::DistanceType { return metric_; } - - /** - @brief Set ef for search - */ - virtual void set_ef(int ef) const; - - private: - int dim_; - raft::distance::DistanceType metric_; -}; - -/**@}*/ - -} // namespace raft::neighbors::hnsw diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh deleted file mode 100644 index 12ab0dc3a6..0000000000 --- a/cpp/include/raft/neighbors/ivf_flat-ext.cuh +++ /dev/null @@ -1,257 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include // raft::device_matrix_view -#include // raft::resources -#include -#include // raft::neighbors::ivf_flat::index -#include // RAFT_EXPLICIT - -#include - -#include // int64_t - -#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY - -namespace raft::neighbors::ivf_flat { - -template -auto build(raft::resources const& handle, - const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index RAFT_EXPLICIT; - -template -auto build(raft::resources const& handle, - const index_params& params, - raft::device_matrix_view dataset) - -> index RAFT_EXPLICIT; - -template -void build(raft::resources const& handle, - const index_params& params, - raft::device_matrix_view dataset, - raft::neighbors::ivf_flat::index& idx) RAFT_EXPLICIT; - -template -auto build(raft::resources const& handle, - const index_params& params, - raft::host_matrix_view dataset) - -> index RAFT_EXPLICIT; - -template -void build(raft::resources const& handle, - const index_params& params, - raft::host_matrix_view dataset, - raft::neighbors::ivf_flat::index& idx) RAFT_EXPLICIT; - -template -auto extend(raft::resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index RAFT_EXPLICIT; - -template -auto extend(raft::resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - const index& orig_index) -> index RAFT_EXPLICIT; - -template -void extend(raft::resources const& handle, - index* index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) RAFT_EXPLICIT; - -template -void extend(raft::resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - index* index) RAFT_EXPLICIT; - -template -auto extend(raft::resources const& handle, - raft::host_matrix_view new_vectors, - std::optional> new_indices, - const raft::neighbors::ivf_flat::index& orig_index) - -> raft::neighbors::ivf_flat::index RAFT_EXPLICIT; - -template -void extend(raft::resources const& handle, - raft::host_matrix_view new_vectors, - std::optional> new_indices, - index* index) RAFT_EXPLICIT; - -template -void search_with_filtering(raft::resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::device_async_resource_ref mr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; - -template -void search(raft::resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::device_async_resource_ref mr) RAFT_EXPLICIT; - -template -void search_with_filtering(raft::resources const& handle, - const search_params& params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) RAFT_EXPLICIT; - -template -void search(raft::resources const& handle, - const search_params& params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) RAFT_EXPLICIT; - -} // namespace raft::neighbors::ivf_flat - -#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY - -#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ - extern template auto raft::neighbors::ivf_flat::build( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_flat::index_params& params, \ - const T* dataset, \ - IdxT n_rows, \ - uint32_t dim) \ - ->raft::neighbors::ivf_flat::index; \ - \ - extern template auto raft::neighbors::ivf_flat::build( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset) \ - ->raft::neighbors::ivf_flat::index; \ - \ - extern template void raft::neighbors::ivf_flat::build( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_flat::index_params& params, \ - raft::device_matrix_view dataset, \ - raft::neighbors::ivf_flat::index& idx); \ - \ - extern template auto raft::neighbors::ivf_flat::build( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset) \ - ->raft::neighbors::ivf_flat::index; \ - \ - extern template void raft::neighbors::ivf_flat::build( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_flat::index_params& params, \ - raft::host_matrix_view dataset, \ - raft::neighbors::ivf_flat::index& idx); - -instantiate_raft_neighbors_ivf_flat_build(float, int64_t); -instantiate_raft_neighbors_ivf_flat_build(int8_t, int64_t); -instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t); -#undef instantiate_raft_neighbors_ivf_flat_build - -#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ - extern template auto raft::neighbors::ivf_flat::extend( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_flat::index& orig_index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows) \ - ->raft::neighbors::ivf_flat::index; \ - \ - extern template auto raft::neighbors::ivf_flat::extend( \ - raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const raft::neighbors::ivf_flat::index& orig_index) \ - ->raft::neighbors::ivf_flat::index; \ - \ - extern template void raft::neighbors::ivf_flat::extend( \ - raft::resources const& handle, \ - raft::neighbors::ivf_flat::index* index, \ - const T* new_vectors, \ - const IdxT* new_indices, \ - IdxT n_rows); \ - \ - extern template void raft::neighbors::ivf_flat::extend( \ - raft::resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* index); \ - \ - extern template void raft::neighbors::ivf_flat::extend( \ - raft::resources const& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* index); \ - \ - extern template auto raft::neighbors::ivf_flat::extend( \ - const raft::resources& handle, \ - raft::host_matrix_view new_vectors, \ - std::optional> new_indices, \ - const raft::neighbors::ivf_flat::index& idx) \ - ->raft::neighbors::ivf_flat::index; - -instantiate_raft_neighbors_ivf_flat_extend(float, int64_t); -instantiate_raft_neighbors_ivf_flat_extend(int8_t, int64_t); -instantiate_raft_neighbors_ivf_flat_extend(uint8_t, int64_t); - -#undef instantiate_raft_neighbors_ivf_flat_extend - -#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ - extern template void raft::neighbors::ivf_flat::search( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_flat::search_params& params, \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - uint32_t n_queries, \ - uint32_t k, \ - IdxT* neighbors, \ - float* distances, \ - rmm::device_async_resource_ref mr); \ - \ - extern template void raft::neighbors::ivf_flat::search( \ - raft::resources const& handle, \ - const raft::neighbors::ivf_flat::search_params& params, \ - const raft::neighbors::ivf_flat::index& index, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); - -instantiate_raft_neighbors_ivf_flat_search(float, int64_t); -instantiate_raft_neighbors_ivf_flat_search(int8_t, int64_t); -instantiate_raft_neighbors_ivf_flat_search(uint8_t, int64_t); - -#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh deleted file mode 100644 index ea7cff7060..0000000000 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ /dev/null @@ -1,664 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include - -#include -#include - -namespace raft::neighbors::ivf_flat { - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_flat::build(handle, index_params, dataset, N, D); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, N, K, out_inds, out_dists); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a host or device pointer to a row-major matrix [n_rows, dim] - * @param[in] n_rows the number of samples - * @param[in] dim the dimensionality of the data - * - * @return the constructed ivf-flat index - */ -template -auto build(raft::resources const& handle, - const index_params& params, - const T* dataset, - IdxT n_rows, - uint32_t dim) -> index -{ - return raft::neighbors::ivf_flat::detail::build(handle, params, dataset, n_rows, dim); -} - -/** - * @defgroup ivf_flat IVF Flat Algorithm - * @{ - */ - -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * auto index = ivf_flat::build(handle, dataset, index_params); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset a device matrix [n_rows, dim] - * - * @return the constructed ivf-flat index - */ -template -auto build(raft::resources const& handle, - const index_params& params, - raft::device_matrix_view dataset) -> index -{ - return raft::neighbors::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); -} - -/** - * @brief Build the index from a dataset in host memory. - */ -template -auto build(raft::resources const& handle, - const index_params& params, - raft::host_matrix_view dataset) -> index -{ - return raft::neighbors::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); -} -/** - * @brief Build the index from the dataset for efficient search. - * - * NB: Currently, the following distance metrics are supported: - * - L2Expanded - * - L2Unexpanded - * - InnerProduct - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // create and fill the index from a [N, D] dataset - * ivf_flat::index index; - * ivf_flat::build(handle, dataset, index_params, index); - * // use default search parameters - * ivf_flat::search_params search_params; - * // search K nearest neighbours for each of the N queries - * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] params configure the index building - * @param[in] dataset raft::device_matrix_view to a row-major matrix [n_rows, dim] - * @param[out] idx reference to ivf_flat::index - * - */ -template -void build(raft::resources const& handle, - const index_params& params, - raft::device_matrix_view dataset, - raft::neighbors::ivf_flat::index& idx) -{ - idx = raft::neighbors::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); -} - -/** - * @brief Build the index from a dataset in host memory. - */ -template -void build(raft::resources const& handle, - const index_params& params, - raft::host_matrix_view dataset, - raft::neighbors::ivf_flat::index& idx) -{ - idx = raft::neighbors::ivf_flat::detail::build(handle, - params, - dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); -} -/** @} */ - -/** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are adjusted to match the newly labeled data. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); - * // fill the index with the data - * auto index = ivf_flat::extend(handle, index_empty, dataset, nullptr, N); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] orig_index original index - * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows number of rows in `new_vectors` - * - * @return the constructed extended ivf-flat index - */ -template -auto extend(raft::resources const& handle, - const index& orig_index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -> index -{ - return raft::neighbors::ivf_flat::detail::extend( - handle, orig_index, new_vectors, new_indices, n_rows); -} - -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Build a new index containing the data of the original plus new extra vectors. - * - * Implementation note: - * The new data is clustered according to existing kmeans clusters, then the cluster - * centers are adjusted to match the newly labeled data. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); - * // fill the index with the data - * std::optional> no_op = std::nullopt; - * auto index = ivf_flat::extend(handle, index_empty, no_op, dataset); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] orig_index original index - * - * @return the constructed extended ivf-flat index - */ -template -auto extend(raft::resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - const index& orig_index) -> index -{ - return extend(handle, - orig_index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - new_vectors.extent(0)); -} - -/** - * @brief Extend the index with additional vectors. - * - * This overloads takes input data in host memory. - */ -template -auto extend(raft::resources const& handle, - raft::host_matrix_view new_vectors, - std::optional> new_indices, - const index& orig_index) -> index -{ - return extend(handle, - orig_index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - new_vectors.extent(0)); -} -/** @} */ - -/** - * @brief Extend the index in-place with the new data. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset, N, D); - * // fill the index with the data - * ivf_flat::extend(handle, index_empty, dataset, nullptr, N); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param handle - * @param[inout] index - * @param[in] new_vectors a device/host pointer to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices a device/host pointer to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `nullptr` - * here to imply a continuous range `[0...n_rows)`. - * @param[in] n_rows the number of samples - */ -template -void extend(raft::resources const& handle, - index* index, - const T* new_vectors, - const IdxT* new_indices, - IdxT n_rows) -{ - raft::neighbors::ivf_flat::detail::extend(handle, index, new_vectors, new_indices, n_rows); -} - -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Extend the index in-place with the new data. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * ivf_flat::index_params index_params; - * index_params.add_data_on_build = false; // don't populate index on build - * index_params.kmeans_trainset_fraction = 1.0; // use whole dataset for kmeans training - * // train the index from a [N, D] dataset - * auto index_empty = ivf_flat::build(handle, index_params, dataset); - * // fill the index with the data - * std::optional> no_op = std::nullopt; - * ivf_flat::extend(handle, dataset, no_opt, &index_empty); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - * @param[in] handle - * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] - * @param[in] new_indices optional raft::device_vector_view to a vector of indices [n_rows]. - * If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt` - * here to imply a continuous range `[0...n_rows)`. - * @param[inout] index pointer to index, to be overwritten in-place - */ -template -void extend(raft::resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - index* index) -{ - extend(handle, - index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - static_cast(new_vectors.extent(0))); -} - -/** - * @brief Extend the index with additional vectors. - * - * This overloads takes input data in host memory. - */ -template -void extend(raft::resources const& handle, - raft::host_matrix_view new_vectors, - std::optional> new_indices, - index* index) -{ - extend(handle, - index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - static_cast(new_vectors.extent(0))); -} -/** @} */ - -/** - * @brief Search ANN using the constructed index with the given filter. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // Create a pooling memory resource with a pre-defined initial size. - * rmm::mr::pool_memory_resource mr( - * rmm::mr::get_current_device_resource(), 1024 * 1024); - * // use default search parameters - * ivf_flat::search_params search_params; - * filtering::none_ivf_sample_filter filter; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search_with_filtering( - * handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr, filter); - * ivf_flat::search_with_filtering( - * handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr, filter); - * ivf_flat::search_with_filtering( - * handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr, filter); - * ... - * @endcode - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * @tparam IvfSampleFilterT Device filter function, with the signature - * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` or - * `(uint32_t query_ix, uint32 sample_ix) -> bool` - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[in] n_queries the batch size - * @param[in] k the number of neighbors to find for each query. - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] mr an optional memory resource to use across the searches (you can provide a large - * enough memory pool here to avoid memory allocations within search). - * @param[in] sample_filter a device filter function that greenlights samples for a given query - */ -template -void search_with_filtering(raft::resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::device_async_resource_ref mr, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) -{ - raft::neighbors::ivf_flat::detail::search( - handle, params, index, queries, n_queries, k, neighbors, distances, mr, sample_filter); -} - -/** - * @brief Search ANN using the constructed index using the given filter. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // Create a pooling memory resource with a pre-defined initial size. - * rmm::mr::pool_memory_resource mr( - * rmm::mr::get_current_device_resource(), 1024 * 1024); - * // use default search parameters - * ivf_flat::search_params search_params; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search(handle, search_params, index, queries1, N1, K, out_inds1, out_dists1, &mr); - * ivf_flat::search(handle, search_params, index, queries2, N2, K, out_inds2, out_dists2, &mr); - * ivf_flat::search(handle, search_params, index, queries3, N3, K, out_inds3, out_dists3, &mr); - * ... - * @endcode - * The exact size of the temporary buffer depends on multiple factors and is an implementation - * detail. However, you can safely specify a small initial size for the memory pool, so that only a - * few allocations happen to grow it during the first invocations of the `search`. - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[in] n_queries the batch size - * @param[in] k the number of neighbors to find for each query. - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] mr an optional memory resource to use across the searches (you can provide a large - * enough memory pool here to avoid memory allocations within search). - */ -template -void search(raft::resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::device_async_resource_ref mr) -{ - raft::neighbors::ivf_flat::detail::search(handle, - params, - index, - queries, - n_queries, - k, - neighbors, - distances, - mr, - raft::neighbors::filtering::none_ivf_sample_filter()); -} - -/** - * @ingroup ivf_flat - * @{ - */ - -/** - * @brief Search ANN using the constructed index with the given filter. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // use default search parameters - * ivf_flat::search_params search_params; - * filtering::none_ivf_sample_filter filter; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search_with_filtering( - * handle, search_params, index, queries1, out_inds1, out_dists1, filter); - * ivf_flat::search_with_filtering( - * handle, search_params, index, queries2, out_inds2, out_dists2, filter); - * ivf_flat::search_with_filtering( - * handle, search_params, index, queries3, out_inds3, out_dists3, filter); - * ... - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * @tparam IvfSampleFilterT Device filter function, with the signature - * `(uint32_t query_ix, uint32 cluster_ix, uint32_t sample_ix) -> bool` or - * `(uint32_t query_ix, uint32 sample_ix) -> bool` - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - * @param[in] sample_filter a device filter function that greenlights samples for a given query - */ -template -void search_with_filtering(raft::resources const& handle, - const search_params& params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances, - IvfSampleFilterT sample_filter = IvfSampleFilterT()) -{ - RAFT_EXPECTS( - queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), - "Number of rows in output neighbors and distances matrices must equal the number of queries."); - - RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), - "Number of columns in output neighbors and distances matrices must be equal"); - - RAFT_EXPECTS(queries.extent(1) == index.dim(), - "Number of query dimensions should equal number of dimensions in the index."); - - search_with_filtering(handle, - params, - index, - queries.data_handle(), - static_cast(queries.extent(0)), - static_cast(neighbors.extent(1)), - neighbors.data_handle(), - distances.data_handle(), - resource::get_workspace_resource(handle), - sample_filter); -} - -/** - * @brief Search ANN using the constructed index. - * - * See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example. - * - * Note, this function requires a temporary buffer to store intermediate results between cuda kernel - * calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can - * pass a pool memory resource or a large enough pre-allocated memory resource to reduce or - * eliminate entirely allocations happening within `search`: - * @code{.cpp} - * ... - * // use default search parameters - * ivf_flat::search_params search_params; - * // Use the same allocator across multiple searches to reduce the number of - * // cuda memory allocations - * ivf_flat::search(handle, search_params, index, queries1, out_inds1, out_dists1); - * ivf_flat::search(handle, search_params, index, queries2, out_inds2, out_dists2); - * ivf_flat::search(handle, search_params, index, queries3, out_inds3, out_dists3); - * ... - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle - * @param[in] params configure the search - * @param[in] index ivf-flat constructed index - * @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()] - * @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset - * [n_queries, k] - * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] - */ -template -void search(raft::resources const& handle, - const search_params& params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) -{ - search_with_filtering(handle, - params, - index, - queries, - neighbors, - distances, - raft::neighbors::filtering::none_ivf_sample_filter()); -} - -/** @} */ - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat.cuh deleted file mode 100644 index 8fd9628a41..0000000000 --- a/cpp/include/raft/neighbors/ivf_flat.cuh +++ /dev/null @@ -1,24 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY -#include "ivf_flat-inl.cuh" -#endif - -#ifdef RAFT_COMPILED -#include "ivf_flat-ext.cuh" -#endif diff --git a/cpp/include/raft/neighbors/ivf_flat_codepacker.hpp b/cpp/include/raft/neighbors/ivf_flat_codepacker.hpp deleted file mode 100644 index 5379788ab4..0000000000 --- a/cpp/include/raft/neighbors/ivf_flat_codepacker.hpp +++ /dev/null @@ -1,90 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -namespace raft::neighbors::ivf_flat::codepacker { - -/** - * Write one flat code into a block by the given offset. The offset indicates the id of the record - * in the list. This function interleaves the code and is intended to later copy the interleaved - * codes over to the IVF list on device. NB: no memory allocation happens here; the block must fit - * the record (offset + 1). - * - * @tparam T - * - * @param[in] flat_code input flat code - * @param[out] block block of memory to write interleaved codes to - * @param[in] dim dimension of the flat code - * @param[in] veclen size of interleaved data chunks - * @param[in] offset how many records to skip before writing the data into the list - */ -template -_RAFT_HOST_DEVICE void pack_1( - const T* flat_code, T* block, uint32_t dim, uint32_t veclen, uint32_t offset) -{ - // The data is written in interleaved groups of `index::kGroupSize` vectors - using interleaved_group = neighbors::detail::div_utils; - - // Interleave dimensions of the source vector while recording it. - // NB: such `veclen` is selected, that `dim % veclen == 0` - auto group_offset = interleaved_group::roundDown(offset); - auto ingroup_id = interleaved_group::mod(offset) * veclen; - - for (uint32_t l = 0; l < dim; l += veclen) { - for (uint32_t j = 0; j < veclen; j++) { - block[group_offset * dim + l * kIndexGroupSize + ingroup_id + j] = flat_code[l + j]; - } - } -} - -/** - * Unpack 1 record of a single list (cluster) in the index to fetch the flat code. The offset - * indicates the id of the record. This function fetches one flat code from an interleaved code. - * - * @tparam T - * - * @param[in] block interleaved block. The block can be thought of as the whole inverted list in - * interleaved format. - * @param[out] flat_code output flat code - * @param[in] dim dimension of the flat code - * @param[in] veclen size of interleaved data chunks - * @param[in] offset fetch the flat code by the given offset - */ -template -_RAFT_HOST_DEVICE void unpack_1( - const T* block, T* flat_code, uint32_t dim, uint32_t veclen, uint32_t offset) -{ - // The data is written in interleaved groups of `index::kGroupSize` vectors - using interleaved_group = neighbors::detail::div_utils; - - // NB: such `veclen` is selected, that `dim % veclen == 0` - auto group_offset = interleaved_group::roundDown(offset); - auto ingroup_id = interleaved_group::mod(offset) * veclen; - - for (uint32_t l = 0; l < dim; l += veclen) { - for (uint32_t j = 0; j < veclen; j++) { - flat_code[l + j] = block[group_offset * dim + l * kIndexGroupSize + ingroup_id + j]; - } - } -} -} // namespace raft::neighbors::ivf_flat::codepacker \ No newline at end of file diff --git a/cpp/include/raft/neighbors/ivf_flat_helpers.cuh b/cpp/include/raft/neighbors/ivf_flat_helpers.cuh deleted file mode 100644 index 883e72c839..0000000000 --- a/cpp/include/raft/neighbors/ivf_flat_helpers.cuh +++ /dev/null @@ -1,176 +0,0 @@ -/* - * Copyright (c) 2023-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include - -namespace raft::neighbors::ivf_flat::helpers { -using namespace raft::spatial::knn::detail; // NOLINT -/** - * @defgroup ivf_flat_helpers Helper functions for manipulationg IVF Flat Index - * @{ - */ - -namespace codepacker { - -/** - * Write flat codes into an existing list by the given offset. - * - * NB: no memory allocation happens here; the list must fit the data (offset + n_vec). - * - * Usage example: - * @code{.cpp} - * auto list_data = index.lists()[label]->data.view(); - * // allocate the buffer for the input codes - * auto codes = raft::make_device_matrix(res, n_vec, index.dim()); - * ... prepare n_vecs to pack into the list in codes ... - * // write codes into the list starting from the 42nd position - * ivf_pq::helpers::codepacker::pack( - * res, make_const_mdspan(codes.view()), index.veclen(), 42, list_data); - * @endcode - * - * @tparam T - * @tparam IdxT - * - * @param[in] res - * @param[in] codes flat codes [n_vec, dim] - * @param[in] veclen size of interleaved data chunks - * @param[in] offset how many records to skip before writing the data into the list - * @param[inout] list_data block to write into - */ -template -void pack( - raft::resources const& res, - device_matrix_view codes, - uint32_t veclen, - uint32_t offset, - device_mdspan::list_extents, row_major> list_data) -{ - raft::neighbors::ivf_flat::detail::pack_list_data(res, codes, veclen, offset, list_data); -} - -/** - * @brief Unpack `n_take` consecutive records of a single list (cluster) in the compressed index - * starting at given `offset`. - * - * Usage example: - * @code{.cpp} - * auto list_data = index.lists()[label]->data.view(); - * // allocate the buffer for the output - * uint32_t n_take = 4; - * auto codes = raft::make_device_matrix(res, n_take, index.dim()); - * uint32_t offset = 0; - * // unpack n_take elements from the list - * ivf_pq::helpers::codepacker::unpack(res, list_data, index.veclen(), offset, codes.view()); - * @endcode - * - * @tparam T - * @tparam IdxT - * - * @param[in] res raft resource - * @param[in] list_data block to read from - * @param[in] veclen size of interleaved data chunks - * @param[in] offset - * How many records in the list to skip. - * @param[inout] codes - * the destination buffer [n_take, index.dim()]. - * The length `n_take` defines how many records to unpack, - * it must be <= the list size. - */ -template -void unpack( - raft::resources const& res, - device_mdspan::list_extents, row_major> list_data, - uint32_t veclen, - uint32_t offset, - device_matrix_view codes) -{ - raft::neighbors::ivf_flat::detail::unpack_list_data( - res, list_data, veclen, offset, codes); -} -} // namespace codepacker - -/** - * @brief Public helper API to reset the data and indices ptrs, and the list sizes. Useful for - * externally modifying the index without going through the build stage. The data and indices of the - * IVF lists will be lost. - * - * Usage example: - * @code{.cpp} - * raft::resources res; - * using namespace raft::neighbors; - * // use default index parameters - * ivf_flat::index_params index_params; - * // initialize an empty index - * ivf_flat::index index(res, index_params, D); - * // reset the index's state and list sizes - * ivf_flat::helpers::reset_index(res, &index); - * @endcode - * - * @tparam IdxT - * - * @param[in] res raft resource - * @param[inout] index pointer to IVF-PQ index - */ -template -void reset_index(const raft::resources& res, index* index) -{ - auto stream = resource::get_cuda_stream(res); - - utils::memzero( - index->accum_sorted_sizes().data_handle(), index->accum_sorted_sizes().size(), stream); - utils::memzero(index->list_sizes().data_handle(), index->list_sizes().size(), stream); - utils::memzero(index->data_ptrs().data_handle(), index->data_ptrs().size(), stream); - utils::memzero(index->inds_ptrs().data_handle(), index->inds_ptrs().size(), stream); -} - -/** - * @brief Helper exposing the re-computation of list sizes and related arrays if IVF lists have been - * modified. - * - * Usage example: - * @code{.cpp} - * using namespace raft::neighbors; - * raft::resources res; - * // use default index parameters - * ivf_flat::index_params index_params; - * // initialize an empty index - * ivf_flat::index index(res, index_params, D); - * ivf_flat::helpers::reset_index(res, &index); - * // recompute the internal state of the index - * ivf_flat::helpers::recompute_internal_state(res, &index); - * @endcode - * - * @tparam T - * @tparam IdxT - * - * @param[in] res raft resource - * @param[inout] index pointer to IVF-FLAT index - */ -template -void recompute_internal_state(const raft::resources& res, index* index) -{ - auto& list = index->lists()[0]; - ivf::detail::recompute_internal_state(res, *index); -} - -/** @} */ -} // namespace raft::neighbors::ivf_flat::helpers diff --git a/cpp/include/raft/neighbors/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/ivf_flat_serialize.cuh deleted file mode 100644 index 311c31040e..0000000000 --- a/cpp/include/raft/neighbors/ivf_flat_serialize.cuh +++ /dev/null @@ -1,154 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "detail/ivf_flat_serialize.cuh" - -namespace raft::neighbors::ivf_flat { - -/** - * \defgroup ivf_flat_serialize IVF-Flat Serialize - * @{ - */ - -/** - * Write the index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create an output stream - * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = ivf_flat::build(...);` - * raft::serialize(handle, os, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] os output stream - * @param[in] index IVF-Flat index - * - */ -template -void serialize(raft::resources const& handle, std::ostream& os, const index& index) -{ - detail::serialize(handle, os, index); -} - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * // create an index with `auto index = ivf_flat::build(...);` - * raft::serialize(handle, filename, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index IVF-Flat index - * - */ -template -void serialize(raft::resources const& handle, - const std::string& filename, - const index& index) -{ - detail::serialize(handle, filename, index); -} - -/** - * Load index from input stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create an input stream - * std::istream is(std::cin.rdbuf()); - * using T = float; // data element type - * using IdxT = int; // type of the index - * auto index = raft::deserialize(handle, is); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] is input stream - * - * @return raft::neighbors::ivf_flat::index - */ -template -index deserialize(raft::resources const& handle, std::istream& is) -{ - return detail::deserialize(handle, is); -} - -/** - * Load index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * using T = float; // data element type - * using IdxT = int; // type of the index - * auto index = raft::deserialize(handle, filename); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * - * @return raft::neighbors::ivf_flat::index - */ -template -index deserialize(raft::resources const& handle, const std::string& filename) -{ - return detail::deserialize(handle, filename); -} - -/**@}*/ - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp deleted file mode 100644 index 2cafceb512..0000000000 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ /dev/null @@ -1,397 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include "ann_types.hpp" - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include // std::max -#include -#include -#include - -namespace raft::neighbors::ivf_flat { -/** - * @addtogroup ivf_flat - * @{ - */ - -/** Size of the interleaved group (see `index::data` description). */ -constexpr static uint32_t kIndexGroupSize = 32; - -struct index_params : ann::index_params { - /** The number of inverted lists (clusters) */ - uint32_t n_lists = 1024; - /** The number of iterations searching for kmeans centers (index building). */ - uint32_t kmeans_n_iters = 20; - /** The fraction of data to use during iterative kmeans building. */ - double kmeans_trainset_fraction = 0.5; - /** - * By default (adaptive_centers = false), the cluster centers are trained in `ivf_flat::build`, - * and never modified in `ivf_flat::extend`. As a result, you may need to retrain the index - * from scratch after invoking (`ivf_flat::extend`) a few times with new data, the distribution of - * which is no longer representative of the original training set. - * - * The alternative behavior (adaptive_centers = true) is to update the cluster centers for new - * data when it is added. In this case, `index.centers()` are always exactly the centroids of the - * data in the corresponding clusters. The drawback of this behavior is that the centroids depend - * on the order of adding new data (through the classification of the added data); that is, - * `index.centers()` "drift" together with the changing distribution of the newly added data. - */ - bool adaptive_centers = false; - /** - * By default, the algorithm allocates more space than necessary for individual clusters - * (`list_data`). This allows to amortize the cost of memory allocation and reduce the number of - * data copies during repeated calls to `extend` (extending the database). - * - * The alternative is the conservative allocation behavior; when enabled, the algorithm always - * allocates the minimum amount of memory required to store the given number of records. Set this - * flag to `true` if you prefer to use as little GPU memory for the database as possible. - */ - bool conservative_memory_allocation = false; -}; - -struct search_params : ann::search_params { - /** The number of clusters to search. */ - uint32_t n_probes = 20; -}; - -static_assert(std::is_aggregate_v); -static_assert(std::is_aggregate_v); - -template -struct list_spec { - using value_type = ValueT; - using list_extents = matrix_extent; - using index_type = IdxT; - - SizeT align_max; - SizeT align_min; - uint32_t dim; - - constexpr list_spec(uint32_t dim, bool conservative_memory_allocation) - : dim(dim), - align_min(kIndexGroupSize), - align_max(conservative_memory_allocation ? kIndexGroupSize : 1024) - { - } - - // Allow casting between different size-types (for safer size and offset calculations) - template - constexpr explicit list_spec(const list_spec& other_spec) - : dim{other_spec.dim}, align_min{other_spec.align_min}, align_max{other_spec.align_max} - { - } - - /** Determine the extents of an array enough to hold a given amount of data. */ - constexpr auto make_list_extents(SizeT n_rows) const -> list_extents - { - return make_extents(n_rows, dim); - } -}; - -template -using list_data = ivf::list; - -/** - * @brief IVF-flat index. - * - * @tparam T data element type - * @tparam IdxT type of the indices in the source dataset - * - */ -template -struct index : ann::index { - static_assert(!raft::is_narrowing_v, - "IdxT must be able to represent all values of uint32_t"); - - public: - /** - * Vectorized load/store size in elements, determines the size of interleaved data chunks. - * - * TODO: in theory, we can lift this to the template parameter and keep it at hardware maximum - * possible value by padding the `dim` of the data https://github.com/rapidsai/raft/issues/711 - */ - [[nodiscard]] constexpr inline auto veclen() const noexcept -> uint32_t { return veclen_; } - /** Distance metric used for clustering. */ - [[nodiscard]] constexpr inline auto metric() const noexcept -> raft::distance::DistanceType - { - return metric_; - } - /** Whether `centers()` change upon extending the index (ivf_pq::extend). */ - [[nodiscard]] constexpr inline auto adaptive_centers() const noexcept -> bool - { - return adaptive_centers_; - } - /** - * Inverted list data [size, dim]. - * - * The data consists of the dataset rows, grouped by their labels (into clusters/lists). - * Within each list (cluster), the data is grouped into blocks of `kIndexGroupSize` interleaved - * vectors. Note, the total index length is slightly larger than the source dataset length, - * because each cluster is padded by `kIndexGroupSize` elements. - * - * Interleaving pattern: - * within groups of `kIndexGroupSize` rows, the data is interleaved with the block size equal to - * `veclen * sizeof(T)`. That is, a chunk of `veclen` consecutive components of one row is - * followed by a chunk of the same size of the next row, and so on. - * - * __Example__: veclen = 2, dim = 6, kIndexGroupSize = 32, list_size = 31 - * - * x[ 0, 0], x[ 0, 1], x[ 1, 0], x[ 1, 1], ... x[14, 0], x[14, 1], x[15, 0], x[15, 1], - * x[16, 0], x[16, 1], x[17, 0], x[17, 1], ... x[30, 0], x[30, 1], - , - , - * x[ 0, 2], x[ 0, 3], x[ 1, 2], x[ 1, 3], ... x[14, 2], x[14, 3], x[15, 2], x[15, 3], - * x[16, 2], x[16, 3], x[17, 2], x[17, 3], ... x[30, 2], x[30, 3], - , - , - * x[ 0, 4], x[ 0, 5], x[ 1, 4], x[ 1, 5], ... x[14, 4], x[14, 5], x[15, 4], x[15, 5], - * x[16, 4], x[16, 5], x[17, 4], x[17, 5], ... x[30, 4], x[30, 5], - , - , - * - */ - /** Sizes of the lists (clusters) [n_lists] - * NB: This may differ from the actual list size if the shared lists have been extended by another - * index - */ - inline auto list_sizes() noexcept -> device_vector_view - { - return list_sizes_.view(); - } - [[nodiscard]] inline auto list_sizes() const noexcept - -> device_vector_view - { - return list_sizes_.view(); - } - - /** k-means cluster centers corresponding to the lists [n_lists, dim] */ - inline auto centers() noexcept -> device_matrix_view - { - return centers_.view(); - } - [[nodiscard]] inline auto centers() const noexcept - -> device_matrix_view - { - return centers_.view(); - } - - /** - * (Optional) Precomputed norms of the `centers` w.r.t. the chosen distance metric [n_lists]. - * - * NB: this may be empty if the index is empty or if the metric does not require the center norms - * calculation. - */ - inline auto center_norms() noexcept -> std::optional> - { - if (center_norms_.has_value()) { - return std::make_optional>(center_norms_->view()); - } else { - return std::nullopt; - } - } - [[nodiscard]] inline auto center_norms() const noexcept - -> std::optional> - { - if (center_norms_.has_value()) { - return std::make_optional>(center_norms_->view()); - } else { - return std::nullopt; - } - } - - /** - * Accumulated list sizes, sorted in descending order [n_lists + 1]. - * The last value contains the total length of the index. - * The value at index zero is always zero. - * - * That is, the content of this span is as if the `list_sizes` was sorted and then accumulated. - * - * This span is used during search to estimate the maximum size of the workspace. - */ - inline auto accum_sorted_sizes() noexcept -> host_vector_view - { - return accum_sorted_sizes_.view(); - } - [[nodiscard]] inline auto accum_sorted_sizes() const noexcept - -> host_vector_view - { - return accum_sorted_sizes_.view(); - } - - /** Total length of the index. */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT - { - return accum_sorted_sizes()(n_lists()); - } - - /** Dimensionality of the data. */ - [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t - { - return centers_.extent(1); - } - /** Number of clusters/inverted lists. */ - [[nodiscard]] constexpr inline auto n_lists() const noexcept -> uint32_t { return lists_.size(); } - - // Don't allow copying the index for performance reasons (try avoiding copying data) - index(const index&) = delete; - index(index&&) = default; - auto operator=(const index&) -> index& = delete; - auto operator=(index&&) -> index& = default; - ~index() = default; - - /** Construct an empty index. It needs to be trained and then populated. */ - [[deprecated("Use cuVS instead")]] index(raft::resources const& res, - raft::distance::DistanceType metric, - uint32_t n_lists, - bool adaptive_centers, - bool conservative_memory_allocation, - uint32_t dim) - : ann::index(), - veclen_(calculate_veclen(dim)), - metric_(metric), - adaptive_centers_(adaptive_centers), - conservative_memory_allocation_{conservative_memory_allocation}, - lists_{n_lists}, - list_sizes_{make_device_vector(res, n_lists)}, - centers_(make_device_matrix(res, n_lists, dim)), - center_norms_(std::nullopt), - data_ptrs_{make_device_vector(res, n_lists)}, - inds_ptrs_{make_device_vector(res, n_lists)}, - accum_sorted_sizes_{make_host_vector(n_lists + 1)} - { - check_consistency(); - accum_sorted_sizes_(n_lists) = 0; - } - - /** Construct an empty index. It needs to be trained and then populated. */ - [[deprecated("Use cuVS instead")]] index(raft::resources const& res, - const index_params& params, - uint32_t dim) - : index(res, - params.metric, - params.n_lists, - params.adaptive_centers, - params.conservative_memory_allocation, - dim) - { - } - - /** Pointers to the inverted lists (clusters) data [n_lists]. */ - inline auto data_ptrs() noexcept -> device_vector_view { return data_ptrs_.view(); } - [[nodiscard]] inline auto data_ptrs() const noexcept -> device_vector_view - { - return data_ptrs_.view(); - } - - /** Pointers to the inverted lists (clusters) indices [n_lists]. */ - inline auto inds_ptrs() noexcept -> device_vector_view - { - return inds_ptrs_.view(); - } - [[nodiscard]] inline auto inds_ptrs() const noexcept -> device_vector_view - { - return inds_ptrs_.view(); - } - /** - * Whether to use convervative memory allocation when extending the list (cluster) data - * (see index_params.conservative_memory_allocation). - */ - [[nodiscard]] constexpr inline auto conservative_memory_allocation() const noexcept -> bool - { - return conservative_memory_allocation_; - } - - void allocate_center_norms(raft::resources const& res) - { - switch (metric_) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2Unexpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: - center_norms_ = make_device_vector(res, n_lists()); - break; - default: center_norms_ = std::nullopt; - } - } - - /** Lists' data and indices. */ - inline auto lists() noexcept -> std::vector>>& - { - return lists_; - } - [[nodiscard]] inline auto lists() const noexcept - -> const std::vector>>& - { - return lists_; - } - - /** Throw an error if the index content is inconsistent. */ - void check_consistency() - { - auto n_lists = lists_.size(); - RAFT_EXPECTS(dim() % veclen_ == 0, "dimensionality is not a multiple of the veclen"); - RAFT_EXPECTS(list_sizes_.extent(0) == n_lists, "inconsistent list size"); - RAFT_EXPECTS(data_ptrs_.extent(0) == n_lists, "inconsistent list size"); - RAFT_EXPECTS(inds_ptrs_.extent(0) == n_lists, "inconsistent list size"); - RAFT_EXPECTS( // - (centers_.extent(0) == list_sizes_.extent(0)) && // - (!center_norms_.has_value() || centers_.extent(0) == center_norms_->extent(0)), - "inconsistent number of lists (clusters)"); - } - - private: - /** - * TODO: in theory, we can lift this to the template parameter and keep it at hardware maximum - * possible value by padding the `dim` of the data https://github.com/rapidsai/raft/issues/711 - */ - uint32_t veclen_; - raft::distance::DistanceType metric_; - bool adaptive_centers_; - bool conservative_memory_allocation_; - std::vector>> lists_; - device_vector list_sizes_; - device_matrix centers_; - std::optional> center_norms_; - - // Computed members - device_vector data_ptrs_; - device_vector inds_ptrs_; - host_vector accum_sorted_sizes_; - - static auto calculate_veclen(uint32_t dim) -> uint32_t - { - // TODO: consider padding the dimensions and fixing veclen to its maximum possible value as a - // template parameter (https://github.com/rapidsai/raft/issues/711) - - // NOTE: keep this consistent with the select_interleaved_scan_kernel logic - // in detail/ivf_flat_interleaved_scan-inl.cuh. - uint32_t veclen = std::max(1, 16 / sizeof(T)); - if (dim % veclen != 0) { veclen = 1; } - return veclen; - } -}; - -/** @} */ - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/include/raft/neighbors/ivf_list.hpp b/cpp/include/raft/neighbors/ivf_list.hpp deleted file mode 100644 index 08879ed059..0000000000 --- a/cpp/include/raft/neighbors/ivf_list.hpp +++ /dev/null @@ -1,193 +0,0 @@ -/* - * Copyright (c) 2022-2024, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include - -#include -#include -#include - -namespace raft::neighbors::ivf { - -/** The data for a single IVF list. */ -template